mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Support new tasks with Hugging Face integration (#3887)
* Update pipelines.py * Add pipelines that run models locally * Fix typo * Add new pipelines with API * Add changelog * changelog * changes * fix tests * linting --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
parent
b4206cd33a
commit
398115b39a
@ -2,7 +2,7 @@
|
||||
|
||||
## New Features:
|
||||
|
||||
No changes to highlight.
|
||||
- Add support for `visual-question-answering`, `document-question-answering`, and `image-to-text` using `gr.Interface.load("models/...")` and `gr.Interface.from_pipeline` by [@osanseviero](https://github.com/osanseviero) in [PR 3887](https://github.com/gradio-app/gradio/pull/3887)
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
|
@ -24,7 +24,7 @@ from gradio.external_utils import (
|
||||
rows_to_cols,
|
||||
streamline_spaces_interface,
|
||||
)
|
||||
from gradio.processing_utils import to_binary
|
||||
from gradio.processing_utils import extract_base64_data, to_binary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Blocks
|
||||
@ -201,12 +201,6 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
|
||||
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
),
|
||||
},
|
||||
"image-to-text": {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Textbox(),
|
||||
"preprocess": to_binary,
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
"question-answering": {
|
||||
# Example: deepset/xlm-roberta-base-squad2
|
||||
"inputs": [
|
||||
@ -319,6 +313,47 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
|
||||
},
|
||||
"document-question-answering": {
|
||||
# example model: impira/layoutlm-document-qa
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Document"),
|
||||
components.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": components.Label(label="Label"),
|
||||
"preprocess": lambda img, q: {
|
||||
"inputs": {
|
||||
"image": extract_base64_data(img), # Extract base64 data
|
||||
"question": q,
|
||||
}
|
||||
},
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["answer"]: i["score"] for i in r.json()}
|
||||
),
|
||||
},
|
||||
"visual-question-answering": {
|
||||
# example model: dandelin/vilt-b32-finetuned-vqa
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Image"),
|
||||
components.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": components.Label(label="Label"),
|
||||
"preprocess": lambda img, q: {
|
||||
"inputs": {
|
||||
"image": extract_base64_data(img),
|
||||
"question": q,
|
||||
}
|
||||
},
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["answer"]: i["score"] for i in r.json()}
|
||||
),
|
||||
},
|
||||
"image-to-text": {
|
||||
# example model: Salesforce/blip-image-captioning-base
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Textbox(label="Generated Text"),
|
||||
"preprocess": to_binary,
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
}
|
||||
|
||||
if p in ["tabular-classification", "tabular-regression"]:
|
||||
|
@ -91,7 +91,7 @@ class Interface(Blocks):
|
||||
Returns:
|
||||
a Gradio Interface object for the given model
|
||||
"""
|
||||
warnings.warn("gr.Intrerface.load() will be deprecated. Use gr.load() instead.")
|
||||
warnings.warn("gr.Interface.load() will be deprecated. Use gr.load() instead.")
|
||||
return external.load(
|
||||
name=name, src=src, hf_token=api_key, alias=alias, **kwargs
|
||||
)
|
||||
|
@ -159,6 +159,40 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
|
||||
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
|
||||
},
|
||||
}
|
||||
elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance(
|
||||
pipeline,
|
||||
pipelines.document_question_answering.DocumentQuestionAnsweringPipeline, # type: ignore
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Document"),
|
||||
components.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": components.Label(label="Label"),
|
||||
"preprocess": lambda img, q: {"image": img, "question": q},
|
||||
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance(
|
||||
pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Image"),
|
||||
components.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": components.Label(label="Score"),
|
||||
"preprocess": lambda img, q: {"image": img, "question": q},
|
||||
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "ImageToTextPipeline") and isinstance(
|
||||
pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Textbox(label="Text"),
|
||||
"preprocess": lambda i: {"images": i},
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported pipeline type: {type(pipeline)}")
|
||||
|
||||
|
@ -37,6 +37,11 @@ def to_binary(x: str | Dict) -> bytes:
|
||||
return base64.b64decode(base64str.split(",")[1])
|
||||
|
||||
|
||||
def extract_base64_data(x: str) -> str:
|
||||
"""Just extracts the base64 data from a general base64 string."""
|
||||
return x.split("base64,")[1]
|
||||
|
||||
|
||||
#########################
|
||||
# IMAGE PRE-PROCESSING
|
||||
#########################
|
||||
|
@ -220,6 +220,14 @@ class TestLoadInterface:
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_visual_question_answering(self):
|
||||
io = gr.load("models/dandelin/vilt-b32-finetuned-vqa")
|
||||
try:
|
||||
output = io("gradio/test_data/lion.jpg", "What is in the image?")
|
||||
assert isinstance(output, str) and output.endswith(".json")
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_image_to_text(self):
|
||||
io = gr.load("models/nlpconnect/vit-gpt2-image-captioning")
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user