Support new tasks with Hugging Face integration (#3887)

* Update pipelines.py

* Add pipelines that run models locally

* Fix typo

* Add new pipelines with API

* Add changelog

* changelog

* changes

* fix tests

* linting

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Omar Sanseviero 2023-05-01 19:18:58 +02:00 committed by GitHub
parent b4206cd33a
commit 398115b39a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 9 deletions

View File

@ -2,7 +2,7 @@
## New Features:
No changes to highlight.
- Add support for `visual-question-answering`, `document-question-answering`, and `image-to-text` using `gr.Interface.load("models/...")` and `gr.Interface.from_pipeline` by [@osanseviero](https://github.com/osanseviero) in [PR 3887](https://github.com/gradio-app/gradio/pull/3887)
## Bug Fixes:

View File

@ -24,7 +24,7 @@ from gradio.external_utils import (
rows_to_cols,
streamline_spaces_interface,
)
from gradio.processing_utils import to_binary
from gradio.processing_utils import extract_base64_data, to_binary
if TYPE_CHECKING:
from gradio.blocks import Blocks
@ -201,12 +201,6 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
),
},
"image-to-text": {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Textbox(),
"preprocess": to_binary,
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"question-answering": {
# Example: deepset/xlm-roberta-base-squad2
"inputs": [
@ -319,6 +313,47 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
},
"document-question-answering": {
# example model: impira/layoutlm-document-qa
"inputs": [
components.Image(type="filepath", label="Input Document"),
components.Textbox(label="Question"),
],
"outputs": components.Label(label="Label"),
"preprocess": lambda img, q: {
"inputs": {
"image": extract_base64_data(img), # Extract base64 data
"question": q,
}
},
"postprocess": lambda r: postprocess_label(
{i["answer"]: i["score"] for i in r.json()}
),
},
"visual-question-answering": {
# example model: dandelin/vilt-b32-finetuned-vqa
"inputs": [
components.Image(type="filepath", label="Input Image"),
components.Textbox(label="Question"),
],
"outputs": components.Label(label="Label"),
"preprocess": lambda img, q: {
"inputs": {
"image": extract_base64_data(img),
"question": q,
}
},
"postprocess": lambda r: postprocess_label(
{i["answer"]: i["score"] for i in r.json()}
),
},
"image-to-text": {
# example model: Salesforce/blip-image-captioning-base
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Textbox(label="Generated Text"),
"preprocess": to_binary,
"postprocess": lambda r: r.json()[0]["generated_text"],
},
}
if p in ["tabular-classification", "tabular-regression"]:

View File

@ -91,7 +91,7 @@ class Interface(Blocks):
Returns:
a Gradio Interface object for the given model
"""
warnings.warn("gr.Intrerface.load() will be deprecated. Use gr.load() instead.")
warnings.warn("gr.Interface.load() will be deprecated. Use gr.load() instead.")
return external.load(
name=name, src=src, hf_token=api_key, alias=alias, **kwargs
)

View File

@ -159,6 +159,40 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
},
}
elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance(
pipeline,
pipelines.document_question_answering.DocumentQuestionAnsweringPipeline, # type: ignore
):
pipeline_info = {
"inputs": [
components.Image(type="filepath", label="Input Document"),
components.Textbox(label="Question"),
],
"outputs": components.Label(label="Label"),
"preprocess": lambda img, q: {"image": img, "question": q},
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
}
elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance(
pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
components.Image(type="filepath", label="Input Image"),
components.Textbox(label="Question"),
],
"outputs": components.Label(label="Score"),
"preprocess": lambda img, q: {"image": img, "question": q},
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageToTextPipeline") and isinstance(
pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Textbox(label="Text"),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: r[0]["generated_text"],
}
else:
raise ValueError(f"Unsupported pipeline type: {type(pipeline)}")

View File

@ -37,6 +37,11 @@ def to_binary(x: str | Dict) -> bytes:
return base64.b64decode(base64str.split(",")[1])
def extract_base64_data(x: str) -> str:
"""Just extracts the base64 data from a general base64 string."""
return x.split("base64,")[1]
#########################
# IMAGE PRE-PROCESSING
#########################

View File

@ -220,6 +220,14 @@ class TestLoadInterface:
except TooManyRequestsError:
pass
def test_visual_question_answering(self):
io = gr.load("models/dandelin/vilt-b32-finetuned-vqa")
try:
output = io("gradio/test_data/lion.jpg", "What is in the image?")
assert isinstance(output, str) and output.endswith(".json")
except TooManyRequestsError:
pass
def test_image_to_text(self):
io = gr.load("models/nlpconnect/vit-gpt2-image-captioning")
try: