mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Add support for async functions and async generators to gr.ChatInterface
(#5116)
* add support for async functions and iterators to ChatInterface * fix api * added support to examples * add tests * chat * add changeset * typing * chat interface * anyio * anyio generator --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
eaa1ce14ac
commit
0dc49b4c51
5
.changeset/moody-buttons-tell.md
Normal file
5
.changeset/moody-buttons-tell.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Add support for async functions and async generators to `gr.ChatInterface`
|
@ -6,8 +6,9 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Generator
|
||||
from typing import AsyncGenerator, Callable
|
||||
|
||||
import anyio
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
@ -25,6 +26,7 @@ from gradio.events import Dependency, EventListenerMethod
|
||||
from gradio.helpers import create_examples as Examples # noqa: N812
|
||||
from gradio.layouts import Accordion, Column, Group, Row
|
||||
from gradio.themes import ThemeClass as Theme
|
||||
from gradio.utils import SyncToAsyncIterator, async_iteration
|
||||
|
||||
set_documentation_group("chatinterface")
|
||||
|
||||
@ -100,7 +102,12 @@ class ChatInterface(Blocks):
|
||||
theme=theme,
|
||||
)
|
||||
self.fn = fn
|
||||
self.is_generator = inspect.isgeneratorfunction(self.fn)
|
||||
self.is_async = inspect.iscoroutinefunction(
|
||||
self.fn
|
||||
) or inspect.isasyncgenfunction(self.fn)
|
||||
self.is_generator = inspect.isgeneratorfunction(
|
||||
self.fn
|
||||
) or inspect.isasyncgenfunction(self.fn)
|
||||
self.examples = examples
|
||||
if self.space_id and cache_examples is None:
|
||||
self.cache_examples = True
|
||||
@ -397,67 +404,99 @@ class ChatInterface(Blocks):
|
||||
history.append([message, None])
|
||||
return history, history
|
||||
|
||||
def _submit_fn(
|
||||
async def _submit_fn(
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
||||
history = history_with_input[:-1]
|
||||
response = self.fn(message, history, *args, **kwargs)
|
||||
if self.is_async:
|
||||
response = await self.fn(message, history, *args)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, message, history, *args, limiter=self.limiter
|
||||
)
|
||||
history.append([message, response])
|
||||
return history, history
|
||||
|
||||
def _stream_fn(
|
||||
async def _stream_fn(
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
|
||||
) -> AsyncGenerator:
|
||||
history = history_with_input[:-1]
|
||||
generator = self.fn(message, history, *args, **kwargs)
|
||||
if self.is_async:
|
||||
generator = self.fn(message, history, *args)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, message, history, *args, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = next(generator)
|
||||
first_response = await async_iteration(generator)
|
||||
update = history + [[message, first_response]]
|
||||
yield update, update
|
||||
except StopIteration:
|
||||
update = history + [[message, None]]
|
||||
yield update, update
|
||||
for response in generator:
|
||||
async for response in generator:
|
||||
update = history + [[message, response]]
|
||||
yield update, update
|
||||
|
||||
def _api_submit_fn(
|
||||
self, message: str, history: list[list[str | None]], *args, **kwargs
|
||||
async def _api_submit_fn(
|
||||
self, message: str, history: list[list[str | None]], *args
|
||||
) -> tuple[str, list[list[str | None]]]:
|
||||
response = self.fn(message, history)
|
||||
if self.is_async:
|
||||
response = await self.fn(message, history, *args)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, message, history, *args, limiter=self.limiter
|
||||
)
|
||||
history.append([message, response])
|
||||
return response, history
|
||||
|
||||
def _api_stream_fn(
|
||||
self, message: str, history: list[list[str | None]], *args, **kwargs
|
||||
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
|
||||
generator = self.fn(message, history, *args, **kwargs)
|
||||
async def _api_stream_fn(
|
||||
self, message: str, history: list[list[str | None]], *args
|
||||
) -> AsyncGenerator:
|
||||
if self.is_async:
|
||||
generator = self.fn(message, history, *args)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, message, history, *args, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = next(generator)
|
||||
first_response = await async_iteration(generator)
|
||||
yield first_response, history + [[message, first_response]]
|
||||
except StopIteration:
|
||||
yield None, history + [[message, None]]
|
||||
for response in generator:
|
||||
async for response in generator:
|
||||
yield response, history + [[message, response]]
|
||||
|
||||
def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
|
||||
return [[message, self.fn(message, [], *args, **kwargs)]]
|
||||
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
||||
if self.is_async:
|
||||
response = await self.fn(message, [], *args)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, message, [], *args, limiter=self.limiter
|
||||
)
|
||||
return [[message, response]]
|
||||
|
||||
def _examples_stream_fn(
|
||||
async def _examples_stream_fn(
|
||||
self,
|
||||
message: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Generator[list[list[str | None]], None, None]:
|
||||
for response in self.fn(message, [], *args, **kwargs):
|
||||
) -> AsyncGenerator:
|
||||
if self.is_async:
|
||||
generator = self.fn(message, [], *args)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, message, [], *args, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
async for response in generator:
|
||||
yield [[message, response]]
|
||||
|
||||
def _delete_prev_fn(
|
||||
|
@ -5,5 +5,6 @@ cd "$(dirname ${0})/.."
|
||||
echo "Formatting the backend... Our style follows the Black code style."
|
||||
ruff --fix gradio test
|
||||
black gradio test
|
||||
bash scripts/type_check_backend.sh
|
||||
|
||||
bash client/python/scripts/format.sh # Call the client library's formatting script
|
||||
|
@ -15,11 +15,20 @@ def double(message, history):
|
||||
return message + " " + message
|
||||
|
||||
|
||||
async def async_greet(message, history):
|
||||
return "hi, " + message
|
||||
|
||||
|
||||
def stream(message, history):
|
||||
for i in range(len(message)):
|
||||
yield message[: i + 1]
|
||||
|
||||
|
||||
async def async_stream(message, history):
|
||||
for i in range(len(message)):
|
||||
yield message[: i + 1]
|
||||
|
||||
|
||||
def count(message, history):
|
||||
return str(len(history))
|
||||
|
||||
@ -68,7 +77,8 @@ class TestInit:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching(self):
|
||||
async def test_example_caching(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
double, examples=["hello", "hi"], cache_examples=True
|
||||
)
|
||||
@ -77,6 +87,17 @@ class TestInit:
|
||||
assert prediction_hello[0][0] == ["hello", "hello hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_async(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
async_greet, examples=["abubakar", "tom"], cache_examples=True
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"]
|
||||
assert prediction_hi[0][0] == ["tom", "hi, tom"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_streaming(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
@ -88,6 +109,17 @@ class TestInit:
|
||||
assert prediction_hello[0][0] == ["hello", "hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_streaming_async(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
async_stream, examples=["hello", "hi"], cache_examples=True
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_additional_inputs(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
@ -138,12 +170,25 @@ class TestAPI:
|
||||
wait([job])
|
||||
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
|
||||
|
||||
def test_streaming_api_async(self, connect):
|
||||
chatbot = gr.ChatInterface(async_stream).queue()
|
||||
with connect(chatbot) as client:
|
||||
job = client.submit("hello")
|
||||
wait([job])
|
||||
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
|
||||
|
||||
def test_non_streaming_api(self, connect):
|
||||
chatbot = gr.ChatInterface(double)
|
||||
with connect(chatbot) as client:
|
||||
result = client.predict("hello")
|
||||
assert result == "hello hello"
|
||||
|
||||
def test_non_streaming_api_async(self, connect):
|
||||
chatbot = gr.ChatInterface(async_greet)
|
||||
with connect(chatbot) as client:
|
||||
result = client.predict("gradio")
|
||||
assert result == "hi, gradio"
|
||||
|
||||
def test_streaming_api_with_additional_inputs(self, connect):
|
||||
chatbot = gr.ChatInterface(
|
||||
echo_system_prompt_plus_message,
|
||||
|
Loading…
Reference in New Issue
Block a user