Refactoring external.py (#2579)

* pass token & add typing

* updated docstring

* reorg

* improve docs

* test for private space

* changelog

* changed asserts

* pipelines

* streamline

* formatting

* int

* dataclass

* formatting

* annotations

* formatting

* external test fixes

* formatting

* typing

* addressing review

* fix tests

* dataframedata

* refactoring

* removed unused imports

* added better error message for invalid spaces

* removed unnecessary import
This commit is contained in:
Abubakar Abid 2022-11-01 09:14:15 -07:00 committed by GitHub
parent db85eb2f3a
commit e6cda90b69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 489 additions and 437 deletions

View File

@ -20,14 +20,6 @@ from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
if TYPE_CHECKING:
from typing import TypedDict
class DataframeData(TypedDict):
headers: List[str]
data: List[List[str | int | bool]]
import matplotlib.figure
import numpy as np
import pandas as pd
@ -59,6 +51,14 @@ from gradio.serializing import (
SimpleSerializable,
)
if TYPE_CHECKING:
from typing import TypedDict
class DataframeData(TypedDict):
headers: List[str]
data: List[List[str | int | bool]]
set_documentation_group("component")

View File

@ -3,157 +3,74 @@ use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
from __future__ import annotations
import base64
import json
import math
import numbers
import operator
import re
import uuid
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Callable, Dict
import requests
import websockets
import yaml
from packaging import version
import gradio
from gradio import components, exceptions, utils
from gradio import components, utils
from gradio.exceptions import TooManyRequestsError
from gradio.external_utils import (
cols_to_rows,
encode_to_base64,
get_tabular_examples,
get_ws_fn,
postprocess_label,
rows_to_cols,
streamline_spaces_blocks,
streamline_spaces_interface,
use_websocket,
)
from gradio.processing_utils import to_binary
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.components import DataframeData
from gradio.interface import Interface
def load_blocks_from_repo(
name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs
) -> Blocks:
"""Creates and returns a Blocks instance from several kinds of Hugging Face repos:
1) A model repo
2) A Spaces repo running Gradio 2.x
3) A Spaces repo running Gradio 3.x
"""
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
if src is None:
tokens = name.split(
"/"
) # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
tokens = name.split("/")
assert (
len(tokens) > 1
), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
src = tokens[0]
name = "/".join(tokens[1:])
factory_methods: Dict[str, Callable] = {
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
"huggingface": from_model,
"models": from_model,
"spaces": from_spaces,
}
assert src.lower() in factory_methods, "parameter: src must be one of {}".format(
factory_methods.keys()
)
blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
return blocks
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
if readme.status_code != 200:
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
example_data = {}
else:
yaml_regex = re.search(
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
)
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
example_data = example_yaml.get("widget", {}).get("structuredData", {})
if not example_data:
raise ValueError(
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
"for a reference on how to provide example data to your model."
)
# replace nan with string NaN for inference API
for data in example_data.values():
for i, val in enumerate(data):
if isinstance(val, numbers.Number) and math.isnan(val):
data[i] = "NaN"
return example_data
def cols_to_rows(
example_data: Dict[str, List[float]]
) -> Tuple[List[str], List[List[float]]]:
headers = list(example_data.keys())
n_rows = max(len(example_data[header] or []) for header in headers)
data = []
for row_index in range(n_rows):
row_data = []
for header in headers:
col = example_data[header] or []
if row_index >= len(col):
row_data.append("NaN")
else:
row_data.append(col[row_index])
data.append(row_data)
return headers, data
def rows_to_cols(
incoming_data: DataframeData,
) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
data_column_wise = {}
for i, header in enumerate(incoming_data["headers"]):
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
return {"inputs": {"data": data_column_wise}}
def get_models_interface(model_name: str, api_key: str | None, alias: str, **kwargs):
def from_model(model_name: str, api_key: str | None, alias: str, **kwargs):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
print("Fetching model from: {}".format(model_url))
if api_key is not None:
headers = {"Authorization": f"Bearer {api_key}"}
else:
headers = {}
headers = {"Authorization": f"Bearer {api_key}"} if api_key is not None else {}
# Checking if model exists, and if so, it gets the pipeline
response = requests.request("GET", api_url, headers=headers)
assert response.status_code == 200, "Invalid model name or src"
p = response.json().get("pipeline_tag")
def postprocess_label(scores):
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
return {
"label": sorted_pred[0][0],
"confidences": [
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
],
}
def encode_to_base64(r: requests.Response) -> str:
# Handles the different ways HF API returns the prediction
base64_repr = base64.b64encode(r.content).decode("utf-8")
data_prefix = ";base64,"
# Case 1: base64 representation already includes data prefix
if data_prefix in base64_repr:
return base64_repr
else:
content_type = r.headers.get("content-type")
# Case 2: the data prefix is a key in the response
if content_type == "application/json":
try:
content_type = r.json()[0]["content-type"]
base64_repr = r.json()[0]["blob"]
except KeyError:
raise ValueError(
"Cannot determine content type returned" "by external API."
)
# Case 3: the data prefix is included in the response headers
else:
pass
new_base64 = "data:{};base64,".format(content_type) + base64_repr
return new_base64
pipelines = {
"audio-classification": {
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
@ -393,7 +310,7 @@ def get_models_interface(model_name: str, api_key: str | None, alias: str, **kwa
return interface
def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
space_url = "https://huggingface.co/spaces/{}".format(space_name)
print("Fetching Space from: {}".format(space_url))
@ -409,6 +326,11 @@ def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Bl
.get("host")
)
if iframe_url is None:
raise ValueError(
f"Could not find Space: {space_name}. If it is a private Space, please provide an access token in the `api_key` parameter."
)
r = requests.get(iframe_url, headers=headers)
result = re.search(
@ -419,61 +341,17 @@ def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Bl
except AttributeError:
raise ValueError("Could not load the Space: {}".format(space_name))
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
return get_spaces_interface(
return from_spaces_interface(
space_name, config, alias, api_key, iframe_url, **kwargs
)
else: # Create a Blocks for Gradio 3.x Spaces
return get_spaces_blocks(space_name, config, api_key, iframe_url)
return from_spaces_blocks(space_name, config, api_key, iframe_url)
async def get_pred_from_ws(
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
) -> Dict[str, Any]:
completed = False
while not completed:
msg = await websocket.recv()
resp = json.loads(msg)
if resp["msg"] == "queue_full":
raise exceptions.Error("Queue is full! Please try again.")
if resp["msg"] == "send_hash":
await websocket.send(hash_data)
elif resp["msg"] == "send_data":
await websocket.send(data)
completed = resp["msg"] == "process_completed"
return resp["output"]
def get_ws_fn(ws_url, headers):
async def ws_fn(data, hash_data):
async with websockets.connect(
ws_url, open_timeout=10, extra_headers=headers
) as websocket:
return await get_pred_from_ws(websocket, data, hash_data)
return ws_fn
def use_websocket(config, dependency):
queue_enabled = config.get("enable_queue", False)
queue_uses_websocket = version.parse(
config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
def get_spaces_blocks(
def from_spaces_blocks(
model_name: str, config: Dict, api_key: str | None, iframe_url: str
) -> Blocks:
def streamline_config(config: dict) -> dict:
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
# TODO(abidlabs): Need a better way to fix relative paths in dataset component
for c, component in enumerate(config["components"]):
if component["type"] == "dataset":
config["components"][c]["props"]["visible"] = False
return config
config = streamline_config(config)
config = streamline_spaces_blocks(config)
api_url = "{}/api/predict/".format(iframe_url)
headers = {"Content-Type": "application/json"}
if api_key is not None:
@ -523,7 +401,7 @@ def get_spaces_blocks(
return gradio.Blocks.from_config(config, fns)
def get_spaces_interface(
def from_spaces_interface(
model_name: str,
config: Dict,
alias: str,
@ -531,29 +409,8 @@ def get_spaces_interface(
iframe_url: str,
**kwargs,
) -> Interface:
def streamline_config(config: Dict) -> Dict:
"""Streamlines the interface config dictionary to remove unnecessary keys."""
config["inputs"] = [
components.get_component_instance(component)
for component in config["input_components"]
]
config["outputs"] = [
components.get_component_instance(component)
for component in config["output_components"]
]
parameters = {
"article",
"description",
"flagging_options",
"inputs",
"outputs",
"theme",
"title",
}
config = {k: config[k] for k in parameters}
return config
config = streamline_config(config)
config = streamline_spaces_interface(config)
api_url = "{}/api/predict/".format(iframe_url)
headers = {"Content-Type": "application/json"}
if api_key is not None:
@ -589,189 +446,3 @@ def get_spaces_interface(
kwargs["_api_mode"] = True
interface = gradio.Interface(**kwargs)
return interface
factory_methods: Dict[str, Callable] = {
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
"huggingface": get_models_interface,
"models": get_models_interface,
"spaces": get_spaces,
}
def load_from_pipeline(pipeline) -> Dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
try:
import transformers
except ImportError:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
)
if not isinstance(pipeline, transformers.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
pipeline, transformers.AudioClassificationPipeline
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
),
"outputs": components.Label(label="Class"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
pipeline, transformers.AutomaticSpeechRecognitionPipeline
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
}
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
pipeline, transformers.FeatureExtractionPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Dataframe(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0],
}
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
pipeline, transformers.FillMaskPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
pipeline, transformers.ImageClassificationPipeline
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Label(type="confidences", label="Classification"),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
pipeline, transformers.QuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(lines=7, label="Context"),
components.Textbox(label="Question"),
],
"outputs": [
components.Textbox(label="Answer"),
components.Label(label="Score"),
],
"preprocess": lambda c, q: {"context": c, "question": q},
"postprocess": lambda r: (r["answer"], r["score"]),
}
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
pipeline, transformers.SummarizationPipeline
):
pipeline_info = {
"inputs": components.Textbox(lines=7, label="Input"),
"outputs": components.Textbox(label="Summary"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0]["summary_text"],
}
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
pipeline, transformers.TextClassificationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
pipeline, transformers.TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda x: {"text_inputs": x},
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "TranslationPipeline") and isinstance(
pipeline, transformers.TranslationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Translation"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["translation_text"],
}
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
pipeline, transformers.Text2TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Generated Text"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
pipeline, transformers.ZeroShotClassificationPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(label="Input"),
components.Textbox(label="Possible class names (" "comma-separated)"),
components.Checkbox(label="Allow multiple true classes"),
],
"outputs": components.Label(label="Classification"),
"preprocess": lambda i, c, m: {
"sequences": i,
"candidate_labels": c,
"multi_label": m,
},
"postprocess": lambda r: {
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
},
}
else:
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
# define the function that will be called by the Interface
def fn(*params):
data = pipeline_info["preprocess"](*params)
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.Text2TextGenerationPipeline,
transformers.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data)
output = pipeline_info["postprocess"](data)
return output
interface_info = pipeline_info.copy()
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = pipeline.model.__class__.__name__
return interface_info

189
gradio/external_utils.py Normal file
View File

@ -0,0 +1,189 @@
"""Utility function for gradio/external.py"""
import base64
import json
import math
import numbers
import operator
import re
import warnings
from typing import Any, Dict, List, Tuple
import requests
import websockets
import yaml
from packaging import version
from gradio import components, exceptions
##################
# Helper functions for processing tabular data
##################
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
if readme.status_code != 200:
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
example_data = {}
else:
yaml_regex = re.search(
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
)
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
example_data = example_yaml.get("widget", {}).get("structuredData", {})
if not example_data:
raise ValueError(
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
"for a reference on how to provide example data to your model."
)
# replace nan with string NaN for inference API
for data in example_data.values():
for i, val in enumerate(data):
if isinstance(val, numbers.Number) and math.isnan(val):
data[i] = "NaN"
return example_data
def cols_to_rows(
example_data: Dict[str, List[float]]
) -> Tuple[List[str], List[List[float]]]:
headers = list(example_data.keys())
n_rows = max(len(example_data[header] or []) for header in headers)
data = []
for row_index in range(n_rows):
row_data = []
for header in headers:
col = example_data[header] or []
if row_index >= len(col):
row_data.append("NaN")
else:
row_data.append(col[row_index])
data.append(row_data)
return headers, data
def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
data_column_wise = {}
for i, header in enumerate(incoming_data["headers"]):
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
return {"inputs": {"data": data_column_wise}}
##################
# Helper functions for processing other kinds of data
##################
def postprocess_label(scores):
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
return {
"label": sorted_pred[0][0],
"confidences": [
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
],
}
def encode_to_base64(r: requests.Response) -> str:
# Handles the different ways HF API returns the prediction
base64_repr = base64.b64encode(r.content).decode("utf-8")
data_prefix = ";base64,"
# Case 1: base64 representation already includes data prefix
if data_prefix in base64_repr:
return base64_repr
else:
content_type = r.headers.get("content-type")
# Case 2: the data prefix is a key in the response
if content_type == "application/json":
try:
content_type = r.json()[0]["content-type"]
base64_repr = r.json()[0]["blob"]
except KeyError:
raise ValueError(
"Cannot determine content type returned" "by external API."
)
# Case 3: the data prefix is included in the response headers
else:
pass
new_base64 = "data:{};base64,".format(content_type) + base64_repr
return new_base64
##################
# Helper functions for connecting to websockets
##################
async def get_pred_from_ws(
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
) -> Dict[str, Any]:
completed = False
while not completed:
msg = await websocket.recv()
resp = json.loads(msg)
if resp["msg"] == "queue_full":
raise exceptions.Error("Queue is full! Please try again.")
if resp["msg"] == "send_hash":
await websocket.send(hash_data)
elif resp["msg"] == "send_data":
await websocket.send(data)
completed = resp["msg"] == "process_completed"
return resp["output"]
def get_ws_fn(ws_url, headers):
async def ws_fn(data, hash_data):
async with websockets.connect(
ws_url, open_timeout=10, extra_headers=headers
) as websocket:
return await get_pred_from_ws(websocket, data, hash_data)
return ws_fn
def use_websocket(config, dependency):
queue_enabled = config.get("enable_queue", False)
queue_uses_websocket = version.parse(
config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
##################
# Helper functions for cleaning up Interfaces/Blocks loaded from HF Spaces
##################
def streamline_spaces_interface(config: Dict) -> Dict:
"""Streamlines the interface config dictionary to remove unnecessary keys."""
config["inputs"] = [
components.get_component_instance(component)
for component in config["input_components"]
]
config["outputs"] = [
components.get_component_instance(component)
for component in config["output_components"]
]
parameters = {
"article",
"description",
"flagging_options",
"inputs",
"outputs",
"theme",
"title",
}
config = {k: config[k] for k in parameters}
return config
def streamline_spaces_blocks(config: dict) -> dict:
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
# TODO(abidlabs): Need a better way to fix relative paths in dataset component
for c, component in enumerate(config["components"]):
if component["type"] == "dataset":
config["components"][c]["props"]["visible"] = False
return config

View File

@ -31,9 +31,9 @@ from gradio.components import (
)
from gradio.documentation import document, set_documentation_group
from gradio.events import Changeable, Streamable
from gradio.external import load_from_pipeline # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.layouts import Column, Row, TabItem, Tabs
from gradio.pipelines import load_from_pipeline # type: ignore
set_documentation_group("interface")

189
gradio/pipelines.py Normal file
View File

@ -0,0 +1,189 @@
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict
from gradio import components
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
try:
import transformers
except ImportError:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
)
if not isinstance(pipeline, transformers.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
pipeline, transformers.AudioClassificationPipeline
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
),
"outputs": components.Label(label="Class"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
pipeline, transformers.AutomaticSpeechRecognitionPipeline
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
}
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
pipeline, transformers.FeatureExtractionPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Dataframe(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0],
}
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
pipeline, transformers.FillMaskPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
pipeline, transformers.ImageClassificationPipeline
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Label(type="confidences", label="Classification"),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
pipeline, transformers.QuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(lines=7, label="Context"),
components.Textbox(label="Question"),
],
"outputs": [
components.Textbox(label="Answer"),
components.Label(label="Score"),
],
"preprocess": lambda c, q: {"context": c, "question": q},
"postprocess": lambda r: (r["answer"], r["score"]),
}
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
pipeline, transformers.SummarizationPipeline
):
pipeline_info = {
"inputs": components.Textbox(lines=7, label="Input"),
"outputs": components.Textbox(label="Summary"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0]["summary_text"],
}
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
pipeline, transformers.TextClassificationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
pipeline, transformers.TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda x: {"text_inputs": x},
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "TranslationPipeline") and isinstance(
pipeline, transformers.TranslationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Translation"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["translation_text"],
}
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
pipeline, transformers.Text2TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Generated Text"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
pipeline, transformers.ZeroShotClassificationPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(label="Input"),
components.Textbox(label="Possible class names (" "comma-separated)"),
components.Checkbox(label="Allow multiple true classes"),
],
"outputs": components.Label(label="Classification"),
"preprocess": lambda i, c, m: {
"sequences": i,
"candidate_labels": c,
"multi_label": m,
},
"postprocess": lambda r: {
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
},
}
else:
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
# define the function that will be called by the Interface
def fn(*params):
data = pipeline_info["preprocess"](*params)
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.Text2TextGenerationPipeline,
transformers.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data)
output = pipeline_info["postprocess"](data)
return output
interface_info = pipeline_info.copy()
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = pipeline.model.__class__.__name__
return interface_info

