mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Add support for async generators (#3821)
* Add impl + unit test * CHANGELOG * Lint * Type check * Remove print * Fix tests * revert change * Lint * formatting * Fix test * Lint --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: Ali Abid <aabid94@gmail.com>
This commit is contained in:
parent
f1703d5f53
commit
96c17a7470
@ -2,6 +2,7 @@
|
||||
|
||||
## New Features:
|
||||
|
||||
- Support for asynchronous iterators by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3821](https://github.com/gradio-app/gradio/pull/3821)
|
||||
- Returning language agnostic types in the `/info` route by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4039](https://github.com/gradio-app/gradio/pull/4039)
|
||||
|
||||
## Bug Fixes:
|
||||
|
@ -12,7 +12,7 @@ import warnings
|
||||
import webbrowser
|
||||
from abc import abstractmethod
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterator
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
|
||||
|
||||
import anyio
|
||||
import requests
|
||||
@ -996,7 +996,7 @@ class Blocks(BlockContext):
|
||||
self,
|
||||
fn_index: int,
|
||||
processed_input: list[Any],
|
||||
iterator: Iterator[Any] | None = None,
|
||||
iterator: AsyncIterator[Any] | None = None,
|
||||
requests: routes.Request | list[routes.Request] | None = None,
|
||||
event_id: str | None = None,
|
||||
event_data: EventData | None = None,
|
||||
@ -1046,17 +1046,17 @@ class Blocks(BlockContext):
|
||||
else:
|
||||
prediction = None
|
||||
|
||||
if inspect.isasyncgenfunction(block_fn.fn):
|
||||
raise ValueError("Gradio does not support async generators.")
|
||||
if inspect.isgeneratorfunction(block_fn.fn):
|
||||
if inspect.isgeneratorfunction(block_fn.fn) or inspect.isasyncgenfunction(
|
||||
block_fn.fn
|
||||
):
|
||||
if not self.enable_queue:
|
||||
raise ValueError("Need to enable queue to use generators.")
|
||||
try:
|
||||
if iterator is None:
|
||||
iterator = prediction
|
||||
prediction = await anyio.to_thread.run_sync(
|
||||
utils.async_iteration, iterator, limiter=self.limiter
|
||||
)
|
||||
if inspect.isgenerator(iterator):
|
||||
iterator = utils.SyncToAsyncIterator(iterator, self.limiter)
|
||||
prediction = await utils.async_iteration(iterator)
|
||||
is_generating = True
|
||||
except StopAsyncIteration:
|
||||
n_outputs = len(self.dependencies[fn_index].get("outputs"))
|
||||
@ -1320,7 +1320,6 @@ Received outputs:
|
||||
|
||||
block_fn.total_runtime += result["duration"]
|
||||
block_fn.total_runs += 1
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"is_generating": is_generating,
|
||||
|
@ -500,12 +500,17 @@ class Interface(Blocks):
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
if inspect.isgeneratorfunction(
|
||||
self.fn
|
||||
) or inspect.isasyncgenfunction(self.fn):
|
||||
stop_btn = Button("Stop", variant="stop", visible=False)
|
||||
elif self.interface_type == InterfaceTypes.UNIFIED:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
if inspect.isgeneratorfunction(self.fn) and not self.live:
|
||||
if (
|
||||
inspect.isgeneratorfunction(self.fn)
|
||||
or inspect.isasyncgenfunction(self.fn)
|
||||
) and not self.live:
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = self.render_flag_btns()
|
||||
@ -536,7 +541,10 @@ class Interface(Blocks):
|
||||
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Generate", variant="primary")
|
||||
if inspect.isgeneratorfunction(self.fn) and not self.live:
|
||||
if (
|
||||
inspect.isgeneratorfunction(self.fn)
|
||||
or inspect.isasyncgenfunction(self.fn)
|
||||
) and not self.live:
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
@ -599,54 +607,39 @@ class Interface(Blocks):
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
fn = self.fn
|
||||
extra_output = []
|
||||
if stop_btn:
|
||||
|
||||
# Wrap the original function to show/hide the "Stop" button
|
||||
def fn(*args):
|
||||
# The main idea here is to call the original function
|
||||
# and append some updates to keep the "Submit" button
|
||||
# hidden and the "Stop" button visible
|
||||
# The 'finally' block hides the "Stop" button and
|
||||
# shows the "submit" button. Having a 'finally' block
|
||||
# will make sure the UI is "reset" even if there is an exception
|
||||
try:
|
||||
for output in self.fn(*args):
|
||||
if len(self.output_components) == 1 and not self.batch:
|
||||
output = [output]
|
||||
output = list(output)
|
||||
yield output + [
|
||||
Button.update(visible=False),
|
||||
Button.update(visible=True),
|
||||
]
|
||||
finally:
|
||||
yield [
|
||||
{"__type__": "generic_update"}
|
||||
for _ in self.output_components
|
||||
] + [Button.update(visible=True), Button.update(visible=False)]
|
||||
|
||||
extra_output = [submit_btn, stop_btn]
|
||||
triggers = [submit_btn.click] + [
|
||||
component.submit
|
||||
for component in self.input_components
|
||||
if isinstance(component, Submittable)
|
||||
]
|
||||
predict_events = []
|
||||
for i, trigger in enumerate(triggers):
|
||||
predict_events.append(
|
||||
trigger(
|
||||
fn,
|
||||
self.input_components,
|
||||
self.output_components + extra_output,
|
||||
api_name="predict" if i == 0 else None,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
)
|
||||
if stop_btn:
|
||||
trigger(
|
||||
|
||||
if stop_btn:
|
||||
|
||||
# Wrap the original function to show/hide the "Stop" button
|
||||
async def fn(*args):
|
||||
# The main idea here is to call the original function
|
||||
# and append some updates to keep the "Submit" button
|
||||
# hidden and the "Stop" button visible
|
||||
|
||||
if inspect.isasyncgenfunction(self.fn):
|
||||
iterator = self.fn(*args)
|
||||
else:
|
||||
iterator = utils.SyncToAsyncIterator(
|
||||
self.fn(*args), limiter=self.limiter
|
||||
)
|
||||
async for output in iterator:
|
||||
yield output
|
||||
|
||||
extra_output = [submit_btn, stop_btn]
|
||||
|
||||
cleanup = lambda: [
|
||||
Button.update(visible=True),
|
||||
Button.update(visible=False),
|
||||
]
|
||||
for i, trigger in enumerate(triggers):
|
||||
predict_event = trigger(
|
||||
lambda: (
|
||||
submit_btn.update(visible=False),
|
||||
stop_btn.update(visible=True),
|
||||
@ -654,28 +647,48 @@ class Interface(Blocks):
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
queue=False,
|
||||
).then(
|
||||
fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict" if i == 0 else None,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
predict_events.append(predict_event)
|
||||
|
||||
predict_event.then(
|
||||
cleanup,
|
||||
inputs=None,
|
||||
outputs=extra_output, # type: ignore
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if stop_btn:
|
||||
submit_btn.click(
|
||||
lambda: (
|
||||
submit_btn.update(visible=False),
|
||||
stop_btn.update(visible=True),
|
||||
),
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
queue=False,
|
||||
)
|
||||
stop_btn.click(
|
||||
lambda: (
|
||||
submit_btn.update(visible=True),
|
||||
stop_btn.update(visible=False),
|
||||
),
|
||||
cleanup,
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
cancels=predict_events,
|
||||
queue=False,
|
||||
)
|
||||
else:
|
||||
for i, trigger in enumerate(triggers):
|
||||
predict_events.append(
|
||||
trigger(
|
||||
fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict" if i == 0 else None,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
def attach_clear_events(
|
||||
self,
|
||||
|
@ -33,6 +33,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import httpx
|
||||
import matplotlib
|
||||
import requests
|
||||
@ -483,7 +484,8 @@ def run_coro_in_background(func: Callable, *args, **kwargs):
|
||||
return event_loop.create_task(func(*args, **kwargs))
|
||||
|
||||
|
||||
def async_iteration(iterator):
|
||||
def run_sync_iterator_async(iterator):
|
||||
"""Helper for yielding StopAsyncIteration from sync iterators."""
|
||||
try:
|
||||
return next(iterator)
|
||||
except StopIteration:
|
||||
@ -491,6 +493,27 @@ def async_iteration(iterator):
|
||||
raise StopAsyncIteration() from None
|
||||
|
||||
|
||||
class SyncToAsyncIterator:
|
||||
"""Treat a synchronous iterator as async one."""
|
||||
|
||||
def __init__(self, iterator, limiter) -> None:
|
||||
self.iterator = iterator
|
||||
self.limiter = limiter
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
return await anyio.to_thread.run_sync(
|
||||
run_sync_iterator_async, self.iterator, limiter=self.limiter
|
||||
)
|
||||
|
||||
|
||||
async def async_iteration(iterator):
|
||||
# anext not introduced until 3.10 :(
|
||||
return await iterator.__anext__()
|
||||
|
||||
|
||||
class AsyncRequest:
|
||||
"""
|
||||
The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
|
||||
|
@ -264,6 +264,98 @@ class TestBlocksMethods:
|
||||
completed = True
|
||||
assert msg["output"]["data"][0] == "Victor"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generators(self):
|
||||
async def async_iteration(count: int):
|
||||
for i in range(count):
|
||||
yield i
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
def iteration(count: int):
|
||||
for i in range(count):
|
||||
yield i
|
||||
time.sleep(0.2)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
num1 = gr.Number(value=4, precision=0)
|
||||
o1 = gr.Number()
|
||||
async_iterate = gr.Button(value="Async Iteration")
|
||||
async_iterate.click(async_iteration, num1, o1)
|
||||
with gr.Column():
|
||||
num2 = gr.Number(value=4, precision=0)
|
||||
o2 = gr.Number()
|
||||
iterate = gr.Button(value="Iterate")
|
||||
iterate.click(iteration, num2, o2)
|
||||
|
||||
demo.queue(concurrency_count=2).launch(prevent_thread_lock=True)
|
||||
|
||||
def _get_ws_pred(data, fn_index):
|
||||
async def wrapped():
|
||||
async with websockets.connect(
|
||||
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(
|
||||
json.dumps({"data": [data], "fn_index": fn_index})
|
||||
)
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{"fn_index": fn_index, "session_hash": "shdce"}
|
||||
)
|
||||
)
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
assert msg["output"]["data"][0] == data - 1
|
||||
|
||||
return wrapped
|
||||
|
||||
try:
|
||||
await asyncio.gather(_get_ws_pred(3, 0)(), _get_ws_pred(4, 1)())
|
||||
finally:
|
||||
demo.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_generators(self):
|
||||
def generator(string):
|
||||
yield from string
|
||||
|
||||
demo = gr.Interface(generator, "text", "text")
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
async def _get_ws_pred(data, fn_index):
|
||||
outputs = []
|
||||
async with websockets.connect(
|
||||
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(
|
||||
json.dumps({"data": [data], "fn_index": fn_index})
|
||||
)
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(
|
||||
json.dumps({"fn_index": fn_index, "session_hash": "shdce"})
|
||||
)
|
||||
if msg["msg"] in ["process_generating"]:
|
||||
outputs.append(msg["output"]["data"])
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
return outputs
|
||||
|
||||
try:
|
||||
output = await _get_ws_pred(fn_index=1, data="abc")
|
||||
assert [o[0] for o in output] == ["a", "b", "c"]
|
||||
finally:
|
||||
demo.close()
|
||||
|
||||
def test_socket_reuse(self):
|
||||
|
||||
try:
|
||||
@ -1158,29 +1250,15 @@ class TestCancel:
|
||||
f"{io.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
checked_iteration = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0}))
|
||||
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 1}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
if msg["msg"] == "process_generating" and isinstance(
|
||||
msg["output"]["data"][0], str
|
||||
):
|
||||
checked_iteration = True
|
||||
assert msg["output"]["data"][1:] == [
|
||||
{"visible": False, "__type__": "update"},
|
||||
{"visible": True, "__type__": "update"},
|
||||
]
|
||||
if msg["msg"] == "process_completed":
|
||||
assert msg["output"]["data"] == [
|
||||
{"__type__": "update"},
|
||||
{"visible": True, "__type__": "update"},
|
||||
{"visible": False, "__type__": "update"},
|
||||
]
|
||||
assert msg["output"]["data"] == ["3"]
|
||||
completed = True
|
||||
assert checked_iteration
|
||||
|
||||
io.close()
|
||||
|
||||
|
@ -394,68 +394,6 @@ class TestRoutes:
|
||||
demo.close()
|
||||
|
||||
|
||||
class TestGeneratorRoutes:
|
||||
def test_generator(self):
|
||||
def generator(string):
|
||||
yield from string
|
||||
|
||||
io = Interface(generator, "text", "text")
|
||||
app, _, _ = io.queue().launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
headers={"Authorization": f"Bearer {app.queue_token}"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"][0] == "a"
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
headers={"Authorization": f"Bearer {app.queue_token}"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"][0] == "b"
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
headers={"Authorization": f"Bearer {app.queue_token}"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"][0] == "c"
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
headers={"Authorization": f"Bearer {app.queue_token}"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"] == [
|
||||
{"__type__": "update"},
|
||||
{"__type__": "update", "visible": True},
|
||||
{"__type__": "update", "visible": False},
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
headers={"Authorization": f"Bearer {app.queue_token}"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"][0] is None
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
headers={"Authorization": f"Bearer {app.queue_token}"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"][0] == "a"
|
||||
|
||||
|
||||
class TestApp:
|
||||
def test_create_app(self):
|
||||
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user