Fix gr.Interface.from_pipeline() to allow audio uploads and to display classification labels correctly (#8080)

* Allow audio file inputs

* Stop comma-splitting label texts

* add changeset

* add changeset

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2024-04-19 22:55:38 +01:00 committed by GitHub
parent d1e3676e73
commit 568eeb26a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 12 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Fix `gr.Interface.from_pipeline()` to allow audio uploads and to display classification labels correctly

View File

@ -26,21 +26,14 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
# version of the transformers library that the user has installed.
if is_transformers_pipeline_type(pipeline, "AudioClassificationPipeline"):
return {
"inputs": components.Audio(
sources=["microphone"],
type="filepath",
label="Input",
render=False,
),
"inputs": components.Audio(type="filepath", label="Input", render=False),
"outputs": components.Label(label="Class", render=False),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
}
if is_transformers_pipeline_type(pipeline, "AutomaticSpeechRecognitionPipeline"):
return {
"inputs": components.Audio(
sources=["microphone"], type="filepath", label="Input", render=False
),
"inputs": components.Audio(type="filepath", label="Input", render=False),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
@ -66,7 +59,7 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
}
if is_transformers_pipeline_type(pipeline, "QuestionAnsweringPipeline"):
return {
@ -93,7 +86,7 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
}
if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"):
return {