Add support for async functions and async generators to gr.ChatInterface (#5116)

* add support for async functions and iterators to ChatInterface

* fix api

* added support to examples

* add tests

* chat

* add changeset

* typing

* chat interface

* anyio

* anyio generator

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-08-08 15:57:55 -04:00 committed by GitHub
parent eaa1ce14ac
commit 0dc49b4c51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 27 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Add support for async functions and async generators to `gr.ChatInterface`

View File

@ -6,8 +6,9 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
from __future__ import annotations
import inspect
from typing import Callable, Generator
from typing import AsyncGenerator, Callable
import anyio
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
@ -25,6 +26,7 @@ from gradio.events import Dependency, EventListenerMethod
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.layouts import Accordion, Column, Group, Row
from gradio.themes import ThemeClass as Theme
from gradio.utils import SyncToAsyncIterator, async_iteration
set_documentation_group("chatinterface")
@ -100,7 +102,12 @@ class ChatInterface(Blocks):
theme=theme,
)
self.fn = fn
self.is_generator = inspect.isgeneratorfunction(self.fn)
self.is_async = inspect.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
self.is_generator = inspect.isgeneratorfunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
self.examples = examples
if self.space_id and cache_examples is None:
self.cache_examples = True
@ -397,67 +404,99 @@ class ChatInterface(Blocks):
history.append([message, None])
return history, history
def _submit_fn(
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
response = self.fn(message, history, *args, **kwargs)
if self.is_async:
response = await self.fn(message, history, *args)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
history.append([message, response])
return history, history
def _stream_fn(
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
) -> AsyncGenerator:
history = history_with_input[:-1]
generator = self.fn(message, history, *args, **kwargs)
if self.is_async:
generator = self.fn(message, history, *args)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = next(generator)
first_response = await async_iteration(generator)
update = history + [[message, first_response]]
yield update, update
except StopIteration:
update = history + [[message, None]]
yield update, update
for response in generator:
async for response in generator:
update = history + [[message, response]]
yield update, update
def _api_submit_fn(
self, message: str, history: list[list[str | None]], *args, **kwargs
async def _api_submit_fn(
self, message: str, history: list[list[str | None]], *args
) -> tuple[str, list[list[str | None]]]:
response = self.fn(message, history)
if self.is_async:
response = await self.fn(message, history, *args)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
history.append([message, response])
return response, history
def _api_stream_fn(
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
generator = self.fn(message, history, *args, **kwargs)
async def _api_stream_fn(
self, message: str, history: list[list[str | None]], *args
) -> AsyncGenerator:
if self.is_async:
generator = self.fn(message, history, *args)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = next(generator)
first_response = await async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
for response in generator:
async for response in generator:
yield response, history + [[message, response]]
def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
return [[message, self.fn(message, [], *args, **kwargs)]]
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
if self.is_async:
response = await self.fn(message, [], *args)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, [], *args, limiter=self.limiter
)
return [[message, response]]
def _examples_stream_fn(
async def _examples_stream_fn(
self,
message: str,
*args,
**kwargs,
) -> Generator[list[list[str | None]], None, None]:
for response in self.fn(message, [], *args, **kwargs):
) -> AsyncGenerator:
if self.is_async:
generator = self.fn(message, [], *args)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, [], *args, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
async for response in generator:
yield [[message, response]]
def _delete_prev_fn(

View File

@ -5,5 +5,6 @@ cd "$(dirname ${0})/.."
echo "Formatting the backend... Our style follows the Black code style."
ruff --fix gradio test
black gradio test
bash scripts/type_check_backend.sh
bash client/python/scripts/format.sh # Call the client library's formatting script

View File

@ -15,11 +15,20 @@ def double(message, history):
return message + " " + message
async def async_greet(message, history):
return "hi, " + message
def stream(message, history):
for i in range(len(message)):
yield message[: i + 1]
async def async_stream(message, history):
for i in range(len(message)):
yield message[: i + 1]
def count(message, history):
return str(len(history))
@ -68,7 +77,8 @@ class TestInit:
)
@pytest.mark.asyncio
async def test_example_caching(self):
async def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
@ -77,6 +87,17 @@ class TestInit:
assert prediction_hello[0][0] == ["hello", "hello hello"]
assert prediction_hi[0][0] == ["hi", "hi hi"]
@pytest.mark.asyncio
async def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"]
assert prediction_hi[0][0] == ["tom", "hi, tom"]
@pytest.mark.asyncio
async def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
@ -88,6 +109,17 @@ class TestInit:
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]
@pytest.mark.asyncio
async def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]
@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
@ -138,12 +170,25 @@ class TestAPI:
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_stream).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_non_streaming_api(self, connect):
chatbot = gr.ChatInterface(double)
with connect(chatbot) as client:
result = client.predict("hello")
assert result == "hello hello"
def test_non_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_greet)
with connect(chatbot) as client:
result = client.predict("gradio")
assert result == "hi, gradio"
def test_streaming_api_with_additional_inputs(self, connect):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,