Fix stopping chat interface when stop button is clicked (#9626)

* changes

* add changeset

* changes

* changes

* fix test

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2024-10-09 14:00:58 -07:00 committed by GitHub
parent 5923c67913
commit ec95b0212b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 61 additions and 28 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/textbox": minor
"gradio": minor
---
feat:Fix stopping chat interface when stop button is clicked

View File

@ -314,7 +314,7 @@ class ChatInterface(Blocks):
def _setup_events(self) -> None:
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_triggers = [self.textbox.submit]
submit_triggers = [self.textbox.submit, self.chatbot.retry]
submit_event = (
self.textbox.submit(
@ -343,12 +343,12 @@ class ChatInterface(Blocks):
Literal["full", "minimal", "hidden"], self.show_progress
),
)
.then(
lambda: update(value=None, interactive=True),
None,
self.textbox,
show_api=False,
)
)
submit_event.then(
lambda: update(value=None, interactive=True),
None,
self.textbox,
show_api=False,
)
if (
@ -381,7 +381,6 @@ class ChatInterface(Blocks):
Literal["full", "minimal", "hidden"], self.show_progress
),
)
self._setup_stop_events(submit_triggers, submit_event)
retry_event = (
self.chatbot.retry(
@ -415,13 +414,14 @@ class ChatInterface(Blocks):
Literal["full", "minimal", "hidden"], self.show_progress
),
)
.then(
lambda: update(interactive=True),
outputs=[self.textbox],
show_api=False,
)
)
self._setup_stop_events([self.chatbot.retry], retry_event)
retry_event.then(
lambda: update(interactive=True),
outputs=[self.textbox],
show_api=False,
)
self._setup_stop_events(submit_triggers, [submit_event, retry_event])
self.chatbot.undo(
self._undo_msg,
@ -432,7 +432,7 @@ class ChatInterface(Blocks):
)
def _setup_stop_events(
self, event_triggers: list[Callable], event_to_cancel: Dependency
self, event_triggers: list[Callable], events_to_cancel: list[Dependency]
) -> None:
textbox_component = MultimodalTextbox if self.multimodal else Textbox
if self.is_generator:
@ -450,22 +450,23 @@ class ChatInterface(Blocks):
show_api=False,
queue=False,
)
event_to_cancel.then(
async_lambda(
lambda: textbox_component(
submit_btn=original_submit_btn, stop_btn=False
)
),
None,
[self.textbox],
show_api=False,
queue=False,
)
for event_to_cancel in events_to_cancel:
event_to_cancel.then(
async_lambda(
lambda: textbox_component(
submit_btn=original_submit_btn, stop_btn=False
)
),
None,
[self.textbox],
show_api=False,
queue=False,
)
self.textbox.stop(
None,
None,
None,
cancels=event_to_cancel,
cancels=events_to_cancel,
show_api=False,
)

View File

@ -88,3 +88,29 @@ for (const test_case of cases) {
);
});
}
test("test stopping generation", async ({ page }) => {
const submit_button = page.locator(".submit-button");
const textbox = page.getByPlaceholder("Type a message...");
const long_string = "abc".repeat(1000);
await textbox.fill(long_string);
await submit_button.click();
await expect(page.locator(".bot.message").first()).toContainText("abc");
const stop_button = page.locator(".stop-button");
await stop_button.click();
await expect(page.locator(".bot.message").first()).toContainText("abc");
const current_content = await page
.locator(".bot.message")
.first()
.textContent();
await page.waitForTimeout(1000);
const new_content = await page.locator(".bot.message").first().textContent();
await expect(current_content).toBe(new_content);
await expect(new_content!.length).toBeLessThan(3000);
});

View File

@ -283,7 +283,7 @@
use:text_area_resize={value}
class="scroll-hide"
dir={rtl ? "rtl" : "ltr"}
class:no-label={!show_label && submit_btn}
class:no-label={!show_label && (submit_btn || stop_btn)}
bind:value
bind:this={el}
{placeholder}