From 398115b39ace7159f1efe3e67f759c24a79d41eb Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Mon, 1 May 2023 19:18:58 +0200 Subject: [PATCH] Support new tasks with Hugging Face integration (#3887) * Update pipelines.py * Add pipelines that run models locally * Fix typo * Add new pipelines with API * Add changelog * changelog * changes * fix tests * linting --------- Co-authored-by: Abubakar Abid Co-authored-by: Freddy Boulton --- CHANGELOG.md | 2 +- gradio/external.py | 49 ++++++++++++++++++++++++++++++++------ gradio/interface.py | 2 +- gradio/pipelines.py | 34 ++++++++++++++++++++++++++ gradio/processing_utils.py | 5 ++++ test/test_external.py | 8 +++++++ 6 files changed, 91 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6b0bfbf8a..1c32a12e3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## New Features: -No changes to highlight. +- Add support for `visual-question-answering`, `document-question-answering`, and `image-to-text` using `gr.Interface.load("models/...")` and `gr.Interface.from_pipeline` by [@osanseviero](https://github.com/osanseviero) in [PR 3887](https://github.com/gradio-app/gradio/pull/3887) ## Bug Fixes: diff --git a/gradio/external.py b/gradio/external.py index 79c7aac5cb..569c24a6bd 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -24,7 +24,7 @@ from gradio.external_utils import ( rows_to_cols, streamline_spaces_interface, ) -from gradio.processing_utils import to_binary +from gradio.processing_utils import extract_base64_data, to_binary if TYPE_CHECKING: from gradio.blocks import Blocks @@ -201,12 +201,6 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs {i["label"].split(", ")[0]: i["score"] for i in r.json()} ), }, - "image-to-text": { - "inputs": components.Image(type="filepath", label="Input Image"), - "outputs": components.Textbox(), - "preprocess": to_binary, - "postprocess": lambda r: r.json()[0]["generated_text"], - }, "question-answering": { # Example: deepset/xlm-roberta-base-squad2 "inputs": [ @@ -319,6 +313,47 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: r, # Handled as a special case in query_huggingface_api() }, + "document-question-answering": { + # example model: impira/layoutlm-document-qa + "inputs": [ + components.Image(type="filepath", label="Input Document"), + components.Textbox(label="Question"), + ], + "outputs": components.Label(label="Label"), + "preprocess": lambda img, q: { + "inputs": { + "image": extract_base64_data(img), # Extract base64 data + "question": q, + } + }, + "postprocess": lambda r: postprocess_label( + {i["answer"]: i["score"] for i in r.json()} + ), + }, + "visual-question-answering": { + # example model: dandelin/vilt-b32-finetuned-vqa + "inputs": [ + components.Image(type="filepath", label="Input Image"), + components.Textbox(label="Question"), + ], + "outputs": components.Label(label="Label"), + "preprocess": lambda img, q: { + "inputs": { + "image": extract_base64_data(img), + "question": q, + } + }, + "postprocess": lambda r: postprocess_label( + {i["answer"]: i["score"] for i in r.json()} + ), + }, + "image-to-text": { + # example model: Salesforce/blip-image-captioning-base + "inputs": components.Image(type="filepath", label="Input Image"), + "outputs": components.Textbox(label="Generated Text"), + "preprocess": to_binary, + "postprocess": lambda r: r.json()[0]["generated_text"], + }, } if p in ["tabular-classification", "tabular-regression"]: diff --git a/gradio/interface.py b/gradio/interface.py index 65e0623824..c55ac16233 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -91,7 +91,7 @@ class Interface(Blocks): Returns: a Gradio Interface object for the given model """ - warnings.warn("gr.Intrerface.load() will be deprecated. Use gr.load() instead.") + warnings.warn("gr.Interface.load() will be deprecated. Use gr.load() instead.") return external.load( name=name, src=src, hf_token=api_key, alias=alias, **kwargs ) diff --git a/gradio/pipelines.py b/gradio/pipelines.py index 2d523e4cd6..73bb15e608 100644 --- a/gradio/pipelines.py +++ b/gradio/pipelines.py @@ -159,6 +159,40 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict: r["labels"][i]: r["scores"][i] for i in range(len(r["labels"])) }, } + elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance( + pipeline, + pipelines.document_question_answering.DocumentQuestionAnsweringPipeline, # type: ignore + ): + pipeline_info = { + "inputs": [ + components.Image(type="filepath", label="Input Document"), + components.Textbox(label="Question"), + ], + "outputs": components.Label(label="Label"), + "preprocess": lambda img, q: {"image": img, "question": q}, + "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, + } + elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance( + pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline + ): + pipeline_info = { + "inputs": [ + components.Image(type="filepath", label="Input Image"), + components.Textbox(label="Question"), + ], + "outputs": components.Label(label="Score"), + "preprocess": lambda img, q: {"image": img, "question": q}, + "postprocess": lambda r: {i["answer"]: i["score"] for i in r}, + } + elif hasattr(transformers, "ImageToTextPipeline") and isinstance( + pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore + ): + pipeline_info = { + "inputs": components.Image(type="filepath", label="Input Image"), + "outputs": components.Textbox(label="Text"), + "preprocess": lambda i: {"images": i}, + "postprocess": lambda r: r[0]["generated_text"], + } else: raise ValueError(f"Unsupported pipeline type: {type(pipeline)}") diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index ad7d027a9b..f39823701b 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -37,6 +37,11 @@ def to_binary(x: str | Dict) -> bytes: return base64.b64decode(base64str.split(",")[1]) +def extract_base64_data(x: str) -> str: + """Just extracts the base64 data from a general base64 string.""" + return x.split("base64,")[1] + + ######################### # IMAGE PRE-PROCESSING ######################### diff --git a/test/test_external.py b/test/test_external.py index 40aa7b9705..6bb1352438 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -220,6 +220,14 @@ class TestLoadInterface: except TooManyRequestsError: pass + def test_visual_question_answering(self): + io = gr.load("models/dandelin/vilt-b32-finetuned-vqa") + try: + output = io("gradio/test_data/lion.jpg", "What is in the image?") + assert isinstance(output, str) and output.endswith(".json") + except TooManyRequestsError: + pass + def test_image_to_text(self): io = gr.load("models/nlpconnect/vit-gpt2-image-captioning") try: