Turn Submit Button into Cancel Button for Interfaces with generators (#3124)

* Switch back from stop button

* Add unit test

* CHANGELOG

* Fix test

* lint

* Reset UI after exception
This commit is contained in:
Freddy Boulton 2023-02-06 14:30:21 -05:00 committed by GitHub
parent ec2b68f554
commit 9beb15b3ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 12 deletions

View File

@ -2,6 +2,18 @@
## New Features:
### Revamped Stop Button for Interfaces 🛑
If your Interface function is a generator, there used to be a separate `Stop` button displayed next
to the `Submit` button.
We've revamed the `Submit` button so that it turns into a `Stop` button during the generation process.
Clicking on the `Stop` button will cancel the generation and turn it back to a `Submit` button.
The `Stop` button will automatically turn back to a `Submit` button at the end of the generation if you don't use it!
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3124](https://github.com/gradio-app/gradio/pull/3124)
### Queue now works with reload mode!
You can now call `queue` on your `demo` outside of the `if __name__ == "__main__"` block and

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

After

Width:  |  Height:  |  Size: 762 KiB

View File

@ -547,7 +547,7 @@ class Interface(Blocks):
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
if inspect.isgeneratorfunction(self.fn):
stop_btn = Button("Stop", variant="stop")
stop_btn = Button("Stop", variant="stop", visible=False)
elif self.interface_type == InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
@ -588,7 +588,7 @@ class Interface(Blocks):
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
stop_btn = Button("Stop", variant="stop")
stop_btn = Button("Stop", variant="stop", visible=False)
if self.allow_flagging == "manual":
flag_btns = self.render_flag_btns()
elif self.allow_flagging == "auto":
@ -643,10 +643,38 @@ class Interface(Blocks):
)
else:
assert submit_btn is not None, "Submit button not rendered"
fn = self.fn
extra_output = []
if stop_btn:
# Wrap the original function to show/hide the "Stop" button
def fn(*args):
# The main idea here is to call the original function
# and append some updates to keep the "Submit" button
# hidden and the "Stop" button visible
# The 'finally' block hides the "Stop" button and
# shows the "submit" button. Having a 'finally' block
# will make sure the UI is "reset" even if there is an exception
try:
for output in self.fn(*args):
if len(self.output_components) == 1 and not self.batch:
output = [output]
output = [o for o in output]
yield output + [
Button.update(visible=False),
Button.update(visible=True),
]
finally:
yield [
{"__type__": "generic_update"}
for _ in self.output_components
] + [Button.update(visible=True), Button.update(visible=False)]
extra_output = [submit_btn, stop_btn]
pred = submit_btn.click(
self.fn,
fn,
self.input_components,
self.output_components,
self.output_components + extra_output,
api_name="predict",
scroll_to_output=True,
preprocess=not (self.api_mode),
@ -655,11 +683,24 @@ class Interface(Blocks):
max_batch_size=self.max_batch_size,
)
if stop_btn:
stop_btn.click(
None,
submit_btn.click(
lambda: (
submit_btn.update(visible=False),
stop_btn.update(visible=True),
),
inputs=None,
outputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
)
stop_btn.click(
lambda: (
submit_btn.update(visible=True),
stop_btn.update(visible=False),
),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[pred],
queue=False,
)
def attach_clear_events(

View File

@ -1039,6 +1039,51 @@ class TestCancel:
cancel.click(None, None, None, cancels=[click])
demo.queue().launch(prevent_thread_lock=True)
@pytest.mark.asyncio
async def test_cancel_button_for_interfaces(self):
def generate(x):
for i in range(4):
yield i
time.sleep(0.2)
io = gr.Interface(generate, gr.Textbox(), gr.Textbox()).queue()
stop_btn_id = next(
i for i, k in io.blocks.items() if getattr(k, "value", None) == "Stop"
)
assert not io.blocks[stop_btn_id].visible
io.launch(prevent_thread_lock=True)
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
checked_iteration = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "process_generating" and isinstance(
msg["output"]["data"][0], str
):
checked_iteration = True
assert msg["output"]["data"][1:] == [
{"visible": False, "__type__": "update"},
{"visible": True, "__type__": "update"},
]
if msg["msg"] == "process_completed":
assert msg["output"]["data"] == [
{"__type__": "update"},
{"visible": True, "__type__": "update"},
{"visible": False, "__type__": "update"},
]
completed = True
assert checked_iteration
io.close()
class TestEvery:
def test_raise_exception_if_parameters_invalid(self):

View File

@ -277,7 +277,7 @@ class TestGeneratorRoutes:
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"] == ["a"]
assert output["data"][0] == "a"
response = client.post(
"/api/predict/",
@ -285,7 +285,7 @@ class TestGeneratorRoutes:
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"] == ["b"]
assert output["data"][0] == "b"
response = client.post(
"/api/predict/",
@ -293,7 +293,7 @@ class TestGeneratorRoutes:
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"] == ["c"]
assert output["data"][0] == "c"
response = client.post(
"/api/predict/",
@ -301,7 +301,11 @@ class TestGeneratorRoutes:
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"] == [None]
assert output["data"] == [
{"__type__": "update"},
{"__type__": "update", "visible": True},
{"__type__": "update", "visible": False},
]
response = client.post(
"/api/predict/",
@ -309,7 +313,15 @@ class TestGeneratorRoutes:
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"] == ["a"]
assert output["data"][0] is None
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
headers={"Authorization": f"Bearer {app.queue_token}"},
)
output = dict(response.json())
assert output["data"][0] == "a"
class TestApp: