Async functions and async generator functions with the every option to work (#6395)

* Extend `get_continuous_fn()` to deal with async functions and async generator functions

* add changeset

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2023-11-14 06:12:49 +09:00 committed by GitHub
parent 03491ef497
commit 8ef48f8524
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 1 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Async functions and async generator functions with the `every` option to work

View File

@ -21,7 +21,7 @@ from contextlib import contextmanager
from io import BytesIO
from numbers import Number
from pathlib import Path
from types import GeneratorType
from types import AsyncGeneratorType, GeneratorType
from typing import (
TYPE_CHECKING,
Any,
@ -604,6 +604,11 @@ def get_continuous_fn(fn: Callable, every: float) -> Callable:
if isinstance(output, GeneratorType):
for item in output:
yield item
elif isinstance(output, AsyncGeneratorType):
async for item in output:
yield item
elif inspect.isawaitable(output):
yield await output
else:
yield output
await asyncio.sleep(every)

View File

@ -348,6 +348,39 @@ class TestGetContinuousFn:
assert [1, 1] == await agener_list.__anext__()
assert [1, 1, 1] == await agener_list.__anext__()
@pytest.mark.asyncio
async def test_get_continuous_fn_with_async_function(self):
async def async_int_return(x): # for origin condition
return x + 1
agen_int_return = get_continuous_fn(fn=async_int_return, every=0.01)
agener_int_return = agen_int_return(1)
assert await agener_int_return.__anext__() == 2
assert await agener_int_return.__anext__() == 2
@pytest.mark.asyncio
async def test_get_continuous_fn_with_async_generator(self):
async def async_int_yield(x): # new condition
for _i in range(2):
yield x
x += 1
async def async_list_yield(x): # new condition
for _i in range(2):
yield x
x += [1]
agen_int_yield = get_continuous_fn(fn=async_int_yield, every=0.01)
agen_list_yield = get_continuous_fn(fn=async_list_yield, every=0.01)
agener_int = agen_int_yield(1) # Primitive
agener_list = agen_list_yield([1]) # Reference
assert await agener_int.__anext__() == 1
assert await agener_int.__anext__() == 2
assert await agener_int.__anext__() == 1
assert [1] == await agener_list.__anext__()
assert [1, 1] == await agener_list.__anext__()
assert [1, 1, 1] == await agener_list.__anext__()
def test_tex2svg_preserves_matplotlib_backend():
import matplotlib