Add support for async generators (#3821)

* Add impl + unit test

* CHANGELOG

* Lint

* Type check

* Remove print

* Fix tests

* revert change

* Lint

* formatting

* Fix test

* Lint

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: Ali Abid <aabid94@gmail.com>
This commit is contained in:
Freddy Boulton 2023-05-09 00:21:47 -04:00 committed by GitHub
parent f1703d5f53
commit 96c17a7470
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 198 additions and 146 deletions

View File

@ -2,6 +2,7 @@
## New Features:
- Support for asynchronous iterators by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3821](https://github.com/gradio-app/gradio/pull/3821)
- Returning language agnostic types in the `/info` route by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4039](https://github.com/gradio-app/gradio/pull/4039)
## Bug Fixes:

View File

@ -12,7 +12,7 @@ import warnings
import webbrowser
from abc import abstractmethod
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Iterator
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
import anyio
import requests
@ -996,7 +996,7 @@ class Blocks(BlockContext):
self,
fn_index: int,
processed_input: list[Any],
iterator: Iterator[Any] | None = None,
iterator: AsyncIterator[Any] | None = None,
requests: routes.Request | list[routes.Request] | None = None,
event_id: str | None = None,
event_data: EventData | None = None,
@ -1046,17 +1046,17 @@ class Blocks(BlockContext):
else:
prediction = None
if inspect.isasyncgenfunction(block_fn.fn):
raise ValueError("Gradio does not support async generators.")
if inspect.isgeneratorfunction(block_fn.fn):
if inspect.isgeneratorfunction(block_fn.fn) or inspect.isasyncgenfunction(
block_fn.fn
):
if not self.enable_queue:
raise ValueError("Need to enable queue to use generators.")
try:
if iterator is None:
iterator = prediction
prediction = await anyio.to_thread.run_sync(
utils.async_iteration, iterator, limiter=self.limiter
)
if inspect.isgenerator(iterator):
iterator = utils.SyncToAsyncIterator(iterator, self.limiter)
prediction = await utils.async_iteration(iterator)
is_generating = True
except StopAsyncIteration:
n_outputs = len(self.dependencies[fn_index].get("outputs"))
@ -1320,7 +1320,6 @@ Received outputs:
block_fn.total_runtime += result["duration"]
block_fn.total_runs += 1
return {
"data": data,
"is_generating": is_generating,

View File

@ -500,12 +500,17 @@ class Interface(Blocks):
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
if inspect.isgeneratorfunction(self.fn):
if inspect.isgeneratorfunction(
self.fn
) or inspect.isasyncgenfunction(self.fn):
stop_btn = Button("Stop", variant="stop", visible=False)
elif self.interface_type == InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
if inspect.isgeneratorfunction(self.fn) and not self.live:
if (
inspect.isgeneratorfunction(self.fn)
or inspect.isasyncgenfunction(self.fn)
) and not self.live:
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = self.render_flag_btns()
@ -536,7 +541,10 @@ class Interface(Blocks):
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
clear_btn = Button("Clear")
submit_btn = Button("Generate", variant="primary")
if inspect.isgeneratorfunction(self.fn) and not self.live:
if (
inspect.isgeneratorfunction(self.fn)
or inspect.isasyncgenfunction(self.fn)
) and not self.live:
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
@ -599,54 +607,39 @@ class Interface(Blocks):
assert submit_btn is not None, "Submit button not rendered"
fn = self.fn
extra_output = []
if stop_btn:
# Wrap the original function to show/hide the "Stop" button
def fn(*args):
# The main idea here is to call the original function
# and append some updates to keep the "Submit" button
# hidden and the "Stop" button visible
# The 'finally' block hides the "Stop" button and
# shows the "submit" button. Having a 'finally' block
# will make sure the UI is "reset" even if there is an exception
try:
for output in self.fn(*args):
if len(self.output_components) == 1 and not self.batch:
output = [output]
output = list(output)
yield output + [
Button.update(visible=False),
Button.update(visible=True),
]
finally:
yield [
{"__type__": "generic_update"}
for _ in self.output_components
] + [Button.update(visible=True), Button.update(visible=False)]
extra_output = [submit_btn, stop_btn]
triggers = [submit_btn.click] + [
component.submit
for component in self.input_components
if isinstance(component, Submittable)
]
predict_events = []
for i, trigger in enumerate(triggers):
predict_events.append(
trigger(
fn,
self.input_components,
self.output_components + extra_output,
api_name="predict" if i == 0 else None,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
)
if stop_btn:
trigger(
if stop_btn:
# Wrap the original function to show/hide the "Stop" button
async def fn(*args):
# The main idea here is to call the original function
# and append some updates to keep the "Submit" button
# hidden and the "Stop" button visible
if inspect.isasyncgenfunction(self.fn):
iterator = self.fn(*args)
else:
iterator = utils.SyncToAsyncIterator(
self.fn(*args), limiter=self.limiter
)
async for output in iterator:
yield output
extra_output = [submit_btn, stop_btn]
cleanup = lambda: [
Button.update(visible=True),
Button.update(visible=False),
]
for i, trigger in enumerate(triggers):
predict_event = trigger(
lambda: (
submit_btn.update(visible=False),
stop_btn.update(visible=True),
@ -654,28 +647,48 @@ class Interface(Blocks):
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
).then(
fn,
self.input_components,
self.output_components,
api_name="predict" if i == 0 else None,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
predict_events.append(predict_event)
predict_event.then(
cleanup,
inputs=None,
outputs=extra_output, # type: ignore
queue=False,
)
if stop_btn:
submit_btn.click(
lambda: (
submit_btn.update(visible=False),
stop_btn.update(visible=True),
),
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
)
stop_btn.click(
lambda: (
submit_btn.update(visible=True),
stop_btn.update(visible=False),
),
cleanup,
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=predict_events,
queue=False,
)
else:
for i, trigger in enumerate(triggers):
predict_events.append(
trigger(
fn,
self.input_components,
self.output_components,
api_name="predict" if i == 0 else None,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
)
def attach_clear_events(
self,

View File

@ -33,6 +33,7 @@ from typing import (
)
import aiohttp
import anyio
import httpx
import matplotlib
import requests
@ -483,7 +484,8 @@ def run_coro_in_background(func: Callable, *args, **kwargs):
return event_loop.create_task(func(*args, **kwargs))
def async_iteration(iterator):
def run_sync_iterator_async(iterator):
"""Helper for yielding StopAsyncIteration from sync iterators."""
try:
return next(iterator)
except StopIteration:
@ -491,6 +493,27 @@ def async_iteration(iterator):
raise StopAsyncIteration() from None
class SyncToAsyncIterator:
"""Treat a synchronous iterator as async one."""
def __init__(self, iterator, limiter) -> None:
self.iterator = iterator
self.limiter = limiter
def __aiter__(self):
return self
async def __anext__(self):
return await anyio.to_thread.run_sync(
run_sync_iterator_async, self.iterator, limiter=self.limiter
)
async def async_iteration(iterator):
# anext not introduced until 3.10 :(
return await iterator.__anext__()
class AsyncRequest:
"""
The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.

View File

@ -264,6 +264,98 @@ class TestBlocksMethods:
completed = True
assert msg["output"]["data"][0] == "Victor"
@pytest.mark.asyncio
async def test_async_generators(self):
async def async_iteration(count: int):
for i in range(count):
yield i
await asyncio.sleep(0.2)
def iteration(count: int):
for i in range(count):
yield i
time.sleep(0.2)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
num1 = gr.Number(value=4, precision=0)
o1 = gr.Number()
async_iterate = gr.Button(value="Async Iteration")
async_iterate.click(async_iteration, num1, o1)
with gr.Column():
num2 = gr.Number(value=4, precision=0)
o2 = gr.Number()
iterate = gr.Button(value="Iterate")
iterate.click(iteration, num2, o2)
demo.queue(concurrency_count=2).launch(prevent_thread_lock=True)
def _get_ws_pred(data, fn_index):
async def wrapped():
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(
json.dumps({"data": [data], "fn_index": fn_index})
)
if msg["msg"] == "send_hash":
await ws.send(
json.dumps(
{"fn_index": fn_index, "session_hash": "shdce"}
)
)
if msg["msg"] == "process_completed":
completed = True
assert msg["output"]["data"][0] == data - 1
return wrapped
try:
await asyncio.gather(_get_ws_pred(3, 0)(), _get_ws_pred(4, 1)())
finally:
demo.close()
@pytest.mark.asyncio
async def test_sync_generators(self):
def generator(string):
yield from string
demo = gr.Interface(generator, "text", "text")
demo.queue().launch(prevent_thread_lock=True)
async def _get_ws_pred(data, fn_index):
outputs = []
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(
json.dumps({"data": [data], "fn_index": fn_index})
)
if msg["msg"] == "send_hash":
await ws.send(
json.dumps({"fn_index": fn_index, "session_hash": "shdce"})
)
if msg["msg"] in ["process_generating"]:
outputs.append(msg["output"]["data"])
if msg["msg"] == "process_completed":
completed = True
return outputs
try:
output = await _get_ws_pred(fn_index=1, data="abc")
assert [o[0] for o in output] == ["a", "b", "c"]
finally:
demo.close()
def test_socket_reuse(self):
try:
@ -1158,29 +1250,15 @@ class TestCancel:
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
checked_iteration = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0}))
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 1}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "process_generating" and isinstance(
msg["output"]["data"][0], str
):
checked_iteration = True
assert msg["output"]["data"][1:] == [
{"visible": False, "__type__": "update"},
{"visible": True, "__type__": "update"},
]
if msg["msg"] == "process_completed":
assert msg["output"]["data"] == [
{"__type__": "update"},
{"visible": True, "__type__": "update"},
{"visible": False, "__type__": "update"},
]
assert msg["output"]["data"] == ["3"]
completed = True
assert checked_iteration
io.close()

View File

@ -394,68 +394,6 @@ class TestRoutes:
demo.close()
class TestGeneratorRoutes:
def test_generator(self):
def generator(string):
yield from string
io = Interface(generator, "text", "text")
app, _, _ = io.queue().launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"][0] == "a"
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"][0] == "b"
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"][0] == "c"
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"] == [
{"__type__": "update"},
{"__type__": "update", "visible": True},
{"__type__": "update", "visible": False},
]
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"][0] is None
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"][0] == "a"
class TestApp:
def test_create_app(self):
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))