diff --git a/CHANGELOG.md b/CHANGELOG.md index 117fc7ab5f..ca0ee031c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## New Features: +- Support for asynchronous iterators by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3821](https://github.com/gradio-app/gradio/pull/3821) - Returning language agnostic types in the `/info` route by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4039](https://github.com/gradio-app/gradio/pull/4039) ## Bug Fixes: diff --git a/gradio/blocks.py b/gradio/blocks.py index 0e2e8b89d3..e6c28a94b4 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -12,7 +12,7 @@ import warnings import webbrowser from abc import abstractmethod from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Iterator +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable import anyio import requests @@ -996,7 +996,7 @@ class Blocks(BlockContext): self, fn_index: int, processed_input: list[Any], - iterator: Iterator[Any] | None = None, + iterator: AsyncIterator[Any] | None = None, requests: routes.Request | list[routes.Request] | None = None, event_id: str | None = None, event_data: EventData | None = None, @@ -1046,17 +1046,17 @@ class Blocks(BlockContext): else: prediction = None - if inspect.isasyncgenfunction(block_fn.fn): - raise ValueError("Gradio does not support async generators.") - if inspect.isgeneratorfunction(block_fn.fn): + if inspect.isgeneratorfunction(block_fn.fn) or inspect.isasyncgenfunction( + block_fn.fn + ): if not self.enable_queue: raise ValueError("Need to enable queue to use generators.") try: if iterator is None: iterator = prediction - prediction = await anyio.to_thread.run_sync( - utils.async_iteration, iterator, limiter=self.limiter - ) + if inspect.isgenerator(iterator): + iterator = utils.SyncToAsyncIterator(iterator, self.limiter) + prediction = await utils.async_iteration(iterator) is_generating = True except StopAsyncIteration: n_outputs = len(self.dependencies[fn_index].get("outputs")) @@ -1320,7 +1320,6 @@ Received outputs: block_fn.total_runtime += result["duration"] block_fn.total_runs += 1 - return { "data": data, "is_generating": is_generating, diff --git a/gradio/interface.py b/gradio/interface.py index 5d82e9bd4c..c566547dcc 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -500,12 +500,17 @@ class Interface(Blocks): # is created. We use whether a generator function is provided # as a proxy of whether the queue will be enabled. # Using a generator function without the queue will raise an error. - if inspect.isgeneratorfunction(self.fn): + if inspect.isgeneratorfunction( + self.fn + ) or inspect.isasyncgenfunction(self.fn): stop_btn = Button("Stop", variant="stop", visible=False) elif self.interface_type == InterfaceTypes.UNIFIED: clear_btn = Button("Clear") submit_btn = Button("Submit", variant="primary") - if inspect.isgeneratorfunction(self.fn) and not self.live: + if ( + inspect.isgeneratorfunction(self.fn) + or inspect.isasyncgenfunction(self.fn) + ) and not self.live: stop_btn = Button("Stop", variant="stop") if self.allow_flagging == "manual": flag_btns = self.render_flag_btns() @@ -536,7 +541,10 @@ class Interface(Blocks): if self.interface_type == InterfaceTypes.OUTPUT_ONLY: clear_btn = Button("Clear") submit_btn = Button("Generate", variant="primary") - if inspect.isgeneratorfunction(self.fn) and not self.live: + if ( + inspect.isgeneratorfunction(self.fn) + or inspect.isasyncgenfunction(self.fn) + ) and not self.live: # Stopping jobs only works if the queue is enabled # We don't know if the queue is enabled when the interface # is created. We use whether a generator function is provided @@ -599,54 +607,39 @@ class Interface(Blocks): assert submit_btn is not None, "Submit button not rendered" fn = self.fn extra_output = [] - if stop_btn: - # Wrap the original function to show/hide the "Stop" button - def fn(*args): - # The main idea here is to call the original function - # and append some updates to keep the "Submit" button - # hidden and the "Stop" button visible - # The 'finally' block hides the "Stop" button and - # shows the "submit" button. Having a 'finally' block - # will make sure the UI is "reset" even if there is an exception - try: - for output in self.fn(*args): - if len(self.output_components) == 1 and not self.batch: - output = [output] - output = list(output) - yield output + [ - Button.update(visible=False), - Button.update(visible=True), - ] - finally: - yield [ - {"__type__": "generic_update"} - for _ in self.output_components - ] + [Button.update(visible=True), Button.update(visible=False)] - - extra_output = [submit_btn, stop_btn] triggers = [submit_btn.click] + [ component.submit for component in self.input_components if isinstance(component, Submittable) ] predict_events = [] - for i, trigger in enumerate(triggers): - predict_events.append( - trigger( - fn, - self.input_components, - self.output_components + extra_output, - api_name="predict" if i == 0 else None, - scroll_to_output=True, - preprocess=not (self.api_mode), - postprocess=not (self.api_mode), - batch=self.batch, - max_batch_size=self.max_batch_size, - ) - ) - if stop_btn: - trigger( + + if stop_btn: + + # Wrap the original function to show/hide the "Stop" button + async def fn(*args): + # The main idea here is to call the original function + # and append some updates to keep the "Submit" button + # hidden and the "Stop" button visible + + if inspect.isasyncgenfunction(self.fn): + iterator = self.fn(*args) + else: + iterator = utils.SyncToAsyncIterator( + self.fn(*args), limiter=self.limiter + ) + async for output in iterator: + yield output + + extra_output = [submit_btn, stop_btn] + + cleanup = lambda: [ + Button.update(visible=True), + Button.update(visible=False), + ] + for i, trigger in enumerate(triggers): + predict_event = trigger( lambda: ( submit_btn.update(visible=False), stop_btn.update(visible=True), @@ -654,28 +647,48 @@ class Interface(Blocks): inputs=None, outputs=[submit_btn, stop_btn], queue=False, + ).then( + fn, + self.input_components, + self.output_components, + api_name="predict" if i == 0 else None, + scroll_to_output=True, + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + ) + predict_events.append(predict_event) + + predict_event.then( + cleanup, + inputs=None, + outputs=extra_output, # type: ignore + queue=False, ) - if stop_btn: - submit_btn.click( - lambda: ( - submit_btn.update(visible=False), - stop_btn.update(visible=True), - ), - inputs=None, - outputs=[submit_btn, stop_btn], - queue=False, - ) stop_btn.click( - lambda: ( - submit_btn.update(visible=True), - stop_btn.update(visible=False), - ), + cleanup, inputs=None, outputs=[submit_btn, stop_btn], cancels=predict_events, queue=False, ) + else: + for i, trigger in enumerate(triggers): + predict_events.append( + trigger( + fn, + self.input_components, + self.output_components, + api_name="predict" if i == 0 else None, + scroll_to_output=True, + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + ) + ) def attach_clear_events( self, diff --git a/gradio/utils.py b/gradio/utils.py index cbaf049257..6688b0d658 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -33,6 +33,7 @@ from typing import ( ) import aiohttp +import anyio import httpx import matplotlib import requests @@ -483,7 +484,8 @@ def run_coro_in_background(func: Callable, *args, **kwargs): return event_loop.create_task(func(*args, **kwargs)) -def async_iteration(iterator): +def run_sync_iterator_async(iterator): + """Helper for yielding StopAsyncIteration from sync iterators.""" try: return next(iterator) except StopIteration: @@ -491,6 +493,27 @@ def async_iteration(iterator): raise StopAsyncIteration() from None +class SyncToAsyncIterator: + """Treat a synchronous iterator as async one.""" + + def __init__(self, iterator, limiter) -> None: + self.iterator = iterator + self.limiter = limiter + + def __aiter__(self): + return self + + async def __anext__(self): + return await anyio.to_thread.run_sync( + run_sync_iterator_async, self.iterator, limiter=self.limiter + ) + + +async def async_iteration(iterator): + # anext not introduced until 3.10 :( + return await iterator.__anext__() + + class AsyncRequest: """ The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager. diff --git a/test/test_blocks.py b/test/test_blocks.py index baf76037a5..2d32408767 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -264,6 +264,98 @@ class TestBlocksMethods: completed = True assert msg["output"]["data"][0] == "Victor" + @pytest.mark.asyncio + async def test_async_generators(self): + async def async_iteration(count: int): + for i in range(count): + yield i + await asyncio.sleep(0.2) + + def iteration(count: int): + for i in range(count): + yield i + time.sleep(0.2) + + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + num1 = gr.Number(value=4, precision=0) + o1 = gr.Number() + async_iterate = gr.Button(value="Async Iteration") + async_iterate.click(async_iteration, num1, o1) + with gr.Column(): + num2 = gr.Number(value=4, precision=0) + o2 = gr.Number() + iterate = gr.Button(value="Iterate") + iterate.click(iteration, num2, o2) + + demo.queue(concurrency_count=2).launch(prevent_thread_lock=True) + + def _get_ws_pred(data, fn_index): + async def wrapped(): + async with websockets.connect( + f"{demo.local_url.replace('http', 'ws')}queue/join" + ) as ws: + completed = False + while not completed: + msg = json.loads(await ws.recv()) + if msg["msg"] == "send_data": + await ws.send( + json.dumps({"data": [data], "fn_index": fn_index}) + ) + if msg["msg"] == "send_hash": + await ws.send( + json.dumps( + {"fn_index": fn_index, "session_hash": "shdce"} + ) + ) + if msg["msg"] == "process_completed": + completed = True + assert msg["output"]["data"][0] == data - 1 + + return wrapped + + try: + await asyncio.gather(_get_ws_pred(3, 0)(), _get_ws_pred(4, 1)()) + finally: + demo.close() + + @pytest.mark.asyncio + async def test_sync_generators(self): + def generator(string): + yield from string + + demo = gr.Interface(generator, "text", "text") + demo.queue().launch(prevent_thread_lock=True) + + async def _get_ws_pred(data, fn_index): + outputs = [] + async with websockets.connect( + f"{demo.local_url.replace('http', 'ws')}queue/join" + ) as ws: + completed = False + while not completed: + msg = json.loads(await ws.recv()) + if msg["msg"] == "send_data": + await ws.send( + json.dumps({"data": [data], "fn_index": fn_index}) + ) + if msg["msg"] == "send_hash": + await ws.send( + json.dumps({"fn_index": fn_index, "session_hash": "shdce"}) + ) + if msg["msg"] in ["process_generating"]: + outputs.append(msg["output"]["data"]) + if msg["msg"] == "process_completed": + completed = True + return outputs + + try: + output = await _get_ws_pred(fn_index=1, data="abc") + assert [o[0] for o in output] == ["a", "b", "c"] + finally: + demo.close() + def test_socket_reuse(self): try: @@ -1158,29 +1250,15 @@ class TestCancel: f"{io.local_url.replace('http', 'ws')}queue/join" ) as ws: completed = False - checked_iteration = False while not completed: msg = json.loads(await ws.recv()) if msg["msg"] == "send_data": - await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0})) + await ws.send(json.dumps({"data": ["freddy"], "fn_index": 1})) if msg["msg"] == "send_hash": await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"})) - if msg["msg"] == "process_generating" and isinstance( - msg["output"]["data"][0], str - ): - checked_iteration = True - assert msg["output"]["data"][1:] == [ - {"visible": False, "__type__": "update"}, - {"visible": True, "__type__": "update"}, - ] if msg["msg"] == "process_completed": - assert msg["output"]["data"] == [ - {"__type__": "update"}, - {"visible": True, "__type__": "update"}, - {"visible": False, "__type__": "update"}, - ] + assert msg["output"]["data"] == ["3"] completed = True - assert checked_iteration io.close() diff --git a/test/test_routes.py b/test/test_routes.py index bd1a4e4fea..918b459d7d 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -394,68 +394,6 @@ class TestRoutes: demo.close() -class TestGeneratorRoutes: - def test_generator(self): - def generator(string): - yield from string - - io = Interface(generator, "text", "text") - app, _, _ = io.queue().launch(prevent_thread_lock=True) - client = TestClient(app) - - response = client.post( - "/api/predict/", - json={"data": ["abc"], "fn_index": 0, "session_hash": "11"}, - headers={"Authorization": f"Bearer {app.queue_token}"}, - ) - output = dict(response.json()) - assert output["data"][0] == "a" - - response = client.post( - "/api/predict/", - json={"data": ["abc"], "fn_index": 0, "session_hash": "11"}, - headers={"Authorization": f"Bearer {app.queue_token}"}, - ) - output = dict(response.json()) - assert output["data"][0] == "b" - - response = client.post( - "/api/predict/", - json={"data": ["abc"], "fn_index": 0, "session_hash": "11"}, - headers={"Authorization": f"Bearer {app.queue_token}"}, - ) - output = dict(response.json()) - assert output["data"][0] == "c" - - response = client.post( - "/api/predict/", - json={"data": ["abc"], "fn_index": 0, "session_hash": "11"}, - headers={"Authorization": f"Bearer {app.queue_token}"}, - ) - output = dict(response.json()) - assert output["data"] == [ - {"__type__": "update"}, - {"__type__": "update", "visible": True}, - {"__type__": "update", "visible": False}, - ] - - response = client.post( - "/api/predict/", - json={"data": ["abc"], "fn_index": 0, "session_hash": "11"}, - headers={"Authorization": f"Bearer {app.queue_token}"}, - ) - output = dict(response.json()) - assert output["data"][0] is None - - response = client.post( - "/api/predict/", - json={"data": ["abc"], "fn_index": 0, "session_hash": "11"}, - headers={"Authorization": f"Bearer {app.queue_token}"}, - ) - output = dict(response.json()) - assert output["data"][0] == "a" - - class TestApp: def test_create_app(self): app = routes.App.create_app(Interface(lambda x: x, "text", "text"))