Max parallel threads (#1460)

* Set max parallel threads to 100.

I could not find a better solution than monkey-patching
`anyio.to_thread.run_sync` to use a custom `CapacityLimiter`. I did try
to obtain the predictions inside a context handler instead:

```
async with self.limiter:
    predictions = await run_in_threadpool(block_fn.fn, *processed_input)
```

However, that approach didn't seem to work.

* Make max parallel threads configurable.

`max_threads` can now be used in `launch` to specify the desired number
of parallel threads supported.

* Fix import order.

* Update gradio/blocks.py

* Replace `run_in_threadpool` with `run_sync`.

We create the capacity limiter on launch, and invoke
`anyio.to_thread.run_sync` directly instead of going through
`run_in_threadpool`. This allows for some code simplification: we no
longer need to patch `anyio.to_thread.run_sync`.

* max_parallel_threads

- resolve conflicts

* max_parallel_threads

- solve error

* max_parallel_threads

- solve async error

Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com>
This commit is contained in:
Ömer Faruk Özdemir 2022-06-03 14:20:41 +03:00 committed by GitHub
parent 282748b4af
commit 267ba1e45b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 6 deletions

View File

@ -7,11 +7,11 @@ import os
import random
import sys
import time
import warnings
import webbrowser
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from fastapi.concurrency import run_in_threadpool
import anyio
from anyio import CapacityLimiter
from gradio import encryptor, external, networking, queueing, routes, strings, utils
from gradio.context import Context
@ -218,6 +218,7 @@ class Blocks(BlockContext):
mode (str): a human-friendly name for the kind of Blocks interface being created.
"""
# Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
self.limiter = None
self.save_to = None
self.api_mode = False
self.theme = theme
@ -429,7 +430,9 @@ class Blocks(BlockContext):
if inspect.iscoroutinefunction(block_fn.fn):
prediction = await block_fn.fn(*processed_input)
else:
prediction = await run_in_threadpool(block_fn.fn, *processed_input)
prediction = await anyio.to_thread.run_sync(
block_fn.fn, *processed_input, limiter=self.limiter
)
duration = time.time() - start
return prediction, duration
@ -525,6 +528,11 @@ class Blocks(BlockContext):
"average_duration": block_fn.total_runtime / block_fn.total_runs,
}
async def create_limiter(self, max_threads: Optional[int]):
self.limiter = (
None if max_threads is None else CapacityLimiter(total_tokens=max_threads)
)
def get_config(self):
return {"type": "column"}
@ -645,6 +653,7 @@ class Blocks(BlockContext):
share: bool = False,
debug: bool = False,
enable_queue: bool = None,
max_threads: Optional[int] = None,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
prevent_thread_lock: bool = False,
@ -678,6 +687,7 @@ class Blocks(BlockContext):
server_name (str | None): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool | None): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
max_threads (int | None): allow up to `max_threads` to be processed in parallel. The default is inherited from the starlette library (currently 40).
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
@ -710,7 +720,7 @@ class Blocks(BlockContext):
self.enable_queue = True
else:
self.enable_queue = enable_queue or False
utils.synchronize_async(self.create_limiter, max_threads)
self.config = self.get_config_file()
self.share = share
self.encrypt = encrypt

View File

@ -2,8 +2,7 @@
from __future__ import annotations
import copy
import csv
import asyncio
import inspect
import json
import json.decoder
@ -311,3 +310,9 @@ def component_or_layout_class(cls_name: str) -> Component | BlockContext:
):
return cls
raise ValueError(f"No such component or layout: {cls_name}")
def synchronize_async(func: Callable, *args: object, callback_func: Callable = None):
event_loop = asyncio.get_event_loop()
task = event_loop.create_task(func(*args))
task.add_done_callback(callback_func)