mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Refactoring external.py
(#2579)
* pass token & add typing * updated docstring * reorg * improve docs * test for private space * changelog * changed asserts * pipelines * streamline * formatting * int * dataclass * formatting * annotations * formatting * external test fixes * formatting * typing * addressing review * fix tests * dataframedata * refactoring * removed unused imports * added better error message for invalid spaces * removed unnecessary import
This commit is contained in:
parent
db85eb2f3a
commit
e6cda90b69
@ -20,14 +20,6 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypedDict
|
||||
|
||||
class DataframeData(TypedDict):
|
||||
headers: List[str]
|
||||
data: List[List[str | int | bool]]
|
||||
|
||||
|
||||
import matplotlib.figure
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -59,6 +51,14 @@ from gradio.serializing import (
|
||||
SimpleSerializable,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypedDict
|
||||
|
||||
class DataframeData(TypedDict):
|
||||
headers: List[str]
|
||||
data: List[List[str | int | bool]]
|
||||
|
||||
|
||||
set_documentation_group("component")
|
||||
|
||||
|
||||
|
@ -3,157 +3,74 @@ use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import numbers
|
||||
import operator
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Dict
|
||||
|
||||
import requests
|
||||
import websockets
|
||||
import yaml
|
||||
from packaging import version
|
||||
|
||||
import gradio
|
||||
from gradio import components, exceptions, utils
|
||||
from gradio import components, utils
|
||||
from gradio.exceptions import TooManyRequestsError
|
||||
from gradio.external_utils import (
|
||||
cols_to_rows,
|
||||
encode_to_base64,
|
||||
get_tabular_examples,
|
||||
get_ws_fn,
|
||||
postprocess_label,
|
||||
rows_to_cols,
|
||||
streamline_spaces_blocks,
|
||||
streamline_spaces_interface,
|
||||
use_websocket,
|
||||
)
|
||||
from gradio.processing_utils import to_binary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.components import DataframeData
|
||||
from gradio.interface import Interface
|
||||
|
||||
|
||||
def load_blocks_from_repo(
|
||||
name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs
|
||||
) -> Blocks:
|
||||
"""Creates and returns a Blocks instance from several kinds of Hugging Face repos:
|
||||
1) A model repo
|
||||
2) A Spaces repo running Gradio 2.x
|
||||
3) A Spaces repo running Gradio 3.x
|
||||
"""
|
||||
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
|
||||
if src is None:
|
||||
tokens = name.split(
|
||||
"/"
|
||||
) # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
|
||||
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
|
||||
tokens = name.split("/")
|
||||
assert (
|
||||
len(tokens) > 1
|
||||
), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
|
||||
src = tokens[0]
|
||||
name = "/".join(tokens[1:])
|
||||
|
||||
factory_methods: Dict[str, Callable] = {
|
||||
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
|
||||
"huggingface": from_model,
|
||||
"models": from_model,
|
||||
"spaces": from_spaces,
|
||||
}
|
||||
assert src.lower() in factory_methods, "parameter: src must be one of {}".format(
|
||||
factory_methods.keys()
|
||||
)
|
||||
|
||||
blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
|
||||
return blocks
|
||||
|
||||
|
||||
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
|
||||
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
|
||||
if readme.status_code != 200:
|
||||
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
|
||||
example_data = {}
|
||||
else:
|
||||
yaml_regex = re.search(
|
||||
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
|
||||
)
|
||||
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
|
||||
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
||||
if not example_data:
|
||||
raise ValueError(
|
||||
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
|
||||
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
|
||||
"for a reference on how to provide example data to your model."
|
||||
)
|
||||
# replace nan with string NaN for inference API
|
||||
for data in example_data.values():
|
||||
for i, val in enumerate(data):
|
||||
if isinstance(val, numbers.Number) and math.isnan(val):
|
||||
data[i] = "NaN"
|
||||
return example_data
|
||||
|
||||
|
||||
def cols_to_rows(
|
||||
example_data: Dict[str, List[float]]
|
||||
) -> Tuple[List[str], List[List[float]]]:
|
||||
headers = list(example_data.keys())
|
||||
n_rows = max(len(example_data[header] or []) for header in headers)
|
||||
data = []
|
||||
for row_index in range(n_rows):
|
||||
row_data = []
|
||||
for header in headers:
|
||||
col = example_data[header] or []
|
||||
if row_index >= len(col):
|
||||
row_data.append("NaN")
|
||||
else:
|
||||
row_data.append(col[row_index])
|
||||
data.append(row_data)
|
||||
return headers, data
|
||||
|
||||
|
||||
def rows_to_cols(
|
||||
incoming_data: DataframeData,
|
||||
) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
|
||||
data_column_wise = {}
|
||||
for i, header in enumerate(incoming_data["headers"]):
|
||||
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
|
||||
return {"inputs": {"data": data_column_wise}}
|
||||
|
||||
|
||||
def get_models_interface(model_name: str, api_key: str | None, alias: str, **kwargs):
|
||||
def from_model(model_name: str, api_key: str | None, alias: str, **kwargs):
|
||||
model_url = "https://huggingface.co/{}".format(model_name)
|
||||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
||||
print("Fetching model from: {}".format(model_url))
|
||||
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
headers = {}
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key is not None else {}
|
||||
|
||||
# Checking if model exists, and if so, it gets the pipeline
|
||||
response = requests.request("GET", api_url, headers=headers)
|
||||
assert response.status_code == 200, "Invalid model name or src"
|
||||
p = response.json().get("pipeline_tag")
|
||||
|
||||
def postprocess_label(scores):
|
||||
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
|
||||
],
|
||||
}
|
||||
|
||||
def encode_to_base64(r: requests.Response) -> str:
|
||||
# Handles the different ways HF API returns the prediction
|
||||
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
||||
data_prefix = ";base64,"
|
||||
# Case 1: base64 representation already includes data prefix
|
||||
if data_prefix in base64_repr:
|
||||
return base64_repr
|
||||
else:
|
||||
content_type = r.headers.get("content-type")
|
||||
# Case 2: the data prefix is a key in the response
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
content_type = r.json()[0]["content-type"]
|
||||
base64_repr = r.json()[0]["blob"]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Cannot determine content type returned" "by external API."
|
||||
)
|
||||
# Case 3: the data prefix is included in the response headers
|
||||
else:
|
||||
pass
|
||||
new_base64 = "data:{};base64,".format(content_type) + base64_repr
|
||||
return new_base64
|
||||
|
||||
pipelines = {
|
||||
"audio-classification": {
|
||||
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
||||
@ -393,7 +310,7 @@ def get_models_interface(model_name: str, api_key: str | None, alias: str, **kwa
|
||||
return interface
|
||||
|
||||
|
||||
def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
|
||||
def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
|
||||
space_url = "https://huggingface.co/spaces/{}".format(space_name)
|
||||
print("Fetching Space from: {}".format(space_url))
|
||||
|
||||
@ -409,6 +326,11 @@ def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Bl
|
||||
.get("host")
|
||||
)
|
||||
|
||||
if iframe_url is None:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {space_name}. If it is a private Space, please provide an access token in the `api_key` parameter."
|
||||
)
|
||||
|
||||
r = requests.get(iframe_url, headers=headers)
|
||||
|
||||
result = re.search(
|
||||
@ -419,61 +341,17 @@ def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Bl
|
||||
except AttributeError:
|
||||
raise ValueError("Could not load the Space: {}".format(space_name))
|
||||
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
|
||||
return get_spaces_interface(
|
||||
return from_spaces_interface(
|
||||
space_name, config, alias, api_key, iframe_url, **kwargs
|
||||
)
|
||||
else: # Create a Blocks for Gradio 3.x Spaces
|
||||
return get_spaces_blocks(space_name, config, api_key, iframe_url)
|
||||
return from_spaces_blocks(space_name, config, api_key, iframe_url)
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
|
||||
) -> Dict[str, Any]:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = await websocket.recv()
|
||||
resp = json.loads(msg)
|
||||
if resp["msg"] == "queue_full":
|
||||
raise exceptions.Error("Queue is full! Please try again.")
|
||||
if resp["msg"] == "send_hash":
|
||||
await websocket.send(hash_data)
|
||||
elif resp["msg"] == "send_data":
|
||||
await websocket.send(data)
|
||||
completed = resp["msg"] == "process_completed"
|
||||
return resp["output"]
|
||||
|
||||
|
||||
def get_ws_fn(ws_url, headers):
|
||||
async def ws_fn(data, hash_data):
|
||||
async with websockets.connect(
|
||||
ws_url, open_timeout=10, extra_headers=headers
|
||||
) as websocket:
|
||||
return await get_pred_from_ws(websocket, data, hash_data)
|
||||
|
||||
return ws_fn
|
||||
|
||||
|
||||
def use_websocket(config, dependency):
|
||||
queue_enabled = config.get("enable_queue", False)
|
||||
queue_uses_websocket = version.parse(
|
||||
config.get("version", "2.0")
|
||||
) >= version.Version("3.2")
|
||||
dependency_uses_queue = dependency.get("queue", False) is not False
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
|
||||
def get_spaces_blocks(
|
||||
def from_spaces_blocks(
|
||||
model_name: str, config: Dict, api_key: str | None, iframe_url: str
|
||||
) -> Blocks:
|
||||
def streamline_config(config: dict) -> dict:
|
||||
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
|
||||
# TODO(abidlabs): Need a better way to fix relative paths in dataset component
|
||||
for c, component in enumerate(config["components"]):
|
||||
if component["type"] == "dataset":
|
||||
config["components"][c]["props"]["visible"] = False
|
||||
return config
|
||||
|
||||
config = streamline_config(config)
|
||||
config = streamline_spaces_blocks(config)
|
||||
api_url = "{}/api/predict/".format(iframe_url)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
@ -523,7 +401,7 @@ def get_spaces_blocks(
|
||||
return gradio.Blocks.from_config(config, fns)
|
||||
|
||||
|
||||
def get_spaces_interface(
|
||||
def from_spaces_interface(
|
||||
model_name: str,
|
||||
config: Dict,
|
||||
alias: str,
|
||||
@ -531,29 +409,8 @@ def get_spaces_interface(
|
||||
iframe_url: str,
|
||||
**kwargs,
|
||||
) -> Interface:
|
||||
def streamline_config(config: Dict) -> Dict:
|
||||
"""Streamlines the interface config dictionary to remove unnecessary keys."""
|
||||
config["inputs"] = [
|
||||
components.get_component_instance(component)
|
||||
for component in config["input_components"]
|
||||
]
|
||||
config["outputs"] = [
|
||||
components.get_component_instance(component)
|
||||
for component in config["output_components"]
|
||||
]
|
||||
parameters = {
|
||||
"article",
|
||||
"description",
|
||||
"flagging_options",
|
||||
"inputs",
|
||||
"outputs",
|
||||
"theme",
|
||||
"title",
|
||||
}
|
||||
config = {k: config[k] for k in parameters}
|
||||
return config
|
||||
|
||||
config = streamline_config(config)
|
||||
config = streamline_spaces_interface(config)
|
||||
api_url = "{}/api/predict/".format(iframe_url)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
@ -589,189 +446,3 @@ def get_spaces_interface(
|
||||
kwargs["_api_mode"] = True
|
||||
interface = gradio.Interface(**kwargs)
|
||||
return interface
|
||||
|
||||
|
||||
factory_methods: Dict[str, Callable] = {
|
||||
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
|
||||
"huggingface": get_models_interface,
|
||||
"models": get_models_interface,
|
||||
"spaces": get_spaces,
|
||||
}
|
||||
|
||||
|
||||
def load_from_pipeline(pipeline) -> Dict:
|
||||
"""
|
||||
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
|
||||
Returns:
|
||||
(dict): a dictionary of kwargs that can be used to construct an Interface object
|
||||
"""
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"transformers not installed. Please try `pip install transformers`"
|
||||
)
|
||||
if not isinstance(pipeline, transformers.Pipeline):
|
||||
raise ValueError("pipeline must be a transformers.Pipeline")
|
||||
|
||||
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
|
||||
# version of the transformers library that the user has installed.
|
||||
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.AudioClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
source="microphone", type="filepath", label="Input"
|
||||
),
|
||||
"outputs": components.Label(label="Class"),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
|
||||
pipeline, transformers.AutomaticSpeechRecognitionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
source="microphone", type="filepath", label="Input"
|
||||
),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: r["text"],
|
||||
}
|
||||
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
|
||||
pipeline, transformers.FeatureExtractionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Dataframe(label="Output"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0],
|
||||
}
|
||||
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
|
||||
pipeline, transformers.FillMaskPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ImageClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Label(type="confidences", label="Classification"),
|
||||
"preprocess": lambda i: {"images": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
|
||||
pipeline, transformers.QuestionAnsweringPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Textbox(lines=7, label="Context"),
|
||||
components.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": [
|
||||
components.Textbox(label="Answer"),
|
||||
components.Label(label="Score"),
|
||||
],
|
||||
"preprocess": lambda c, q: {"context": c, "question": q},
|
||||
"postprocess": lambda r: (r["answer"], r["score"]),
|
||||
}
|
||||
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
|
||||
pipeline, transformers.SummarizationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(lines=7, label="Input"),
|
||||
"outputs": components.Textbox(label="Summary"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0]["summary_text"],
|
||||
}
|
||||
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.TextClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"preprocess": lambda x: {"text_inputs": x},
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, "TranslationPipeline") and isinstance(
|
||||
pipeline, transformers.TranslationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Translation"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["translation_text"],
|
||||
}
|
||||
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.Text2TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Generated Text"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ZeroShotClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Textbox(label="Input"),
|
||||
components.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
components.Checkbox(label="Allow multiple true classes"),
|
||||
],
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"sequences": i,
|
||||
"candidate_labels": c,
|
||||
"multi_label": m,
|
||||
},
|
||||
"postprocess": lambda r: {
|
||||
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
|
||||
|
||||
# define the function that will be called by the Interface
|
||||
def fn(*params):
|
||||
data = pipeline_info["preprocess"](*params)
|
||||
# special cases that needs to be handled differently
|
||||
if isinstance(
|
||||
pipeline,
|
||||
(
|
||||
transformers.TextClassificationPipeline,
|
||||
transformers.Text2TextGenerationPipeline,
|
||||
transformers.TranslationPipeline,
|
||||
),
|
||||
):
|
||||
data = pipeline(*data)
|
||||
else:
|
||||
data = pipeline(**data)
|
||||
output = pipeline_info["postprocess"](data)
|
||||
return output
|
||||
|
||||
interface_info = pipeline_info.copy()
|
||||
interface_info["fn"] = fn
|
||||
del interface_info["preprocess"]
|
||||
del interface_info["postprocess"]
|
||||
|
||||
# define the title/description of the Interface
|
||||
interface_info["title"] = pipeline.model.__class__.__name__
|
||||
|
||||
return interface_info
|
||||
|
189
gradio/external_utils.py
Normal file
189
gradio/external_utils.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""Utility function for gradio/external.py"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import numbers
|
||||
import operator
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import requests
|
||||
import websockets
|
||||
import yaml
|
||||
from packaging import version
|
||||
|
||||
from gradio import components, exceptions
|
||||
|
||||
##################
|
||||
# Helper functions for processing tabular data
|
||||
##################
|
||||
|
||||
|
||||
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
|
||||
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
|
||||
if readme.status_code != 200:
|
||||
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
|
||||
example_data = {}
|
||||
else:
|
||||
yaml_regex = re.search(
|
||||
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
|
||||
)
|
||||
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
|
||||
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
||||
if not example_data:
|
||||
raise ValueError(
|
||||
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
|
||||
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
|
||||
"for a reference on how to provide example data to your model."
|
||||
)
|
||||
# replace nan with string NaN for inference API
|
||||
for data in example_data.values():
|
||||
for i, val in enumerate(data):
|
||||
if isinstance(val, numbers.Number) and math.isnan(val):
|
||||
data[i] = "NaN"
|
||||
return example_data
|
||||
|
||||
|
||||
def cols_to_rows(
|
||||
example_data: Dict[str, List[float]]
|
||||
) -> Tuple[List[str], List[List[float]]]:
|
||||
headers = list(example_data.keys())
|
||||
n_rows = max(len(example_data[header] or []) for header in headers)
|
||||
data = []
|
||||
for row_index in range(n_rows):
|
||||
row_data = []
|
||||
for header in headers:
|
||||
col = example_data[header] or []
|
||||
if row_index >= len(col):
|
||||
row_data.append("NaN")
|
||||
else:
|
||||
row_data.append(col[row_index])
|
||||
data.append(row_data)
|
||||
return headers, data
|
||||
|
||||
|
||||
def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
|
||||
data_column_wise = {}
|
||||
for i, header in enumerate(incoming_data["headers"]):
|
||||
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
|
||||
return {"inputs": {"data": data_column_wise}}
|
||||
|
||||
|
||||
##################
|
||||
# Helper functions for processing other kinds of data
|
||||
##################
|
||||
|
||||
|
||||
def postprocess_label(scores):
|
||||
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def encode_to_base64(r: requests.Response) -> str:
|
||||
# Handles the different ways HF API returns the prediction
|
||||
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
||||
data_prefix = ";base64,"
|
||||
# Case 1: base64 representation already includes data prefix
|
||||
if data_prefix in base64_repr:
|
||||
return base64_repr
|
||||
else:
|
||||
content_type = r.headers.get("content-type")
|
||||
# Case 2: the data prefix is a key in the response
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
content_type = r.json()[0]["content-type"]
|
||||
base64_repr = r.json()[0]["blob"]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Cannot determine content type returned" "by external API."
|
||||
)
|
||||
# Case 3: the data prefix is included in the response headers
|
||||
else:
|
||||
pass
|
||||
new_base64 = "data:{};base64,".format(content_type) + base64_repr
|
||||
return new_base64
|
||||
|
||||
|
||||
##################
|
||||
# Helper functions for connecting to websockets
|
||||
##################
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
|
||||
) -> Dict[str, Any]:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = await websocket.recv()
|
||||
resp = json.loads(msg)
|
||||
if resp["msg"] == "queue_full":
|
||||
raise exceptions.Error("Queue is full! Please try again.")
|
||||
if resp["msg"] == "send_hash":
|
||||
await websocket.send(hash_data)
|
||||
elif resp["msg"] == "send_data":
|
||||
await websocket.send(data)
|
||||
completed = resp["msg"] == "process_completed"
|
||||
return resp["output"]
|
||||
|
||||
|
||||
def get_ws_fn(ws_url, headers):
|
||||
async def ws_fn(data, hash_data):
|
||||
async with websockets.connect(
|
||||
ws_url, open_timeout=10, extra_headers=headers
|
||||
) as websocket:
|
||||
return await get_pred_from_ws(websocket, data, hash_data)
|
||||
|
||||
return ws_fn
|
||||
|
||||
|
||||
def use_websocket(config, dependency):
|
||||
queue_enabled = config.get("enable_queue", False)
|
||||
queue_uses_websocket = version.parse(
|
||||
config.get("version", "2.0")
|
||||
) >= version.Version("3.2")
|
||||
dependency_uses_queue = dependency.get("queue", False) is not False
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
|
||||
##################
|
||||
# Helper functions for cleaning up Interfaces/Blocks loaded from HF Spaces
|
||||
##################
|
||||
|
||||
|
||||
def streamline_spaces_interface(config: Dict) -> Dict:
|
||||
"""Streamlines the interface config dictionary to remove unnecessary keys."""
|
||||
config["inputs"] = [
|
||||
components.get_component_instance(component)
|
||||
for component in config["input_components"]
|
||||
]
|
||||
config["outputs"] = [
|
||||
components.get_component_instance(component)
|
||||
for component in config["output_components"]
|
||||
]
|
||||
parameters = {
|
||||
"article",
|
||||
"description",
|
||||
"flagging_options",
|
||||
"inputs",
|
||||
"outputs",
|
||||
"theme",
|
||||
"title",
|
||||
}
|
||||
config = {k: config[k] for k in parameters}
|
||||
return config
|
||||
|
||||
|
||||
def streamline_spaces_blocks(config: dict) -> dict:
|
||||
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
|
||||
# TODO(abidlabs): Need a better way to fix relative paths in dataset component
|
||||
for c, component in enumerate(config["components"]):
|
||||
if component["type"] == "dataset":
|
||||
config["components"][c]["props"]["visible"] = False
|
||||
return config
|
@ -31,9 +31,9 @@ from gradio.components import (
|
||||
)
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.events import Changeable, Streamable
|
||||
from gradio.external import load_from_pipeline # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.layouts import Column, Row, TabItem, Tabs
|
||||
from gradio.pipelines import load_from_pipeline # type: ignore
|
||||
|
||||
set_documentation_group("interface")
|
||||
|
||||
|
189
gradio/pipelines.py
Normal file
189
gradio/pipelines.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""This module should not be used directly as its API is subject to change. Instead,
|
||||
please use the `gr.Interface.from_pipeline()` function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from gradio import components
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import transformers
|
||||
|
||||
|
||||
def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"""
|
||||
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
|
||||
Returns:
|
||||
(dict): a dictionary of kwargs that can be used to construct an Interface object
|
||||
"""
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"transformers not installed. Please try `pip install transformers`"
|
||||
)
|
||||
if not isinstance(pipeline, transformers.Pipeline):
|
||||
raise ValueError("pipeline must be a transformers.Pipeline")
|
||||
|
||||
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
|
||||
# version of the transformers library that the user has installed.
|
||||
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.AudioClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
source="microphone", type="filepath", label="Input"
|
||||
),
|
||||
"outputs": components.Label(label="Class"),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
|
||||
pipeline, transformers.AutomaticSpeechRecognitionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
source="microphone", type="filepath", label="Input"
|
||||
),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: r["text"],
|
||||
}
|
||||
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
|
||||
pipeline, transformers.FeatureExtractionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Dataframe(label="Output"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0],
|
||||
}
|
||||
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
|
||||
pipeline, transformers.FillMaskPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ImageClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Label(type="confidences", label="Classification"),
|
||||
"preprocess": lambda i: {"images": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
|
||||
pipeline, transformers.QuestionAnsweringPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Textbox(lines=7, label="Context"),
|
||||
components.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": [
|
||||
components.Textbox(label="Answer"),
|
||||
components.Label(label="Score"),
|
||||
],
|
||||
"preprocess": lambda c, q: {"context": c, "question": q},
|
||||
"postprocess": lambda r: (r["answer"], r["score"]),
|
||||
}
|
||||
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
|
||||
pipeline, transformers.SummarizationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(lines=7, label="Input"),
|
||||
"outputs": components.Textbox(label="Summary"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0]["summary_text"],
|
||||
}
|
||||
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.TextClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"preprocess": lambda x: {"text_inputs": x},
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, "TranslationPipeline") and isinstance(
|
||||
pipeline, transformers.TranslationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Translation"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["translation_text"],
|
||||
}
|
||||
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.Text2TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Generated Text"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ZeroShotClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Textbox(label="Input"),
|
||||
components.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
components.Checkbox(label="Allow multiple true classes"),
|
||||
],
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"sequences": i,
|
||||
"candidate_labels": c,
|
||||
"multi_label": m,
|
||||
},
|
||||
"postprocess": lambda r: {
|
||||
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
|
||||
|
||||
# define the function that will be called by the Interface
|
||||
def fn(*params):
|
||||
data = pipeline_info["preprocess"](*params)
|
||||
# special cases that needs to be handled differently
|
||||
if isinstance(
|
||||
pipeline,
|
||||
(
|
||||
transformers.TextClassificationPipeline,
|
||||
transformers.Text2TextGenerationPipeline,
|
||||
transformers.TranslationPipeline,
|
||||
),
|
||||
):
|
||||
data = pipeline(*data)
|
||||
else:
|
||||
data = pipeline(**data)
|
||||
output = pipeline_info["postprocess"](data)
|
||||
return output
|
||||
|
||||
interface_info = pipeline_info.copy()
|
||||
interface_info["fn"] = fn
|
||||
del interface_info["preprocess"]
|
||||
del interface_info["postprocess"]
|
||||
|
||||
# define the title/description of the Interface
|
||||
interface_info["title"] = pipeline.model.__class__.__name__
|
||||
|
||||
return interface_info
|
@ -7,7 +7,6 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
|
||||
import gradio
|
||||
import gradio as gr
|
||||
@ -15,10 +14,10 @@ from gradio import utils
|
||||
from gradio.external import (
|
||||
TooManyRequestsError,
|
||||
cols_to_rows,
|
||||
get_pred_from_ws,
|
||||
get_tabular_examples,
|
||||
use_websocket,
|
||||
)
|
||||
from gradio.external_utils import get_pred_from_ws
|
||||
|
||||
"""
|
||||
WARNING: These tests have an external dependency: namely that Hugging Face's
|
||||
@ -43,52 +42,52 @@ class TestLoadInterface(unittest.TestCase):
|
||||
src="models",
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Audio)
|
||||
assert isinstance(interface.output_components[0], gr.components.Audio)
|
||||
|
||||
def test_question_answering(self):
|
||||
model_type = "image-classification"
|
||||
interface = gr.Blocks.load(
|
||||
name="lysandre/tiny-vit-random", src="models", alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Image)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Image)
|
||||
assert isinstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
def test_text_generation(self):
|
||||
model_type = "text_generation"
|
||||
interface = gr.Interface.load("models/gpt2", alias=model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
def test_summarization(self):
|
||||
model_type = "summarization"
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
def test_translation(self):
|
||||
model_type = "translation"
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
def test_text2text_generation(self):
|
||||
model_type = "text2text-generation"
|
||||
interface = gr.Interface.load(
|
||||
"models/sshleifer/tiny-mbart", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
def test_text_classification(self):
|
||||
model_type = "text-classification"
|
||||
@ -97,47 +96,47 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
def test_fill_mask(self):
|
||||
model_type = "fill-mask"
|
||||
interface = gr.Interface.load(
|
||||
"models/bert-base-uncased", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
def test_zero_shot_classification(self):
|
||||
model_type = "zero-shot-classification"
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/bart-large-mnli", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.input_components[1], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.input_components[2], gr.components.Checkbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.input_components[1], gr.components.Textbox)
|
||||
assert isinstance(interface.input_components[2], gr.components.Checkbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
def test_automatic_speech_recognition(self):
|
||||
model_type = "automatic-speech-recognition"
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Audio)
|
||||
assert isinstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
def test_image_classification(self):
|
||||
model_type = "image-classification"
|
||||
interface = gr.Interface.load(
|
||||
"models/google/vit-base-patch16-224", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Image)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Image)
|
||||
assert isinstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
def test_feature_extraction(self):
|
||||
model_type = "feature-extraction"
|
||||
@ -146,9 +145,9 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Dataframe)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Dataframe)
|
||||
|
||||
def test_sentence_similarity(self):
|
||||
model_type = "text-to-speech"
|
||||
@ -157,9 +156,9 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Audio)
|
||||
|
||||
def test_text_to_speech(self):
|
||||
model_type = "text-to-speech"
|
||||
@ -168,23 +167,23 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Audio)
|
||||
|
||||
def test_text_to_image(self):
|
||||
model_type = "text-to-image"
|
||||
interface = gr.Interface.load(
|
||||
"models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Image)
|
||||
assert interface.__name__ == model_type
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Image)
|
||||
|
||||
def test_english_to_spanish(self):
|
||||
interface = gr.Interface.load("spaces/abidlabs/english_to_spanish")
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.input_components[0], gr.components.Textbox)
|
||||
assert isinstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
def test_sentiment_model(self):
|
||||
io = gr.Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english")
|
||||
@ -206,7 +205,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
io = gr.Blocks.load(name="models/t5-base")
|
||||
try:
|
||||
output = io("My name is Sarah and I live in London")
|
||||
self.assertEqual(output, "Mein Name ist Sarah und ich lebe in London")
|
||||
assert output == "Mein Name ist Sarah und ich lebe in London"
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
@ -241,18 +240,11 @@ class TestLoadInterface(unittest.TestCase):
|
||||
)
|
||||
try:
|
||||
output = io("abc")
|
||||
self.assertEqual(output, "abc")
|
||||
assert output == "abc"
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
|
||||
class TestLoadFromPipeline(unittest.TestCase):
|
||||
def test_text_to_text_model_from_pipeline(self):
|
||||
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
output = pipe("My name is Sylvain and I work at Hugging Face in Brooklyn")
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
|
||||
class TestLoadInterfaceWithExamples:
|
||||
def test_interface_load_examples(self, tmp_path):
|
||||
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
@ -419,7 +411,7 @@ async def test_get_pred_from_ws_raises_if_queue_full():
|
||||
def test_respect_queue_when_load_from_config():
|
||||
with unittest.mock.patch("websockets.connect"):
|
||||
with unittest.mock.patch(
|
||||
"gradio.external.get_pred_from_ws", return_value={"data": ["foo"]}
|
||||
"gradio.external_utils.get_pred_from_ws", return_value={"data": ["foo"]}
|
||||
):
|
||||
interface = gr.Interface.load("spaces/freddyaboulton/saymyname")
|
||||
assert interface("bob") == "foo"
|
||||
|
11
test/test_pipelines.py
Normal file
11
test/test_pipelines.py
Normal file
@ -0,0 +1,11 @@
|
||||
import transformers
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class TestLoadFromPipeline:
|
||||
def test_text_to_text_model_from_pipeline(self):
|
||||
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
io = gr.Interface.from_pipeline(pipe)
|
||||
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
|
||||
assert isinstance(output, str)
|
Loading…
x
Reference in New Issue
Block a user