Fix chatinterface multimodal bug (#9119)

* Add test

* add changeset

* comments

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-08-15 10:50:32 -04:00 committed by GitHub
parent e1c404da11
commit 30b5d6f2b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 164 additions and 103 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix chatinterface multimodal bug

View File

@ -3,6 +3,10 @@ import gradio as gr
runs = 0
def reset_runs():
global runs
runs = 0
def slow_echo(message, history):
global runs # i didn't want to add state or anything to this demo
runs = runs + 1
@ -10,7 +14,16 @@ def slow_echo(message, history):
time.sleep(0.05)
yield f"Run {runs} - You typed: " + message[: i + 1]
demo = gr.ChatInterface(slow_echo, type="messages")
chat = gr.ChatInterface(slow_echo, type="messages")
with gr.Blocks() as demo:
chat.render()
# We reset the global variable to minimize flakes
# this works because CI runs only one test at at time
# need to use gr.State if we want to parallelize this test
# currently chatinterface does not support that
demo.unload(reset_runs)
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,26 @@
import gradio as gr
runs = 0
def reset_runs():
global runs
runs = 0
def slow_echo(message, history):
global runs # i didn't want to add state or anything to this demo
runs = runs + 1
for i in range(len(message['text'])):
yield f"Run {runs} - You typed: " + message['text'][: i + 1]
chat = gr.ChatInterface(slow_echo, multimodal=True, type="messages")
with gr.Blocks() as demo:
chat.render()
# We reset the global variable to minimize flakes
# this works because CI runs only one test at at time
# need to use gr.State if we want to parallelize this test
# currently chatinterface does not support that
demo.unload(reset_runs)
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,25 @@
import gradio as gr
runs = 0
def reset_runs():
global runs
runs = 0
def slow_echo(message, history):
global runs # i didn't want to add state or anything to this demo
runs = runs + 1
return f"Run {runs} - You typed: " + message['text']
chat = gr.ChatInterface(slow_echo, multimodal=True, type="tuples")
with gr.Blocks() as demo:
chat.render()
# We reset the global variable to minimize flakes
# this works because CI runs only one test at at time
# need to use gr.State if we want to parallelize this test
# currently chatinterface does not support that
demo.unload(reset_runs)
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,26 @@
import gradio as gr
runs = 0
def reset_runs():
global runs
runs = 0
def slow_echo(message, history):
global runs # i didn't want to add state or anything to this demo
runs = runs + 1
for i in range(len(message['text'])):
yield f"Run {runs} - You typed: " + message['text'][: i + 1]
chat = gr.ChatInterface(slow_echo, multimodal=True, type="tuples")
with gr.Blocks() as demo:
chat.render()
# We reset the global variable to minimize flakes
# this works because CI runs only one test at at time
# need to use gr.State if we want to parallelize this test
# currently chatinterface does not support that
demo.unload(reset_runs)
if __name__ == "__main__":
demo.launch()

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_non_stream_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_tuples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -579,6 +579,30 @@ class ChatInterface(Blocks):
new_response = response
return cast(MessageDict, new_response)
def _process_msg_and_trim_history(
self,
message: str | MultimodalData,
history_with_input: TupleFormat | list[MessageDict],
) -> tuple[str | dict, TupleFormat | list[MessageDict]]:
if isinstance(message, MultimodalData):
remove_input = len(message.files) + int(message.text is not None)
history = history_with_input[:-remove_input]
message_serialized = message.model_dump()
else:
history = history_with_input[:-1]
message_serialized = message
return message_serialized, history
def _append_history(self, history, message, first_response=True):
if self.type == "tuples":
history[-1][1] = message # type: ignore
else:
message = self.response_as_dict(message)
if first_response:
history.append(message) # type: ignore
else:
history[-1] = message
async def _submit_fn(
self,
message: str | MultimodalData,
@ -586,18 +610,9 @@ class ChatInterface(Blocks):
request: Request,
*args,
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
if self.multimodal and isinstance(message, MultimodalData):
remove_input = (
len(message.files) + 1
if message.text is not None
else len(message.files)
)
history = history_with_input[:-remove_input]
message_serialized = message.model_dump()
else:
history = history_with_input[:-1]
message_serialized = message
message_serialized, history = self._process_msg_and_trim_history(
message, history_with_input
)
inputs, _, _ = special_args(
self.fn, inputs=[message_serialized, history, *args], request=request
)
@ -609,15 +624,8 @@ class ChatInterface(Blocks):
self.fn, *inputs, limiter=self.limiter
)
if self.type == "messages":
new_response = self.response_as_dict(response)
else:
new_response = response
self._append_history(history_with_input, response)
if self.type == "tuples":
history_with_input[-1][1] = new_response # type: ignore
elif self.type == "messages":
history_with_input.append(new_response) # type: ignore
return history_with_input # type: ignore
async def _stream_fn(
@ -627,17 +635,11 @@ class ChatInterface(Blocks):
request: Request,
*args,
) -> AsyncGenerator:
if self.multimodal and isinstance(message, MultimodalData):
remove_input = (
len(message.files) + 1
if message.text is not None
else len(message.files)
)
history = history_with_input[:-remove_input]
else:
history = history_with_input[:-1]
message_serialized, history = self._process_msg_and_trim_history(
message, history_with_input
)
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
self.fn, inputs=[message_serialized, history, *args], request=request
)
if self.is_async:
@ -649,53 +651,13 @@ class ChatInterface(Blocks):
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.type == "messages":
first_response = self.response_as_dict(first_response)
if (
self.multimodal
and isinstance(message, MultimodalData)
and self.type == "tuples"
):
history_with_input[-1][1] = first_response # type: ignore
yield history_with_input
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.type == "messages"
):
history_with_input.append(first_response) # type: ignore
yield history_with_input
elif self.type == "tuples":
history_with_input[-1][1] = first_response # type: ignore
yield history_with_input
else:
history_with_input.append(first_response) # type: ignore
yield history_with_input
self._append_history(history_with_input, first_response)
yield history_with_input
except StopIteration:
yield history_with_input
async for response in generator:
if self.type == "messages":
response = self.response_as_dict(response)
if (
self.multimodal
and isinstance(message, MultimodalData)
and self.type == "tuples"
):
history_with_input[-1][1] = response # type: ignore
yield history_with_input
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.type == "messages"
):
history_with_input[-1] = response # type: ignore
yield history_with_input
elif self.type == "tuples":
history_with_input[-1][1] = response # type: ignore
yield history_with_input
else:
history_with_input[-1] = response # type: ignore
yield history_with_input
self._append_history(history_with_input, response, first_response=False)
yield history_with_input
async def _examples_fn(
self, message: str, *args

View File

@ -1,13 +1,24 @@
import { test, expect, go_to_testcase } from "@gradio/tootils";
for (const msg_format of ["tuples", "messages"]) {
test(`msg format ${msg_format} chatinterface works with streaming functions and all buttons behave as expected`, async ({
const cases = [
"tuples",
"messages",
"multimodal_tuples",
"multimodal_messages",
"multimodal_non_stream"
];
for (const test_case of cases) {
test(`test case ${test_case} chatinterface works with streaming functions and all buttons behave as expected`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
if (cases.slice(1).includes(test_case)) {
await go_to_testcase(page, test_case);
}
let submit_button = page.getByRole("button", { name: "Submit" });
if (test_case.startsWith("multimodal")) {
submit_button = page.locator(".submit-button");
}
const submit_button = page.getByRole("button", { name: "Submit" });
const retry_button = page.getByRole("button", { name: "🔄 Retry" });
const undo_button = page.getByRole("button", { name: "↩️ Undo" });
const clear_button = page.getByRole("button", { name: "🗑️ Clear" });
@ -32,25 +43,19 @@ for (const msg_format of ["tuples", "messages"]) {
hasText: "Run 2 - You typed: hi"
});
await expect(expected_text_el_1).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(2);
await expect(page.locator(".bot.message")).toHaveCount(2);
await undo_button.click();
await expect
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
.toBe(1);
await expect(page.locator(".bot.message")).toHaveCount(1);
await expect(textbox).toHaveValue("hi");
await retry_button.click();
const expected_text_el_2 = page.locator(".bot p", {
hasText: ""
hasText: "Run 3 - You typed: hi"
});
await expect(expected_text_el_2).toBeVisible();
await expect
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
.toBe(1);
await expect(page.locator(".bot.message")).toHaveCount(1);
await textbox.fill("hi");
await submit_button.click();
@ -59,24 +64,22 @@ for (const msg_format of ["tuples", "messages"]) {
hasText: "Run 4 - You typed: hi"
});
await expect(expected_text_el_3).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(2);
await expect(page.locator(".bot.message")).toHaveCount(2);
await clear_button.click();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 5000 })
.toBe(0);
await expect(page.locator(".bot.message")).toHaveCount(0);
});
test(`msg format ${msg_format} the api recorder correctly records the api calls`, async ({
test(`test case ${test_case} the api recorder correctly records the api calls`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
if (cases.slice(1).includes(test_case)) {
await go_to_testcase(page, test_case);
}
const textbox = page.getByPlaceholder("Type a message...");
const submit_button = page.getByRole("button", { name: "Submit" });
let submit_button = page.getByRole("button", { name: "Submit" });
if (test_case.startsWith("multimodal")) {
submit_button = page.locator(".submit-button");
}
await textbox.fill("hi");
await page.getByRole("button", { name: "Use via API logo" }).click();
@ -88,8 +91,9 @@ for (const msg_format of ["tuples", "messages"]) {
);
const api_recorder = await page.locator("#api-recorder");
await api_recorder.click();
const n_calls = test_case.includes("non_stream") ? 3 : 5;
await expect(page.locator("#num-recorded-api-calls")).toContainText(
"🪄 Recorded API Calls [5]"
`🪄 Recorded API Calls [${n_calls}]`
);
});
}