mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Fix multimodal chatinterface api bug (#9054)
* fix * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
f29aef4528
commit
9fa635a8fd
5
.changeset/mighty-maps-double.md
Normal file
5
.changeset/mighty-maps-double.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix multimodal chatinterface api bug
|
@ -517,7 +517,7 @@ class ChatInterface(Blocks):
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
||||
[self.textbox, self.chatbot_state],
|
||||
[self.fake_response_textbox, self.chatbot_state],
|
||||
api_name="chat",
|
||||
concurrency_limit=cast(
|
||||
Union[int, Literal["default"], None], self.concurrency_limit
|
||||
@ -697,65 +697,6 @@ class ChatInterface(Blocks):
|
||||
history_with_input[-1] = response # type: ignore
|
||||
yield history_with_input
|
||||
|
||||
async def _api_submit_fn(
|
||||
self,
|
||||
message: str,
|
||||
history: TupleFormat | list[MessageDict],
|
||||
request: Request,
|
||||
*args,
|
||||
) -> tuple[str, TupleFormat | list[MessageDict]]:
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
response = await self.fn(*inputs)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
if self.type == "tuples":
|
||||
history.append([message, response]) # type: ignore
|
||||
else:
|
||||
new_response = self.response_as_dict(response)
|
||||
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
|
||||
return response, history
|
||||
|
||||
async def _api_stream_fn(
|
||||
self, message: str, history: list[list[str | None]], request: Request, *args
|
||||
) -> AsyncGenerator:
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
)
|
||||
if self.is_async:
|
||||
generator = self.fn(*inputs)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = await async_iteration(generator)
|
||||
if self.type == "tuples":
|
||||
yield first_response, history + [[message, first_response]]
|
||||
else:
|
||||
first_response = self.response_as_dict(first_response)
|
||||
yield (
|
||||
first_response,
|
||||
history + [{"role": "user", "content": message}, first_response],
|
||||
)
|
||||
except StopIteration:
|
||||
yield None, history + [[message, None]]
|
||||
async for response in generator:
|
||||
if self.type == "tuples":
|
||||
yield response, history + [[message, response]]
|
||||
else:
|
||||
new_response = self.response_as_dict(response)
|
||||
yield (
|
||||
new_response,
|
||||
history + [{"role": "user", "content": message}, new_response],
|
||||
)
|
||||
|
||||
async def _examples_fn(
|
||||
self, message: str, *args
|
||||
) -> TupleFormat | list[MessageDict]:
|
||||
|
@ -301,3 +301,17 @@ class TestAPI:
|
||||
"robot ",
|
||||
"robot h",
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("type", ["tuples", "messages"])
|
||||
def test_multimodal_api(self, type, connect):
|
||||
def double_multimodal(msg, history):
|
||||
return msg["text"] + " " + msg["text"]
|
||||
|
||||
chatbot = gr.ChatInterface(
|
||||
double_multimodal,
|
||||
type=type,
|
||||
multimodal=True,
|
||||
)
|
||||
with connect(chatbot) as client:
|
||||
result = client.predict({"text": "hello", "files": []}, api_name="/chat")
|
||||
assert result == "hello hello"
|
||||
|
Loading…
Reference in New Issue
Block a user