Allow textbox / number submits to trigger Interface submit (#4090)

* changes

* changes

* changes

* changes

* changes

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2023-05-08 20:51:05 -05:00 committed by GitHub
parent 1910df10d9
commit 5ef0bfeefc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 18 deletions

View File

@ -28,11 +28,11 @@ No changes to highlight.
## Full Changelog:
- Allow users to submit with enter in Interfaces with textbox / number inputs [@aliabid94](https://github.com/aliabid94) in [PR 4090](https://github.com/gradio-app/gradio/pull/4090).
- Updates gradio's requirements.txt to requires uvicorn>=0.14.0 by [@abidlabs](https://github.com/abidlabs) in [PR 4086](https://github.com/gradio-app/gradio/pull/4086)
- Updates some error messaging by [@abidlabs](https://github.com/abidlabs) in [PR 4086](https://github.com/gradio-app/gradio/pull/4086)
- Renames simplified Chinese translation file from `zh-cn.json` to `zh-CN.json` by [@abidlabs](https://github.com/abidlabs) in [PR 4086](https://github.com/gradio-app/gradio/pull/4086)
## Contributors Shoutout:
No changes to highlight.

View File

@ -25,7 +25,7 @@ from gradio.components import (
get_component_instance,
)
from gradio.data_classes import InterfaceTypes
from gradio.events import Changeable, Streamable
from gradio.events import Changeable, Streamable, Submittable
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline
@ -625,17 +625,37 @@ class Interface(Blocks):
] + [Button.update(visible=True), Button.update(visible=False)]
extra_output = [submit_btn, stop_btn]
pred = submit_btn.click(
fn,
self.input_components,
self.output_components + extra_output,
api_name="predict",
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
triggers = [submit_btn.click] + [
component.submit
for component in self.input_components
if isinstance(component, Submittable)
]
predict_events = []
for i, trigger in enumerate(triggers):
predict_events.append(
trigger(
fn,
self.input_components,
self.output_components + extra_output,
api_name="predict" if i == 0 else None,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
)
if stop_btn:
trigger(
lambda: (
submit_btn.update(visible=False),
stop_btn.update(visible=True),
),
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
)
if stop_btn:
submit_btn.click(
lambda: (
@ -653,7 +673,7 @@ class Interface(Blocks):
),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[pred],
cancels=predict_events,
queue=False,
)

View File

@ -327,10 +327,10 @@ class TestProcessExamples:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]})
response = client.post("/api/predict/", json={"fn_index": 7, "data": [0]})
assert response.json()["data"] == ["Hello,"]
response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]})
response = client.post("/api/predict/", json={"fn_index": 7, "data": [1]})
assert response.json()["data"] == ["Michael"]
def test_end_to_end_cache_examples(self):
@ -348,8 +348,8 @@ class TestProcessExamples:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]})
response = client.post("/api/predict/", json={"fn_index": 7, "data": [0]})
assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]})
response = client.post("/api/predict/", json={"fn_index": 7, "data": [1]})
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]