Fix multimodal chatinterface api bug (#9054)

* fix

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-08-08 04:36:14 -04:00 committed by GitHub
parent f29aef4528
commit 9fa635a8fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 60 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix multimodal chatinterface api bug

View File

@ -517,7 +517,7 @@ class ChatInterface(Blocks):
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
[self.fake_response_textbox, self.chatbot_state],
api_name="chat",
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
@ -697,65 +697,6 @@ class ChatInterface(Blocks):
history_with_input[-1] = response # type: ignore
yield history_with_input
async def _api_submit_fn(
self,
message: str,
history: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> tuple[str, TupleFormat | list[MessageDict]]:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
if self.type == "tuples":
history.append([message, response]) # type: ignore
else:
new_response = self.response_as_dict(response)
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
return response, history
async def _api_stream_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> AsyncGenerator:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.type == "tuples":
yield first_response, history + [[message, first_response]]
else:
first_response = self.response_as_dict(first_response)
yield (
first_response,
history + [{"role": "user", "content": message}, first_response],
)
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
if self.type == "tuples":
yield response, history + [[message, response]]
else:
new_response = self.response_as_dict(response)
yield (
new_response,
history + [{"role": "user", "content": message}, new_response],
)
async def _examples_fn(
self, message: str, *args
) -> TupleFormat | list[MessageDict]:

View File

@ -301,3 +301,17 @@ class TestAPI:
"robot ",
"robot h",
]
@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_multimodal_api(self, type, connect):
def double_multimodal(msg, history):
return msg["text"] + " " + msg["text"]
chatbot = gr.ChatInterface(
double_multimodal,
type=type,
multimodal=True,
)
with connect(chatbot) as client:
result = client.predict({"text": "hello", "files": []}, api_name="/chat")
assert result == "hello hello"