diff --git a/gradio/blocks.py b/gradio/blocks.py index 0b0a434e22..bb4b59a607 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -7,11 +7,11 @@ import os import random import sys import time -import warnings import webbrowser from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple -from fastapi.concurrency import run_in_threadpool +import anyio +from anyio import CapacityLimiter from gradio import encryptor, external, networking, queueing, routes, strings, utils from gradio.context import Context @@ -218,6 +218,7 @@ class Blocks(BlockContext): mode (str): a human-friendly name for the kind of Blocks interface being created. """ # Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks? + self.limiter = None self.save_to = None self.api_mode = False self.theme = theme @@ -429,7 +430,9 @@ class Blocks(BlockContext): if inspect.iscoroutinefunction(block_fn.fn): prediction = await block_fn.fn(*processed_input) else: - prediction = await run_in_threadpool(block_fn.fn, *processed_input) + prediction = await anyio.to_thread.run_sync( + block_fn.fn, *processed_input, limiter=self.limiter + ) duration = time.time() - start return prediction, duration @@ -525,6 +528,11 @@ class Blocks(BlockContext): "average_duration": block_fn.total_runtime / block_fn.total_runs, } + async def create_limiter(self, max_threads: Optional[int]): + self.limiter = ( + None if max_threads is None else CapacityLimiter(total_tokens=max_threads) + ) + def get_config(self): return {"type": "column"} @@ -645,6 +653,7 @@ class Blocks(BlockContext): share: bool = False, debug: bool = False, enable_queue: bool = None, + max_threads: Optional[int] = None, auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None, auth_message: Optional[str] = None, prevent_thread_lock: bool = False, @@ -678,6 +687,7 @@ class Blocks(BlockContext): server_name (str | None): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1". show_tips (bool): if True, will occasionally show tips about new Gradio features enable_queue (bool | None): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False. + max_threads (int | None): allow up to `max_threads` to be processed in parallel. The default is inherited from the starlette library (currently 40). width (int): The width in pixels of the iframe element containing the interface (used if inline=True) height (int): The height in pixels of the iframe element containing the interface (used if inline=True) encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch @@ -710,7 +720,7 @@ class Blocks(BlockContext): self.enable_queue = True else: self.enable_queue = enable_queue or False - + utils.synchronize_async(self.create_limiter, max_threads) self.config = self.get_config_file() self.share = share self.encrypt = encrypt diff --git a/gradio/utils.py b/gradio/utils.py index d12c5174b1..a695a7e05d 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -2,8 +2,7 @@ from __future__ import annotations -import copy -import csv +import asyncio import inspect import json import json.decoder @@ -311,3 +310,9 @@ def component_or_layout_class(cls_name: str) -> Component | BlockContext: ): return cls raise ValueError(f"No such component or layout: {cls_name}") + + +def synchronize_async(func: Callable, *args: object, callback_func: Callable = None): + event_loop = asyncio.get_event_loop() + task = event_loop.create_task(func(*args)) + task.add_done_callback(callback_func)