View File

@ -7,7 +7,6 @@ import unittest
from unittest.mock import MagicMock, patch
import pytest
import transformers
import gradio
import gradio as gr
@ -15,10 +14,10 @@ from gradio import utils
from gradio.external import (
TooManyRequestsError,
cols_to_rows,
get_pred_from_ws,
get_tabular_examples,
use_websocket,
)
from gradio.external_utils import get_pred_from_ws
"""
WARNING: These tests have an external dependency: namely that Hugging Face's
@ -43,52 +42,52 @@ class TestLoadInterface(unittest.TestCase):
src="models",
alias=model_type,
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Audio)
assert isinstance(interface.output_components[0], gr.components.Audio)
def test_question_answering(self):
model_type = "image-classification"
interface = gr.Blocks.load(
name="lysandre/tiny-vit-random", src="models", alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Image)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Image)
assert isinstance(interface.output_components[0], gr.components.Label)
def test_text_generation(self):
model_type = "text_generation"
interface = gr.Interface.load("models/gpt2", alias=model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Textbox)
def test_summarization(self):
model_type = "summarization"
interface = gr.Interface.load(
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Textbox)
def test_translation(self):
model_type = "translation"
interface = gr.Interface.load(
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Textbox)
def test_text2text_generation(self):
model_type = "text2text-generation"
interface = gr.Interface.load(
"models/sshleifer/tiny-mbart", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Textbox)
def test_text_classification(self):
model_type = "text-classification"
@ -97,47 +96,47 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Label)
def test_fill_mask(self):
model_type = "fill-mask"
interface = gr.Interface.load(
"models/bert-base-uncased", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Label)
def test_zero_shot_classification(self):
model_type = "zero-shot-classification"
interface = gr.Interface.load(
"models/facebook/bart-large-mnli", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.input_components[1], gr.components.Textbox)
self.assertIsInstance(interface.input_components[2], gr.components.Checkbox)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.input_components[1], gr.components.Textbox)
assert isinstance(interface.input_components[2], gr.components.Checkbox)
assert isinstance(interface.output_components[0], gr.components.Label)
def test_automatic_speech_recognition(self):
model_type = "automatic-speech-recognition"
interface = gr.Interface.load(
"models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Audio)
assert isinstance(interface.output_components[0], gr.components.Textbox)
def test_image_classification(self):
model_type = "image-classification"
interface = gr.Interface.load(
"models/google/vit-base-patch16-224", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Image)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Image)
assert isinstance(interface.output_components[0], gr.components.Label)
def test_feature_extraction(self):
model_type = "feature-extraction"
@ -146,9 +145,9 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Dataframe)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Dataframe)
def test_sentence_similarity(self):
model_type = "text-to-speech"
@ -157,9 +156,9 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Audio)
def test_text_to_speech(self):
model_type = "text-to-speech"
@ -168,23 +167,23 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Audio)
def test_text_to_image(self):
model_type = "text-to-image"
interface = gr.Interface.load(
"models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Image)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Image)
def test_english_to_spanish(self):
interface = gr.Interface.load("spaces/abidlabs/english_to_spanish")
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Textbox)
def test_sentiment_model(self):
io = gr.Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english")
@ -206,7 +205,7 @@ class TestLoadInterface(unittest.TestCase):
io = gr.Blocks.load(name="models/t5-base")
try:
output = io("My name is Sarah and I live in London")
self.assertEqual(output, "Mein Name ist Sarah und ich lebe in London")
assert output == "Mein Name ist Sarah und ich lebe in London"
except TooManyRequestsError:
pass
@ -241,18 +240,11 @@ class TestLoadInterface(unittest.TestCase):
)
try:
output = io("abc")
self.assertEqual(output, "abc")
assert output == "abc"
except TooManyRequestsError:
pass
class TestLoadFromPipeline(unittest.TestCase):
def test_text_to_text_model_from_pipeline(self):
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
output = pipe("My name is Sylvain and I work at Hugging Face in Brooklyn")
self.assertIsNotNone(output)
class TestLoadInterfaceWithExamples:
def test_interface_load_examples(self, tmp_path):
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
@ -419,7 +411,7 @@ async def test_get_pred_from_ws_raises_if_queue_full():
def test_respect_queue_when_load_from_config():
with unittest.mock.patch("websockets.connect"):
with unittest.mock.patch(
"gradio.external.get_pred_from_ws", return_value={"data": ["foo"]}
"gradio.external_utils.get_pred_from_ws", return_value={"data": ["foo"]}
):
interface = gr.Interface.load("spaces/freddyaboulton/saymyname")
assert interface("bob") == "foo"

11
test/test_pipelines.py Normal file
View File

@ -0,0 +1,11 @@
import transformers
import gradio as gr
class TestLoadFromPipeline:
def test_text_to_text_model_from_pipeline(self):
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
io = gr.Interface.from_pipeline(pipe)
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
assert isinstance(output, str)