Allow gr.Interface.from_pipeline() and gr.load() to work within gr.Blocks() (#5231)

* add test

* external

* fixed in external

* add changeset

* close demo

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-08-15 22:28:19 -07:00 committed by GitHub
parent b2f49cfa36
commit 87f1c2b4ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 136 additions and 93 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Allow `gr.Interface.from_pipeline()` and `gr.load()` to work within `gr.Blocks()`

View File

@ -152,8 +152,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
pipelines = {
"audio-classification": {
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
"outputs": components.Label(label="Class"),
"inputs": components.Audio(
source="upload", type="filepath", label="Input", render=False
),
"outputs": components.Label(label="Class", render=False),
"preprocess": lambda i: to_binary,
"postprocess": lambda r: postprocess_label(
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
@ -161,34 +163,38 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
},
"audio-to-audio": {
# example model: facebook/xm_transformer_sm_all-en
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
"outputs": components.Audio(label="Output"),
"inputs": components.Audio(
source="upload", type="filepath", label="Input", render=False
),
"outputs": components.Audio(label="Output", render=False),
"preprocess": to_binary,
"postprocess": encode_to_base64,
},
"automatic-speech-recognition": {
# example model: facebook/wav2vec2-base-960h
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
"outputs": components.Textbox(label="Output"),
"inputs": components.Audio(
source="upload", type="filepath", label="Input", render=False
),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": to_binary,
"postprocess": lambda r: r.json()["text"],
},
"conversational": {
"inputs": [components.Textbox(), components.State()], # type: ignore
"outputs": [components.Chatbot(), components.State()], # type: ignore
"inputs": [components.Textbox(render=False), components.State(render=False)], # type: ignore
"outputs": [components.Chatbot(render=False), components.State(render=False)], # type: ignore
"preprocess": chatbot_preprocess,
"postprocess": chatbot_postprocess,
},
"feature-extraction": {
# example model: julien-c/distilbert-feature-extraction
"inputs": components.Textbox(label="Input"),
"outputs": components.Dataframe(label="Output"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Dataframe(label="Output", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0],
},
"fill-mask": {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: postprocess_label(
{i["token_str"]: i["score"] for i in r.json()}
@ -196,8 +202,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
},
"image-classification": {
# Example: google/vit-base-patch16-224
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Label(label="Classification"),
"inputs": components.Image(
type="filepath", label="Input Image", render=False
),
"outputs": components.Label(label="Classification", render=False),
"preprocess": to_binary,
"postprocess": lambda r: postprocess_label(
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
@ -206,27 +214,27 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
"question-answering": {
# Example: deepset/xlm-roberta-base-squad2
"inputs": [
components.Textbox(lines=7, label="Context"),
components.Textbox(label="Question"),
components.Textbox(lines=7, label="Context", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": [
components.Textbox(label="Answer"),
components.Label(label="Score"),
components.Textbox(label="Answer", render=False),
components.Label(label="Score", render=False),
],
"preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
"postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}),
},
"summarization": {
# Example: facebook/bart-large-cnn
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Summary"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Summary", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["summary_text"],
},
"text-classification": {
# Example: distilbert-base-uncased-finetuned-sst-2-english
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: postprocess_label(
{i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
@ -234,32 +242,34 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
},
"text-generation": {
# Example: gpt2
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Output"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"text2text-generation": {
# Example: valhalla/t5-small-qa-qg-hl
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Generated Text"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Generated Text", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"translation": {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Translation"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Translation", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["translation_text"],
},
"zero-shot-classification": {
# Example: facebook/bart-large-mnli
"inputs": [
components.Textbox(label="Input"),
components.Textbox(label="Possible class names (" "comma-separated)"),
components.Checkbox(label="Allow multiple true classes"),
components.Textbox(label="Input", render=False),
components.Textbox(
label="Possible class names (" "comma-separated)", render=False
),
components.Checkbox(label="Allow multiple true classes", render=False),
],
"outputs": components.Label(label="Classification"),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda i, c, m: {
"inputs": i,
"parameters": {"candidate_labels": c, "multi_class": m},
@ -275,15 +285,18 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
# Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens
"inputs": [
components.Textbox(
value="That is a happy person", label="Source Sentence"
value="That is a happy person",
label="Source Sentence",
render=False,
),
components.Textbox(
lines=7,
placeholder="Separate each sentence by a newline",
label="Sentences to compare to",
render=False,
),
],
"outputs": components.Label(label="Classification"),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda src, sentences: {
"inputs": {
"source_sentence": src,
@ -296,32 +309,32 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
},
"text-to-speech": {
# Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
"inputs": components.Textbox(label="Input"),
"outputs": components.Audio(label="Audio"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Audio(label="Audio", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
"text-to-image": {
# example model: osanseviero/BigGAN-deep-128
"inputs": components.Textbox(label="Input"),
"outputs": components.Image(label="Output"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Image(label="Output", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
"token-classification": {
# example model: huggingface-course/bert-finetuned-ner
"inputs": components.Textbox(label="Input"),
"outputs": components.HighlightedText(label="Output"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.HighlightedText(label="Output", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
},
"document-question-answering": {
# example model: impira/layoutlm-document-qa
"inputs": [
components.Image(type="filepath", label="Input Document"),
components.Textbox(label="Question"),
components.Image(type="filepath", label="Input Document", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": components.Label(label="Label"),
"outputs": components.Label(label="Label", render=False),
"preprocess": lambda img, q: {
"inputs": {
"image": extract_base64_data(img), # Extract base64 data
@ -335,10 +348,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
"visual-question-answering": {
# example model: dandelin/vilt-b32-finetuned-vqa
"inputs": [
components.Image(type="filepath", label="Input Image"),
components.Textbox(label="Question"),
components.Image(type="filepath", label="Input Image", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": components.Label(label="Label"),
"outputs": components.Label(label="Label", render=False),
"preprocess": lambda img, q: {
"inputs": {
"image": extract_base64_data(img),
@ -351,8 +364,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
},
"image-to-text": {
# example model: Salesforce/blip-image-captioning-base
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Textbox(label="Generated Text"),
"inputs": components.Image(
type="filepath", label="Input Image", render=False
),
"outputs": components.Textbox(label="Generated Text", render=False),
"preprocess": to_binary,
"postprocess": lambda r: r.json()[0]["generated_text"],
},
@ -369,9 +384,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
type="pandas",
headers=col_names,
col_count=(len(col_names), "fixed"),
render=False,
),
"outputs": components.Dataframe(
label="Predictions", type="array", headers=["prediction"]
label="Predictions", type="array", headers=["prediction"], render=False
),
"preprocess": rows_to_cols,
"postprocess": lambda r: {

View File

@ -35,9 +35,12 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
source="microphone",
type="filepath",
label="Input",
render=False,
),
"outputs": components.Label(label="Class"),
"outputs": components.Label(label="Class", render=False),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
@ -47,9 +50,9 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
source="microphone", type="filepath", label="Input", render=False
),
"outputs": components.Textbox(label="Output"),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
}
@ -57,8 +60,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Dataframe(label="Output"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Dataframe(label="Output", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0],
}
@ -66,8 +69,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.fill_mask.FillMaskPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
@ -75,8 +78,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.image_classification.ImageClassificationPipeline
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Label(type="confidences", label="Classification"),
"inputs": components.Image(
type="filepath", label="Input Image", render=False
),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
@ -85,12 +90,12 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
):
pipeline_info = {
"inputs": [
components.Textbox(lines=7, label="Context"),
components.Textbox(label="Question"),
components.Textbox(lines=7, label="Context", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": [
components.Textbox(label="Answer"),
components.Label(label="Score"),
components.Textbox(label="Answer", render=False),
components.Label(label="Score", render=False),
],
"preprocess": lambda c, q: {"context": c, "question": q},
"postprocess": lambda r: (r["answer"], r["score"]),
@ -99,8 +104,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.text2text_generation.SummarizationPipeline
):
pipeline_info = {
"inputs": components.Textbox(lines=7, label="Input"),
"outputs": components.Textbox(label="Summary"),
"inputs": components.Textbox(lines=7, label="Input", render=False),
"outputs": components.Textbox(label="Summary", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0]["summary_text"],
}
@ -108,8 +113,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.text_classification.TextClassificationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
@ -117,8 +122,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.text_generation.TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Output"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": lambda x: {"text_inputs": x},
"postprocess": lambda r: r[0]["generated_text"],
}
@ -126,8 +131,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.text2text_generation.TranslationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Translation"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Translation", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["translation_text"],
}
@ -135,8 +140,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Generated Text"),
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Generated Text", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["generated_text"],
}
@ -145,11 +150,13 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
):
pipeline_info = {
"inputs": [
components.Textbox(label="Input"),
components.Textbox(label="Possible class names (" "comma-separated)"),
components.Checkbox(label="Allow multiple true classes"),
components.Textbox(label="Input", render=False),
components.Textbox(
label="Possible class names (" "comma-separated)", render=False
),
components.Checkbox(label="Allow multiple true classes", render=False),
],
"outputs": components.Label(label="Classification"),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda i, c, m: {
"sequences": i,
"candidate_labels": c,
@ -165,10 +172,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
):
pipeline_info = {
"inputs": [
components.Image(type="filepath", label="Input Document"),
components.Textbox(label="Question"),
components.Image(type="filepath", label="Input Document", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": components.Label(label="Label"),
"outputs": components.Label(label="Label", render=False),
"preprocess": lambda img, q: {"image": img, "question": q},
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
}
@ -177,10 +184,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
):
pipeline_info = {
"inputs": [
components.Image(type="filepath", label="Input Image"),
components.Textbox(label="Question"),
components.Image(type="filepath", label="Input Image", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": components.Label(label="Score"),
"outputs": components.Label(label="Score", render=False),
"preprocess": lambda img, q: {"image": img, "question": q},
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
}
@ -188,8 +195,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Textbox(label="Text"),
"inputs": components.Image(
type="filepath", label="Input Image", render=False
),
"outputs": components.Textbox(label="Text", render=False),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: r[0]["generated_text"],
}

View File

@ -236,8 +236,9 @@ class TestLoadInterface:
except TooManyRequestsError:
pass
def test_conversational(self):
io = gr.load("models/microsoft/DialoGPT-medium")
def test_conversational_in_blocks(self):
with gr.Blocks() as io:
gr.load("models/microsoft/DialoGPT-medium")
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
assert app.state_holder == {}

View File

@ -5,9 +5,21 @@ import gradio as gr
@pytest.mark.flaky
class TestLoadFromPipeline:
def test_text_to_text_model_from_pipeline(self):
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
io = gr.Interface.from_pipeline(pipe)
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
assert isinstance(output, str)
def test_text_to_text_model_from_pipeline():
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
io = gr.Interface.from_pipeline(pipe)
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
assert isinstance(output, str)
@pytest.mark.flaky
def test_interface_in_blocks():
pipe1 = transformers.pipeline(model="sshleifer/bart-tiny-random")
pipe2 = transformers.pipeline(model="sshleifer/bart-tiny-random")
with gr.Blocks() as demo:
with gr.Tab("Image Inference"):
gr.Interface.from_pipeline(pipe1)
with gr.Tab("Image Inference"):
gr.Interface.from_pipeline(pipe2)
demo.launch(prevent_thread_lock=True)
demo.close()