mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Fix chatinterface multimodal bug (#9119)
* Add test * add changeset * comments --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
e1c404da11
commit
30b5d6f2b7
5
.changeset/smart-pants-dance.md
Normal file
5
.changeset/smart-pants-dance.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix chatinterface multimodal bug
|
@ -3,6 +3,10 @@ import gradio as gr
|
||||
|
||||
runs = 0
|
||||
|
||||
def reset_runs():
|
||||
global runs
|
||||
runs = 0
|
||||
|
||||
def slow_echo(message, history):
|
||||
global runs # i didn't want to add state or anything to this demo
|
||||
runs = runs + 1
|
||||
@ -10,7 +14,16 @@ def slow_echo(message, history):
|
||||
time.sleep(0.05)
|
||||
yield f"Run {runs} - You typed: " + message[: i + 1]
|
||||
|
||||
demo = gr.ChatInterface(slow_echo, type="messages")
|
||||
chat = gr.ChatInterface(slow_echo, type="messages")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chat.render()
|
||||
# We reset the global variable to minimize flakes
|
||||
# this works because CI runs only one test at at time
|
||||
# need to use gr.State if we want to parallelize this test
|
||||
# currently chatinterface does not support that
|
||||
demo.unload(reset_runs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -0,0 +1,26 @@
|
||||
import gradio as gr
|
||||
|
||||
runs = 0
|
||||
|
||||
def reset_runs():
|
||||
global runs
|
||||
runs = 0
|
||||
|
||||
def slow_echo(message, history):
|
||||
global runs # i didn't want to add state or anything to this demo
|
||||
runs = runs + 1
|
||||
for i in range(len(message['text'])):
|
||||
yield f"Run {runs} - You typed: " + message['text'][: i + 1]
|
||||
|
||||
chat = gr.ChatInterface(slow_echo, multimodal=True, type="messages")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chat.render()
|
||||
# We reset the global variable to minimize flakes
|
||||
# this works because CI runs only one test at at time
|
||||
# need to use gr.State if we want to parallelize this test
|
||||
# currently chatinterface does not support that
|
||||
demo.unload(reset_runs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -0,0 +1,25 @@
|
||||
import gradio as gr
|
||||
|
||||
runs = 0
|
||||
|
||||
def reset_runs():
|
||||
global runs
|
||||
runs = 0
|
||||
|
||||
def slow_echo(message, history):
|
||||
global runs # i didn't want to add state or anything to this demo
|
||||
runs = runs + 1
|
||||
return f"Run {runs} - You typed: " + message['text']
|
||||
|
||||
chat = gr.ChatInterface(slow_echo, multimodal=True, type="tuples")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chat.render()
|
||||
# We reset the global variable to minimize flakes
|
||||
# this works because CI runs only one test at at time
|
||||
# need to use gr.State if we want to parallelize this test
|
||||
# currently chatinterface does not support that
|
||||
demo.unload(reset_runs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -0,0 +1,26 @@
|
||||
import gradio as gr
|
||||
|
||||
runs = 0
|
||||
|
||||
def reset_runs():
|
||||
global runs
|
||||
runs = 0
|
||||
|
||||
def slow_echo(message, history):
|
||||
global runs # i didn't want to add state or anything to this demo
|
||||
runs = runs + 1
|
||||
for i in range(len(message['text'])):
|
||||
yield f"Run {runs} - You typed: " + message['text'][: i + 1]
|
||||
|
||||
chat = gr.ChatInterface(slow_echo, multimodal=True, type="tuples")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chat.render()
|
||||
# We reset the global variable to minimize flakes
|
||||
# this works because CI runs only one test at at time
|
||||
# need to use gr.State if we want to parallelize this test
|
||||
# currently chatinterface does not support that
|
||||
demo.unload(reset_runs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_non_stream_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_tuples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -579,6 +579,30 @@ class ChatInterface(Blocks):
|
||||
new_response = response
|
||||
return cast(MessageDict, new_response)
|
||||
|
||||
def _process_msg_and_trim_history(
|
||||
self,
|
||||
message: str | MultimodalData,
|
||||
history_with_input: TupleFormat | list[MessageDict],
|
||||
) -> tuple[str | dict, TupleFormat | list[MessageDict]]:
|
||||
if isinstance(message, MultimodalData):
|
||||
remove_input = len(message.files) + int(message.text is not None)
|
||||
history = history_with_input[:-remove_input]
|
||||
message_serialized = message.model_dump()
|
||||
else:
|
||||
history = history_with_input[:-1]
|
||||
message_serialized = message
|
||||
return message_serialized, history
|
||||
|
||||
def _append_history(self, history, message, first_response=True):
|
||||
if self.type == "tuples":
|
||||
history[-1][1] = message # type: ignore
|
||||
else:
|
||||
message = self.response_as_dict(message)
|
||||
if first_response:
|
||||
history.append(message) # type: ignore
|
||||
else:
|
||||
history[-1] = message
|
||||
|
||||
async def _submit_fn(
|
||||
self,
|
||||
message: str | MultimodalData,
|
||||
@ -586,18 +610,9 @@ class ChatInterface(Blocks):
|
||||
request: Request,
|
||||
*args,
|
||||
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
remove_input = (
|
||||
len(message.files) + 1
|
||||
if message.text is not None
|
||||
else len(message.files)
|
||||
)
|
||||
history = history_with_input[:-remove_input]
|
||||
message_serialized = message.model_dump()
|
||||
else:
|
||||
history = history_with_input[:-1]
|
||||
message_serialized = message
|
||||
|
||||
message_serialized, history = self._process_msg_and_trim_history(
|
||||
message, history_with_input
|
||||
)
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message_serialized, history, *args], request=request
|
||||
)
|
||||
@ -609,15 +624,8 @@ class ChatInterface(Blocks):
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
|
||||
if self.type == "messages":
|
||||
new_response = self.response_as_dict(response)
|
||||
else:
|
||||
new_response = response
|
||||
self._append_history(history_with_input, response)
|
||||
|
||||
if self.type == "tuples":
|
||||
history_with_input[-1][1] = new_response # type: ignore
|
||||
elif self.type == "messages":
|
||||
history_with_input.append(new_response) # type: ignore
|
||||
return history_with_input # type: ignore
|
||||
|
||||
async def _stream_fn(
|
||||
@ -627,17 +635,11 @@ class ChatInterface(Blocks):
|
||||
request: Request,
|
||||
*args,
|
||||
) -> AsyncGenerator:
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
remove_input = (
|
||||
len(message.files) + 1
|
||||
if message.text is not None
|
||||
else len(message.files)
|
||||
)
|
||||
history = history_with_input[:-remove_input]
|
||||
else:
|
||||
history = history_with_input[:-1]
|
||||
message_serialized, history = self._process_msg_and_trim_history(
|
||||
message, history_with_input
|
||||
)
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
self.fn, inputs=[message_serialized, history, *args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
@ -649,53 +651,13 @@ class ChatInterface(Blocks):
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = await async_iteration(generator)
|
||||
if self.type == "messages":
|
||||
first_response = self.response_as_dict(first_response)
|
||||
if (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.type == "tuples"
|
||||
):
|
||||
history_with_input[-1][1] = first_response # type: ignore
|
||||
yield history_with_input
|
||||
elif (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.type == "messages"
|
||||
):
|
||||
history_with_input.append(first_response) # type: ignore
|
||||
yield history_with_input
|
||||
elif self.type == "tuples":
|
||||
history_with_input[-1][1] = first_response # type: ignore
|
||||
yield history_with_input
|
||||
else:
|
||||
history_with_input.append(first_response) # type: ignore
|
||||
yield history_with_input
|
||||
self._append_history(history_with_input, first_response)
|
||||
yield history_with_input
|
||||
except StopIteration:
|
||||
yield history_with_input
|
||||
async for response in generator:
|
||||
if self.type == "messages":
|
||||
response = self.response_as_dict(response)
|
||||
if (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.type == "tuples"
|
||||
):
|
||||
history_with_input[-1][1] = response # type: ignore
|
||||
yield history_with_input
|
||||
elif (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.type == "messages"
|
||||
):
|
||||
history_with_input[-1] = response # type: ignore
|
||||
yield history_with_input
|
||||
elif self.type == "tuples":
|
||||
history_with_input[-1][1] = response # type: ignore
|
||||
yield history_with_input
|
||||
else:
|
||||
history_with_input[-1] = response # type: ignore
|
||||
yield history_with_input
|
||||
self._append_history(history_with_input, response, first_response=False)
|
||||
yield history_with_input
|
||||
|
||||
async def _examples_fn(
|
||||
self, message: str, *args
|
||||
|
@ -1,13 +1,24 @@
|
||||
import { test, expect, go_to_testcase } from "@gradio/tootils";
|
||||
|
||||
for (const msg_format of ["tuples", "messages"]) {
|
||||
test(`msg format ${msg_format} chatinterface works with streaming functions and all buttons behave as expected`, async ({
|
||||
const cases = [
|
||||
"tuples",
|
||||
"messages",
|
||||
"multimodal_tuples",
|
||||
"multimodal_messages",
|
||||
"multimodal_non_stream"
|
||||
];
|
||||
|
||||
for (const test_case of cases) {
|
||||
test(`test case ${test_case} chatinterface works with streaming functions and all buttons behave as expected`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
if (cases.slice(1).includes(test_case)) {
|
||||
await go_to_testcase(page, test_case);
|
||||
}
|
||||
let submit_button = page.getByRole("button", { name: "Submit" });
|
||||
if (test_case.startsWith("multimodal")) {
|
||||
submit_button = page.locator(".submit-button");
|
||||
}
|
||||
const submit_button = page.getByRole("button", { name: "Submit" });
|
||||
const retry_button = page.getByRole("button", { name: "🔄 Retry" });
|
||||
const undo_button = page.getByRole("button", { name: "↩️ Undo" });
|
||||
const clear_button = page.getByRole("button", { name: "🗑️ Clear" });
|
||||
@ -32,25 +43,19 @@ for (const msg_format of ["tuples", "messages"]) {
|
||||
hasText: "Run 2 - You typed: hi"
|
||||
});
|
||||
await expect(expected_text_el_1).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(2);
|
||||
await expect(page.locator(".bot.message")).toHaveCount(2);
|
||||
|
||||
await undo_button.click();
|
||||
await expect
|
||||
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
|
||||
.toBe(1);
|
||||
await expect(page.locator(".bot.message")).toHaveCount(1);
|
||||
await expect(textbox).toHaveValue("hi");
|
||||
|
||||
await retry_button.click();
|
||||
const expected_text_el_2 = page.locator(".bot p", {
|
||||
hasText: ""
|
||||
hasText: "Run 3 - You typed: hi"
|
||||
});
|
||||
await expect(expected_text_el_2).toBeVisible();
|
||||
|
||||
await expect
|
||||
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
|
||||
.toBe(1);
|
||||
await expect(page.locator(".bot.message")).toHaveCount(1);
|
||||
|
||||
await textbox.fill("hi");
|
||||
await submit_button.click();
|
||||
@ -59,24 +64,22 @@ for (const msg_format of ["tuples", "messages"]) {
|
||||
hasText: "Run 4 - You typed: hi"
|
||||
});
|
||||
await expect(expected_text_el_3).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(2);
|
||||
|
||||
await expect(page.locator(".bot.message")).toHaveCount(2);
|
||||
await clear_button.click();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 5000 })
|
||||
.toBe(0);
|
||||
await expect(page.locator(".bot.message")).toHaveCount(0);
|
||||
});
|
||||
|
||||
test(`msg format ${msg_format} the api recorder correctly records the api calls`, async ({
|
||||
test(`test case ${test_case} the api recorder correctly records the api calls`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
if (cases.slice(1).includes(test_case)) {
|
||||
await go_to_testcase(page, test_case);
|
||||
}
|
||||
const textbox = page.getByPlaceholder("Type a message...");
|
||||
const submit_button = page.getByRole("button", { name: "Submit" });
|
||||
let submit_button = page.getByRole("button", { name: "Submit" });
|
||||
if (test_case.startsWith("multimodal")) {
|
||||
submit_button = page.locator(".submit-button");
|
||||
}
|
||||
await textbox.fill("hi");
|
||||
|
||||
await page.getByRole("button", { name: "Use via API logo" }).click();
|
||||
@ -88,8 +91,9 @@ for (const msg_format of ["tuples", "messages"]) {
|
||||
);
|
||||
const api_recorder = await page.locator("#api-recorder");
|
||||
await api_recorder.click();
|
||||
const n_calls = test_case.includes("non_stream") ? 3 : 5;
|
||||
await expect(page.locator("#num-recorded-api-calls")).toContainText(
|
||||
"🪄 Recorded API Calls [5]"
|
||||
`🪄 Recorded API Calls [${n_calls}]`
|
||||
);
|
||||
});
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user