mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Max parallel threads (#1460)
* Set max parallel threads to 100. I could not find a better solution than monkey-patching `anyio.to_thread.run_sync` to use a custom `CapacityLimiter`. I did try to obtain the predictions inside a context handler instead: ``` async with self.limiter: predictions = await run_in_threadpool(block_fn.fn, *processed_input) ``` However, that approach didn't seem to work. * Make max parallel threads configurable. `max_threads` can now be used in `launch` to specify the desired number of parallel threads supported. * Fix import order. * Update gradio/blocks.py * Replace `run_in_threadpool` with `run_sync`. We create the capacity limiter on launch, and invoke `anyio.to_thread.run_sync` directly instead of going through `run_in_threadpool`. This allows for some code simplification: we no longer need to patch `anyio.to_thread.run_sync`. * max_parallel_threads - resolve conflicts * max_parallel_threads - solve error * max_parallel_threads - solve async error Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com>
This commit is contained in:
parent
282748b4af
commit
267ba1e45b
@ -7,11 +7,11 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
import webbrowser
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
import anyio
|
||||
from anyio import CapacityLimiter
|
||||
|
||||
from gradio import encryptor, external, networking, queueing, routes, strings, utils
|
||||
from gradio.context import Context
|
||||
@ -218,6 +218,7 @@ class Blocks(BlockContext):
|
||||
mode (str): a human-friendly name for the kind of Blocks interface being created.
|
||||
"""
|
||||
# Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
|
||||
self.limiter = None
|
||||
self.save_to = None
|
||||
self.api_mode = False
|
||||
self.theme = theme
|
||||
@ -429,7 +430,9 @@ class Blocks(BlockContext):
|
||||
if inspect.iscoroutinefunction(block_fn.fn):
|
||||
prediction = await block_fn.fn(*processed_input)
|
||||
else:
|
||||
prediction = await run_in_threadpool(block_fn.fn, *processed_input)
|
||||
prediction = await anyio.to_thread.run_sync(
|
||||
block_fn.fn, *processed_input, limiter=self.limiter
|
||||
)
|
||||
duration = time.time() - start
|
||||
return prediction, duration
|
||||
|
||||
@ -525,6 +528,11 @@ class Blocks(BlockContext):
|
||||
"average_duration": block_fn.total_runtime / block_fn.total_runs,
|
||||
}
|
||||
|
||||
async def create_limiter(self, max_threads: Optional[int]):
|
||||
self.limiter = (
|
||||
None if max_threads is None else CapacityLimiter(total_tokens=max_threads)
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return {"type": "column"}
|
||||
|
||||
@ -645,6 +653,7 @@ class Blocks(BlockContext):
|
||||
share: bool = False,
|
||||
debug: bool = False,
|
||||
enable_queue: bool = None,
|
||||
max_threads: Optional[int] = None,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
@ -678,6 +687,7 @@ class Blocks(BlockContext):
|
||||
server_name (str | None): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool | None): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
||||
max_threads (int | None): allow up to `max_threads` to be processed in parallel. The default is inherited from the starlette library (currently 40).
|
||||
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
|
||||
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
@ -710,7 +720,7 @@ class Blocks(BlockContext):
|
||||
self.enable_queue = True
|
||||
else:
|
||||
self.enable_queue = enable_queue or False
|
||||
|
||||
utils.synchronize_async(self.create_limiter, max_threads)
|
||||
self.config = self.get_config_file()
|
||||
self.share = share
|
||||
self.encrypt = encrypt
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import csv
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import json.decoder
|
||||
@ -311,3 +310,9 @@ def component_or_layout_class(cls_name: str) -> Component | BlockContext:
|
||||
):
|
||||
return cls
|
||||
raise ValueError(f"No such component or layout: {cls_name}")
|
||||
|
||||
|
||||
def synchronize_async(func: Callable, *args: object, callback_func: Callable = None):
|
||||
event_loop = asyncio.get_event_loop()
|
||||
task = event_loop.create_task(func(*args))
|
||||
task.add_done_callback(callback_func)
|
||||
|
Loading…
Reference in New Issue
Block a user