From 87f1c2b4ac7c685c43477215fa5b96b6cbeffa05 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 15 Aug 2023 22:28:19 -0700 Subject: [PATCH] Allow `gr.Interface.from_pipeline()` and `gr.load()` to work within `gr.Blocks()` (#5231) * add test * external * fixed in external * add changeset * close demo --------- Co-authored-by: gradio-pr-bot --- .changeset/poor-papayas-behave.md | 5 ++ gradio/external.py | 114 +++++++++++++++++------------- gradio/pipelines.py | 81 +++++++++++---------- test/test_external.py | 5 +- test/test_pipelines.py | 24 +++++-- 5 files changed, 136 insertions(+), 93 deletions(-) create mode 100644 .changeset/poor-papayas-behave.md diff --git a/.changeset/poor-papayas-behave.md b/.changeset/poor-papayas-behave.md new file mode 100644 index 0000000000..b31a0fe7d9 --- /dev/null +++ b/.changeset/poor-papayas-behave.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Allow `gr.Interface.from_pipeline()` and `gr.load()` to work within `gr.Blocks()` diff --git a/gradio/external.py b/gradio/external.py index 29ad28384c..88710ca7c4 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -152,8 +152,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg pipelines = { "audio-classification": { # example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition - "inputs": components.Audio(source="upload", type="filepath", label="Input"), - "outputs": components.Label(label="Class"), + "inputs": components.Audio( + source="upload", type="filepath", label="Input", render=False + ), + "outputs": components.Label(label="Class", render=False), "preprocess": lambda i: to_binary, "postprocess": lambda r: postprocess_label( {i["label"].split(", ")[0]: i["score"] for i in r.json()} @@ -161,34 +163,38 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg }, "audio-to-audio": { # example model: facebook/xm_transformer_sm_all-en - "inputs": components.Audio(source="upload", type="filepath", label="Input"), - "outputs": components.Audio(label="Output"), + "inputs": components.Audio( + source="upload", type="filepath", label="Input", render=False + ), + "outputs": components.Audio(label="Output", render=False), "preprocess": to_binary, "postprocess": encode_to_base64, }, "automatic-speech-recognition": { # example model: facebook/wav2vec2-base-960h - "inputs": components.Audio(source="upload", type="filepath", label="Input"), - "outputs": components.Textbox(label="Output"), + "inputs": components.Audio( + source="upload", type="filepath", label="Input", render=False + ), + "outputs": components.Textbox(label="Output", render=False), "preprocess": to_binary, "postprocess": lambda r: r.json()["text"], }, "conversational": { - "inputs": [components.Textbox(), components.State()], # type: ignore - "outputs": [components.Chatbot(), components.State()], # type: ignore + "inputs": [components.Textbox(render=False), components.State(render=False)], # type: ignore + "outputs": [components.Chatbot(render=False), components.State(render=False)], # type: ignore "preprocess": chatbot_preprocess, "postprocess": chatbot_postprocess, }, "feature-extraction": { # example model: julien-c/distilbert-feature-extraction - "inputs": components.Textbox(label="Input"), - "outputs": components.Dataframe(label="Output"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Dataframe(label="Output", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r.json()[0], }, "fill-mask": { - "inputs": components.Textbox(label="Input"), - "outputs": components.Label(label="Classification"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: postprocess_label( {i["token_str"]: i["score"] for i in r.json()} @@ -196,8 +202,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg }, "image-classification": { # Example: google/vit-base-patch16-224 - "inputs": components.Image(type="filepath", label="Input Image"), - "outputs": components.Label(label="Classification"), + "inputs": components.Image( + type="filepath", label="Input Image", render=False + ), + "outputs": components.Label(label="Classification", render=False), "preprocess": to_binary, "postprocess": lambda r: postprocess_label( {i["label"].split(", ")[0]: i["score"] for i in r.json()} @@ -206,27 +214,27 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg "question-answering": { # Example: deepset/xlm-roberta-base-squad2 "inputs": [ - components.Textbox(lines=7, label="Context"), - components.Textbox(label="Question"), + components.Textbox(lines=7, label="Context", render=False), + components.Textbox(label="Question", render=False), ], "outputs": [ - components.Textbox(label="Answer"), - components.Label(label="Score"), + components.Textbox(label="Answer", render=False), + components.Label(label="Score", render=False), ], "preprocess": lambda c, q: {"inputs": {"context": c, "question": q}}, "postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}), }, "summarization": { # Example: facebook/bart-large-cnn - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Summary"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Summary", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r.json()[0]["summary_text"], }, "text-classification": { # Example: distilbert-base-uncased-finetuned-sst-2-english - "inputs": components.Textbox(label="Input"), - "outputs": components.Label(label="Classification"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: postprocess_label( {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]} @@ -234,32 +242,34 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg }, "text-generation": { # Example: gpt2 - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Output"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Output", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r.json()[0]["generated_text"], }, "text2text-generation": { # Example: valhalla/t5-small-qa-qg-hl - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Generated Text"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Generated Text", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r.json()[0]["generated_text"], }, "translation": { - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Translation"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Translation", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r.json()[0]["translation_text"], }, "zero-shot-classification": { # Example: facebook/bart-large-mnli "inputs": [ - components.Textbox(label="Input"), - components.Textbox(label="Possible class names (" "comma-separated)"), - components.Checkbox(label="Allow multiple true classes"), + components.Textbox(label="Input", render=False), + components.Textbox( + label="Possible class names (" "comma-separated)", render=False + ), + components.Checkbox(label="Allow multiple true classes", render=False), ], - "outputs": components.Label(label="Classification"), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda i, c, m: { "inputs": i, "parameters": {"candidate_labels": c, "multi_class": m}, @@ -275,15 +285,18 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg # Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens "inputs": [ components.Textbox( - value="That is a happy person", label="Source Sentence" + value="That is a happy person", + label="Source Sentence", + render=False, ), components.Textbox( lines=7, placeholder="Separate each sentence by a newline", label="Sentences to compare to", + render=False, ), ], - "outputs": components.Label(label="Classification"), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda src, sentences: { "inputs": { "source_sentence": src, @@ -296,32 +309,32 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg }, "text-to-speech": { # Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train - "inputs": components.Textbox(label="Input"), - "outputs": components.Audio(label="Audio"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Audio(label="Audio", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": encode_to_base64, }, "text-to-image": { # example model: osanseviero/BigGAN-deep-128 - "inputs": components.Textbox(label="Input"), - "outputs": components.Image(label="Output"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Image(label="Output", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": encode_to_base64, }, "token-classification": { # example model: huggingface-course/bert-finetuned-ner - "inputs": components.Textbox(label="Input"), - "outputs": components.HighlightedText(label="Output"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.HighlightedText(label="Output", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r, # Handled as a special case in query_huggingface_api() }, "document-question-answering": { # example model: impira/layoutlm-document-qa "inputs": [ - components.Image(type="filepath", label="Input Document"), - components.Textbox(label="Question"), + components.Image(type="filepath", label="Input Document", render=False), + components.Textbox(label="Question", render=False), ], - "outputs": components.Label(label="Label"), + "outputs": components.Label(label="Label", render=False), "preprocess": lambda img, q: { "inputs": { "image": extract_base64_data(img), # Extract base64 data @@ -335,10 +348,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg "visual-question-answering": { # example model: dandelin/vilt-b32-finetuned-vqa "inputs": [ - components.Image(type="filepath", label="Input Image"), - components.Textbox(label="Question"), + components.Image(type="filepath", label="Input Image", render=False), + components.Textbox(label="Question", render=False), ], - "outputs": components.Label(label="Label"), + "outputs": components.Label(label="Label", render=False), "preprocess": lambda img, q: { "inputs": { "image": extract_base64_data(img), @@ -351,8 +364,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg }, "image-to-text": { # example model: Salesforce/blip-image-captioning-base - "inputs": components.Image(type="filepath", label="Input Image"), - "outputs": components.Textbox(label="Generated Text"), + "inputs": components.Image( + type="filepath", label="Input Image", render=False + ), + "outputs": components.Textbox(label="Generated Text", render=False), "preprocess": to_binary, "postprocess": lambda r: r.json()[0]["generated_text"], }, @@ -369,9 +384,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg type="pandas", headers=col_names, col_count=(len(col_names), "fixed"), + render=False, ), "outputs": components.Dataframe( - label="Predictions", type="array", headers=["prediction"] + label="Predictions", type="array", headers=["prediction"], render=False ), "preprocess": rows_to_cols, "postprocess": lambda r: { diff --git a/gradio/pipelines.py b/gradio/pipelines.py index 144f1f7ecd..cb088c4d87 100644 --- a/gradio/pipelines.py +++ b/gradio/pipelines.py @@ -35,9 +35,12 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: ): pipeline_info = { "inputs": components.Audio( - source="microphone", type="filepath", label="Input" + source="microphone", + type="filepath", + label="Input", + render=False, ), - "outputs": components.Label(label="Class"), + "outputs": components.Label(label="Class", render=False), "preprocess": lambda i: {"inputs": i}, "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, } @@ -47,9 +50,9 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: ): pipeline_info = { "inputs": components.Audio( - source="microphone", type="filepath", label="Input" + source="microphone", type="filepath", label="Input", render=False ), - "outputs": components.Textbox(label="Output"), + "outputs": components.Textbox(label="Output", render=False), "preprocess": lambda i: {"inputs": i}, "postprocess": lambda r: r["text"], } @@ -57,8 +60,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.feature_extraction.FeatureExtractionPipeline ): pipeline_info = { - "inputs": components.Textbox(label="Input"), - "outputs": components.Dataframe(label="Output"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Dataframe(label="Output", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r[0], } @@ -66,8 +69,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.fill_mask.FillMaskPipeline ): pipeline_info = { - "inputs": components.Textbox(label="Input"), - "outputs": components.Label(label="Classification"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: {i["token_str"]: i["score"] for i in r}, } @@ -75,8 +78,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.image_classification.ImageClassificationPipeline ): pipeline_info = { - "inputs": components.Image(type="filepath", label="Input Image"), - "outputs": components.Label(type="confidences", label="Classification"), + "inputs": components.Image( + type="filepath", label="Input Image", render=False + ), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda i: {"images": i}, "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, } @@ -85,12 +90,12 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: ): pipeline_info = { "inputs": [ - components.Textbox(lines=7, label="Context"), - components.Textbox(label="Question"), + components.Textbox(lines=7, label="Context", render=False), + components.Textbox(label="Question", render=False), ], "outputs": [ - components.Textbox(label="Answer"), - components.Label(label="Score"), + components.Textbox(label="Answer", render=False), + components.Label(label="Score", render=False), ], "preprocess": lambda c, q: {"context": c, "question": q}, "postprocess": lambda r: (r["answer"], r["score"]), @@ -99,8 +104,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.text2text_generation.SummarizationPipeline ): pipeline_info = { - "inputs": components.Textbox(lines=7, label="Input"), - "outputs": components.Textbox(label="Summary"), + "inputs": components.Textbox(lines=7, label="Input", render=False), + "outputs": components.Textbox(label="Summary", render=False), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r[0]["summary_text"], } @@ -108,8 +113,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.text_classification.TextClassificationPipeline ): pipeline_info = { - "inputs": components.Textbox(label="Input"), - "outputs": components.Label(label="Classification"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda x: [x], "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, } @@ -117,8 +122,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.text_generation.TextGenerationPipeline ): pipeline_info = { - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Output"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Output", render=False), "preprocess": lambda x: {"text_inputs": x}, "postprocess": lambda r: r[0]["generated_text"], } @@ -126,8 +131,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.text2text_generation.TranslationPipeline ): pipeline_info = { - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Translation"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Translation", render=False), "preprocess": lambda x: [x], "postprocess": lambda r: r[0]["translation_text"], } @@ -135,8 +140,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline ): pipeline_info = { - "inputs": components.Textbox(label="Input"), - "outputs": components.Textbox(label="Generated Text"), + "inputs": components.Textbox(label="Input", render=False), + "outputs": components.Textbox(label="Generated Text", render=False), "preprocess": lambda x: [x], "postprocess": lambda r: r[0]["generated_text"], } @@ -145,11 +150,13 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: ): pipeline_info = { "inputs": [ - components.Textbox(label="Input"), - components.Textbox(label="Possible class names (" "comma-separated)"), - components.Checkbox(label="Allow multiple true classes"), + components.Textbox(label="Input", render=False), + components.Textbox( + label="Possible class names (" "comma-separated)", render=False + ), + components.Checkbox(label="Allow multiple true classes", render=False), ], - "outputs": components.Label(label="Classification"), + "outputs": components.Label(label="Classification", render=False), "preprocess": lambda i, c, m: { "sequences": i, "candidate_labels": c, @@ -165,10 +172,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: ): pipeline_info = { "inputs": [ - components.Image(type="filepath", label="Input Document"), - components.Textbox(label="Question"), + components.Image(type="filepath", label="Input Document", render=False), + components.Textbox(label="Question", render=False), ], - "outputs": components.Label(label="Label"), + "outputs": components.Label(label="Label", render=False), "preprocess": lambda img, q: {"image": img, "question": q}, "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, } @@ -177,10 +184,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: ): pipeline_info = { "inputs": [ - components.Image(type="filepath", label="Input Image"), - components.Textbox(label="Question"), + components.Image(type="filepath", label="Input Image", render=False), + components.Textbox(label="Question", render=False), ], - "outputs": components.Label(label="Score"), + "outputs": components.Label(label="Score", render=False), "preprocess": lambda img, q: {"image": img, "question": q}, "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, } @@ -188,8 +195,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore ): pipeline_info = { - "inputs": components.Image(type="filepath", label="Input Image"), - "outputs": components.Textbox(label="Text"), + "inputs": components.Image( + type="filepath", label="Input Image", render=False + ), + "outputs": components.Textbox(label="Text", render=False), "preprocess": lambda i: {"images": i}, "postprocess": lambda r: r[0]["generated_text"], } diff --git a/test/test_external.py b/test/test_external.py index 83b627dce9..26a82385b8 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -236,8 +236,9 @@ class TestLoadInterface: except TooManyRequestsError: pass - def test_conversational(self): - io = gr.load("models/microsoft/DialoGPT-medium") + def test_conversational_in_blocks(self): + with gr.Blocks() as io: + gr.load("models/microsoft/DialoGPT-medium") app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) assert app.state_holder == {} diff --git a/test/test_pipelines.py b/test/test_pipelines.py index bdb6633b05..c55fb5167f 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -5,9 +5,21 @@ import gradio as gr @pytest.mark.flaky -class TestLoadFromPipeline: - def test_text_to_text_model_from_pipeline(self): - pipe = transformers.pipeline(model="sshleifer/bart-tiny-random") - io = gr.Interface.from_pipeline(pipe) - output = io("My name is Sylvain and I work at Hugging Face in Brooklyn") - assert isinstance(output, str) +def test_text_to_text_model_from_pipeline(): + pipe = transformers.pipeline(model="sshleifer/bart-tiny-random") + io = gr.Interface.from_pipeline(pipe) + output = io("My name is Sylvain and I work at Hugging Face in Brooklyn") + assert isinstance(output, str) + + +@pytest.mark.flaky +def test_interface_in_blocks(): + pipe1 = transformers.pipeline(model="sshleifer/bart-tiny-random") + pipe2 = transformers.pipeline(model="sshleifer/bart-tiny-random") + with gr.Blocks() as demo: + with gr.Tab("Image Inference"): + gr.Interface.from_pipeline(pipe1) + with gr.Tab("Image Inference"): + gr.Interface.from_pipeline(pipe2) + demo.launch(prevent_thread_lock=True) + demo.close()