mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Async functions and async generator functions with the every
option to work (#6395)
* Extend `get_continuous_fn()` to deal with async functions and async generator functions * add changeset * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
03491ef497
commit
8ef48f8524
5
.changeset/cold-gifts-tickle.md
Normal file
5
.changeset/cold-gifts-tickle.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Async functions and async generator functions with the `every` option to work
|
@ -21,7 +21,7 @@ from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
from numbers import Number
|
||||
from pathlib import Path
|
||||
from types import GeneratorType
|
||||
from types import AsyncGeneratorType, GeneratorType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -604,6 +604,11 @@ def get_continuous_fn(fn: Callable, every: float) -> Callable:
|
||||
if isinstance(output, GeneratorType):
|
||||
for item in output:
|
||||
yield item
|
||||
elif isinstance(output, AsyncGeneratorType):
|
||||
async for item in output:
|
||||
yield item
|
||||
elif inspect.isawaitable(output):
|
||||
yield await output
|
||||
else:
|
||||
yield output
|
||||
await asyncio.sleep(every)
|
||||
|
@ -348,6 +348,39 @@ class TestGetContinuousFn:
|
||||
assert [1, 1] == await agener_list.__anext__()
|
||||
assert [1, 1, 1] == await agener_list.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_continuous_fn_with_async_function(self):
|
||||
async def async_int_return(x): # for origin condition
|
||||
return x + 1
|
||||
|
||||
agen_int_return = get_continuous_fn(fn=async_int_return, every=0.01)
|
||||
agener_int_return = agen_int_return(1)
|
||||
assert await agener_int_return.__anext__() == 2
|
||||
assert await agener_int_return.__anext__() == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_continuous_fn_with_async_generator(self):
|
||||
async def async_int_yield(x): # new condition
|
||||
for _i in range(2):
|
||||
yield x
|
||||
x += 1
|
||||
|
||||
async def async_list_yield(x): # new condition
|
||||
for _i in range(2):
|
||||
yield x
|
||||
x += [1]
|
||||
|
||||
agen_int_yield = get_continuous_fn(fn=async_int_yield, every=0.01)
|
||||
agen_list_yield = get_continuous_fn(fn=async_list_yield, every=0.01)
|
||||
agener_int = agen_int_yield(1) # Primitive
|
||||
agener_list = agen_list_yield([1]) # Reference
|
||||
assert await agener_int.__anext__() == 1
|
||||
assert await agener_int.__anext__() == 2
|
||||
assert await agener_int.__anext__() == 1
|
||||
assert [1] == await agener_list.__anext__()
|
||||
assert [1, 1] == await agener_list.__anext__()
|
||||
assert [1, 1, 1] == await agener_list.__anext__()
|
||||
|
||||
|
||||
def test_tex2svg_preserves_matplotlib_backend():
|
||||
import matplotlib
|
||||
|
Loading…
Reference in New Issue
Block a user