This commit is contained in:
dawoodkhan82 2021-10-25 18:30:44 -04:00
commit bcf46b225f
22 changed files with 458 additions and 122 deletions

View File

@ -1,6 +1,7 @@
version: 2.1
orbs:
codecov: codecov/codecov@3.1.1
node: circleci/node@1.0.1
jobs:
build:
docker:
@ -26,6 +27,14 @@ jobs:
key: deps1-{{ .Branch }}-{{ checksum "gradio.egg-info/requires.txt" }}
paths:
- "venv"
- node/install
- node/install-npm
- run:
name: Build frontend
command: |
cd frontend
npm install
npm run build
- run:
command: |
mkdir screenshots

8
demo/divide.py Normal file
View File

@ -0,0 +1,8 @@
import gradio as gr
def divide(num1):
return num1/0
iface = gr.Interface(fn=divide, inputs="number", outputs="number")
if __name__ == "__main__":
iface.launch(debug=True)

View File

@ -7,7 +7,7 @@ def image_mod(image):
return image.rotate(45)
iface = gr.Interface(image_mod, gr.inputs.Image(type="pil", optional=True), "image")
iface = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image")
if __name__ == "__main__":
iface.launch()

View File

@ -7,7 +7,7 @@ class ImageOutput extends BaseComponent {
return this.props.value ? (
<div className="output_image">
<div class="image_preview_holder">
<img class="image_preview" alt="" src={this.props.value[0]}></img>
<img class="image_preview" alt="" src={this.props.value}></img>
</div>
</div>
) : (

View File

@ -1,15 +1,15 @@
Flask-Cors>=3.0.8
Flask-Login
Flask>=1.1.1
analytics-python
ffmpy
flask-cachebuster
markdown2
matplotlib
numpy
pandas
paramiko
pillow
pycryptodome
pydub
matplotlib
pandas
pillow
ffmpy
markdown2
pycryptodome
requests
paramiko
analytics-python
Flask>=1.1.1
Flask-Cors>=3.0.8
flask-cachebuster
Flask-Login

View File

@ -3,10 +3,14 @@ import tempfile
import requests
from gradio import inputs, outputs
import re
import base64
def get_huggingface_interface(model_name, api_key, alias):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
print("Fetching model from: {}".format(model_url))
if api_key is not None:
headers = {"Authorization": f"Bearer {api_key}"}
else:
@ -17,6 +21,7 @@ def get_huggingface_interface(model_name, api_key, alias):
assert response.status_code == 200, "Invalid model name or src"
p = response.json().get('pipeline_tag')
# convert from binary to base64
def post_process_binary_body(r: requests.Response):
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(r.content)
@ -27,78 +32,77 @@ def get_huggingface_interface(model_name, api_key, alias):
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
'outputs': [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
'preprocess': lambda c, q: {"inputs": {"context": c, "question": q}},
'postprocess': lambda r: (r["answer"], r["score"]),
'postprocess': lambda r: (r.json()["answer"], r.json()["score"]),
# 'examples': [['My name is Sarah and I live in London', 'Where do I live?']]
},
'text-generation': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0]["generated_text"],
'postprocess': lambda r: r.json()[0]["generated_text"],
# 'examples': [['My name is Clara and I am']]
},
'summarization': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Summary"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0]["summary_text"]
'postprocess': lambda r: r.json()[0]["summary_text"]
},
'translation': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Translation"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0]["translation_text"]
'postprocess': lambda r: r.json()[0]["translation_text"]
},
'text2text-generation': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Generated Text"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0]["generated_text"]
'postprocess': lambda r: r.json()[0]["generated_text"]
},
'text-classification': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Label(label="Classification"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {'Negative': r[0][0]["score"],
'Positive': r[0][1]["score"]}
'postprocess': lambda r: {'Negative': r.json()[0][0]["score"],
'Positive': r.json()[0][1]["score"]}
},
'fill-mask': {
'inputs': inputs.Textbox(label="Input"),
'outputs': "label",
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r}
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
},
'zero-shot-classification': {
'inputs': [inputs.Textbox(label="Input"),
inputs.Textbox(label="Possible class names ("
"comma-separated)"),
inputs.Checkbox(label="Allow multiple true classes")],
'outputs': "label",
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i, c, m: {"inputs": i, "parameters":
{"candidate_labels": c, "multi_class": m}},
'postprocess': lambda r: {r["labels"][i]: r["scores"][i] for i in
range(len(r["labels"]))}
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
range(len(r.json()["labels"]))}
},
'automatic-speech-recognition': {
'inputs': inputs.Audio(label="Input", source="upload",
type="file"),
type="filepath"),
'outputs': outputs.Textbox(label="Output"),
'preprocess': lambda i: {"inputs": i},
'postprocess': lambda r: r["text"]
'postprocess': lambda r: r.json()["text"]
},
'image-classification': {
'inputs': inputs.Image(label="Input Image", type="file"),
'outputs': outputs.Label(label="Classification"),
'preprocess': lambda i: i,
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for
i in r}
'inputs': inputs.Image(label="Input Image", type="filepath"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
},
'feature-extraction': {
# example model: hf.co/julien-c/distilbert-feature-extraction
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Dataframe(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0],
'postprocess': lambda r: r.json()[0],
},
'sentence-similarity': {
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
@ -106,26 +110,26 @@ def get_huggingface_interface(model_name, api_key, alias):
inputs.Textbox(label="Source Sentence", default="That is a happy person"),
inputs.Textbox(lines=7, label="Sentences to compare to", placeholder="Separate each sentence by a newline"),
],
'outputs': outputs.Label(label="Classification"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda src, sentences: {"inputs": {
"source_sentence": src,
"sentences": [s for s in sentences.splitlines() if s != ""],
}},
'postprocess': lambda r: { f"sentence {i}": v for i, v in enumerate(r) },
'postprocess': lambda r: { f"sentence {i}": v for i, v in enumerate(r.json()) },
},
'text-to-speech': {
# example model: hf.co/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Audio(label="Audio"),
'preprocess': lambda x: {"inputs": x},
'postprocess': post_process_binary_body,
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
},
'text-to-image': {
# example model: hf.co/osanseviero/BigGAN-deep-128
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Image(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': post_process_binary_body,
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
},
}
@ -134,23 +138,16 @@ def get_huggingface_interface(model_name, api_key, alias):
pipeline = pipelines[p]
def query_huggingface_api(*input):
payload = pipeline['preprocess'](*input)
if p == 'automatic-speech-recognition' or p == 'image-classification':
with open(input[0].name, "rb") as f:
data = f.read()
else:
payload.update({'options': {'wait_for_model': True}})
data = json.dumps(payload)
response = requests.request("POST", api_url, headers=headers, data=data)
if response.status_code == 200:
if p == 'text-to-speech' or p == 'text-to-image':
output = pipeline['postprocess'](response)
else:
result = response.json()
output = pipeline['postprocess'](result)
else:
def query_huggingface_api(*params):
# Convert to a list of input components
data = pipeline['preprocess'](*params)
if isinstance(data, dict): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
data.update({'options': {'wait_for_model': True}})
data = json.dumps(data)
response = requests.request("POST", api_url, headers=headers, data=data)
if not(response.status_code == 200):
raise ValueError("Could not complete request to HuggingFace API, Error {}".format(response.status_code))
output = pipeline['postprocess'](response)
return output
if alias is None:
@ -163,14 +160,14 @@ def get_huggingface_interface(model_name, api_key, alias):
'inputs': pipeline['inputs'],
'outputs': pipeline['outputs'],
'title': model_name,
# 'examples': pipeline['examples'],
'api_mode': True,
}
return interface_info
def load_interface(name, src=None, api_key=None, alias=None):
if src is None:
tokens = name.split("/")
tokens = name.split("/") # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
assert len(tokens) > 1, "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
src = tokens[0]
name = "/".join(tokens[1:])
@ -182,20 +179,16 @@ def interface_params_from_config(config_dict):
## instantiate input component and output component
config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]]
config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]]
# remove preprocessing and postprocessing (since they'll be performed remotely)
for component in config_dict["inputs"]:
component.preprocess = lambda x:x
for component in config_dict["outputs"]:
component.postprocess = lambda x:x
# Remove keys that are not parameters to Interface() class
not_parameters = ("allow_embedding", "allow_interpretation", "avg_durations", "function_count",
"queue", "input_components", "output_components", "examples")
for key in not_parameters:
if key in config_dict:
del config_dict[key]
parameters = {
"allow_flagging", "allow_screenshot", "article", "description", "flagging_options", "inputs", "outputs",
"show_input", "show_output", "theme", "title"
}
config_dict = {k: config_dict[k] for k in parameters}
return config_dict
def get_spaces_interface(model_name, api_key, alias):
space_url = "https://huggingface.co/spaces/{}".format(model_name)
print("Fetching interface from: {}".format(space_url))
iframe_url = "https://huggingface.co/gradioiframe/{}/+".format(model_name)
api_url = "https://huggingface.co/gradioiframe/{}/api/predict/".format(model_name)
headers = {'Content-Type': 'application/json'}
@ -213,19 +206,20 @@ def get_spaces_interface(model_name, api_key, alias):
output = result["data"]
if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it
output = output[0]
if len(interface_info["outputs"])==1 and isinstance(output, list): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
output = output[0]
return output
if alias is None:
fn.__name__ = model_name
else:
fn.__name__ = alias
fn.__name__ = alias if (alias is not None) else model_name
interface_info["fn"] = fn
interface_info["api_mode"] = True
return interface_info
repos = {
# for each repo, we have a method that returns the Interface given the model name & optionally an api_key
"huggingface": get_huggingface_interface,
"models": get_huggingface_interface,
"spaces": get_spaces_interface,
}

View File

@ -1,6 +1,6 @@
{
"files": {
"main.css": "/static/css/main.e82c4b43.css",
"main.css": "/static/css/main.ac4c682c.css",
"main.js": "/static/bundle.js",
"index.html": "/index.html",
"static/bundle.js.LICENSE.txt": "/static/bundle.js.LICENSE.txt",
@ -11,7 +11,7 @@
},
"entrypoints": [
"static/bundle.css",
"static/css/main.e82c4b43.css",
"static/css/main.ac4c682c.css",
"static/bundle.js"
]
}

View File

@ -8,4 +8,4 @@
window.config = {{ config|tojson }};
} catch (e) {
window.config = {};
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.e82c4b43.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.ac4c682c.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>

View File

@ -32,6 +32,15 @@ class InputComponent(Component):
"""
return x
def serialize(self, x, called_directly):
"""
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) into the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
Parameters:
x (Any): Input to interface
called_directly (bool): if true, the interface was called(), otherwise, it is being used via the GUI
"""
return x
def preprocess_example(self, x):
"""
Any preprocessing needed to be performed on an example before being passed to the main function.
@ -603,7 +612,7 @@ class Image(InputComponent):
invert_colors (bool): whether to invert the image as a preprocessing step.
source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "base64" leaves as a base64 string.
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
label (str): component name in interface.
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
'''
@ -638,7 +647,7 @@ class Image(InputComponent):
}
def preprocess(self, x):
if x is None or self.type == "base64":
if x is None:
return x
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
@ -653,18 +662,41 @@ class Image(InputComponent):
return im
elif self.type == "numpy":
return np.array(im)
elif self.type == "file":
elif self.type == "file" or self.type == "filepath":
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=(
"."+fmt.lower() if fmt is not None else ".png"))
im.save(file_obj.name)
return file_obj
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
return file_obj
else:
return file_obj.name
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'pil', 'file'.")
". Please choose from: 'numpy', 'pil', 'filepath'.")
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
def serialize(self, x, called_directly=False):
# if called directly, can assume it's a URL or filepath
if self.type == "filepath" or called_directly:
return processing_utils.encode_url_or_file_to_base64(x)
elif self.type == "file":
return processing_utils.encode_url_or_file_to_base64(x.name)
elif self.type == "numpy" or "pil":
if self.type == "numpy":
x = PIL.Image.fromarray(np.uint8(x)).convert('RGB')
fmt = x.format
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=(
"."+fmt.lower() if fmt is not None else ".png"))
x.save(file_obj.name)
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'pil', 'filepath'.")
def set_interpret_parameters(self, segments=16):
"""
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
@ -813,6 +845,9 @@ class Video(InputComponent):
else:
return file_name
def serialize(self, x, called_directly):
raise NotImplementedError()
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
@ -834,7 +869,7 @@ class Audio(InputComponent):
"""
Parameters:
source (str): Source of audio. "upload" creates a box where user can drop an audio file, "microphone" creates a microphone input.
type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name.
type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
label (str): component name in interface.
optional (bool): If True, the interface can be submitted with no uploaded audio, in which case the input value is None.
"""
@ -863,9 +898,6 @@ class Audio(InputComponent):
}
def preprocess(self, x):
"""
By default, no pre-processing is applied to a microphone input file
"""
if x is None:
return x
file_name, file_data, is_example = x["name"], x["data"], x["is_example"]
@ -875,13 +907,38 @@ class Audio(InputComponent):
file_obj = processing_utils.decode_base64_to_file(
file_data, file_path=file_name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
return file_obj
elif self.type == "filepath":
return file_obj.name
elif self.type == "numpy":
return processing_utils.audio_from_file(file_obj.name)
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'filepath'.")
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="audio")
def serialize(self, x, called_directly):
if self.type == "filepath" or called_directly:
name = x
elif self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
name = x.name
elif self.type == "numpy":
file = tempfile.NamedTemporaryFile(delete=False)
name = file.name
processing_utils.audio_to_file(x[0], x[1], name)
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'filepath'.")
file_data = processing_utils.encode_url_or_file_to_base64(name, type="audio")
return {"name": name, "data": file_data, "is_example": False}
def set_interpret_parameters(self, segments=8):
"""
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.

View File

@ -71,7 +71,7 @@ class Interface:
title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False):
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False, api_mode=False):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
@ -100,6 +100,7 @@ class Interface:
flagging_dir (str): what to name the dir where flagged data is stored.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
"""
if not isinstance(fn, list):
fn = [fn]
@ -181,6 +182,7 @@ class Interface:
self.requires_permissions = any(
[component.requires_permissions for component in self.input_components])
self.enable_queue = enable_queue
self.api_mode = api_mode
data = {'fn': fn,
'inputs': inputs,
@ -216,10 +218,11 @@ class Interface:
pass # do not push analytics if no network
def __call__(self, *params):
output = [p(*params) for p in self.predict]
if len(output) == 1:
return output.pop() # if there's only one output, then don't return as list
return output
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
output = self.run_prediction(params, called_directly=True)
else:
output, _ = self.process(params)
return output[0] if len(output) == 1 else output
def __str__(self):
return self.__repr__()
@ -308,7 +311,10 @@ class Interface:
config["examples"] = self.examples
return config
def run_prediction(self, processed_input, return_duration=False):
def run_prediction(self, processed_input, return_duration=False, called_directly=False):
if self.api_mode: # Serialize the input
processed_input = [input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)]
predictions = []
durations = []
for predict_fn in self.predict:
@ -330,6 +336,10 @@ class Interface:
if len(self.output_components) == len(self.predict):
prediction = [prediction]
if self.api_mode: # Serialize the input
prediction = [output_component.deserialize(prediction[o])
for o, output_component in enumerate(self.output_components)]
durations.append(duration)
predictions.extend(prediction)
@ -348,8 +358,8 @@ class Interface:
for i, input_component in enumerate(self.input_components)]
predictions, durations = self.run_prediction(
processed_input, return_duration=True)
processed_output = [output_component.postprocess(
predictions[i]) if predictions[i] is not None else None for i, output_component in enumerate(self.output_components)]
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
for i, output_component in enumerate(self.output_components)]
return processed_output, durations
def interpret(self, raw_input):

View File

@ -25,6 +25,8 @@ from gradio import encryptor
from gradio import queue
from functools import wraps
import io
import traceback
INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -138,6 +140,7 @@ def static_resource(path):
return send_file(os.path.join(STATIC_PATH_LIB, path))
# TODO(@aliabid94): this throws a 500 error if app.auth is None (should probalbly just redirect to '/')
@app.route('/login', methods=["GET", "POST"])
def login():
if request.method == "GET":

View File

@ -31,6 +31,12 @@ class OutputComponent(Component):
"""
return y
def deserialize(self, x):
"""
Convert from serialized output (e.g. base64 representation) from a call() to the interface to a human-readable version of the output (path of an image, etc.)
"""
return x
class Textbox(OutputComponent):
'''
@ -115,6 +121,21 @@ class Label(OutputComponent):
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences.")
def deserialize(self, y):
# 4 cases: (1): {'label': 'lion'}, {'label': 'lion', 'confidences':...}, {'lion': 0.46, ...}, 'lion'
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or ('label' in y and not('confidences' in y.keys())))):
if isinstance(y, str):
return y
else:
return y['label']
elif self.type == "confidences" or self.type == "auto":
if 'confidences' in y.keys() and isinstance(y['confidences'], list):
return {k['label']:k['confidence'] for k in y['confidences']}
else:
return y
raise ValueError("Unable to deserialize output: {}".format(y))
@classmethod
def get_shortcut_implementations(cls):
return {
@ -145,15 +166,13 @@ class Image(OutputComponent):
Demos: image_mod.py, webcam.py
'''
def __init__(self, type="auto", labeled_segments=False, plot=False, label=None):
def __init__(self, type="auto", plot=False, label=None):
'''
Parameters:
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
labeled_segments (bool): If True, expects a two-element tuple to be returned. The first element of the tuple is the image of format specified by type. The second element is a list of tuples, where each tuple represents a labeled segment within the image. The first element of the tuple is the string label of the segment, followed by 4 floats that represent the left-x, top-y, right-x, and bottom-y coordinates of the bounding box.
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
label (str): component name in interface.
'''
self.labeled_segments = labeled_segments
if plot:
warnings.warn(
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.", DeprecationWarning)
@ -166,16 +185,11 @@ class Image(OutputComponent):
def get_shortcut_implementations(cls):
return {
"image": {},
"segmented_image": {"labeled_segments": True},
"plot": {"type": "plot"},
"pil": {"type": "pil"}
}
def postprocess(self, y):
if self.labeled_segments:
y, coordinates = y
else:
coordinates = []
if self.type == "auto":
if isinstance(y, np.ndarray):
dtype = "numpy"
@ -195,17 +209,17 @@ class Image(OutputComponent):
y = np.array(y)
out_y = processing_utils.encode_array_to_base64(y)
elif dtype == "file":
try:
requests.get(y)
out_y = processing_utils.encode_url_to_base64(y)
except requests.exceptions.MissingSchema:
out_y = processing_utils.encode_file_to_base64(y)
out_y = processing_utils.encode_url_or_file_to_base64(y)
elif dtype == "plot":
out_y = processing_utils.encode_plot_to_base64(y)
else:
raise ValueError("Unknown type: " + dtype +
". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
return out_y, coordinates
return out_y
def deserialize(self, x):
y = processing_utils.decode_base64_to_file(x).name
return y
def save_flagged(self, dir, label, data, encryption_key):
"""
@ -253,6 +267,9 @@ class Video(OutputComponent):
"data": processing_utils.encode_file_to_base64(y, type="video")
}
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
@ -369,11 +386,14 @@ class Audio(OutputComponent):
file = tempfile.NamedTemporaryFile(prefix="sample", suffix=".wav", delete=False)
processing_utils.audio_to_file(sample_rate, data, file.name)
y = file.name
return processing_utils.encode_file_to_base64(y, type="audio", ext="wav")
return processing_utils.encode_url_or_file_to_base64(y, type="audio", ext="wav")
else:
raise ValueError("Unknown type: " + self.type +
". Please choose from: 'numpy', 'file'.")
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to audio file

View File

@ -21,6 +21,22 @@ def decode_base64_to_image(encoding):
return Image.open(BytesIO(base64.b64decode(image_encoded)))
def get_url_or_file_as_bytes(path):
try:
return requests.get(path).content
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
with open(path, "rb") as f:
return f.read()
def encode_url_or_file_to_base64(path, type="image", ext=None, header=True):
try:
requests.get(path)
return encode_url_to_base64(path, type, ext, header)
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
return encode_file_to_base64(path, type, ext, header)
def encode_file_to_base64(f, type="image", ext=None, header=True):
with open(f, "rb") as file:
encoded_string = base64.b64encode(file.read())

View File

@ -11,8 +11,6 @@ from io import StringIO
import warnings
import paramiko
DEBUG_MODE = False
def handler(chan, host, port):
sock = socket.socket()
@ -55,8 +53,8 @@ def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
thr.start()
def verbose(s):
if DEBUG_MODE:
def verbose(s, debug_mode=False):
if debug_mode:
print(s)

View File

@ -65,12 +65,13 @@ def ipython_check():
Check if interface is launching from iPython (not colab)
:return is_ipython (bool): True or False
"""
is_ipython = False
try: # Check if running interactively using ipython.
from IPython import get_ipython
get_ipython()
is_ipython = True
if get_ipython() is not None:
is_ipython = True
except (ImportError, NameError):
is_ipython = False
pass
return is_ipython

View File

@ -18,7 +18,7 @@ GOLDEN_PATH = "test/golden/{}/{}.png"
TOLERANCE = 0.1
TIMEOUT = 10
GAP_TO_SCREENSHOT = 1
GAP_TO_SCREENSHOT = 2
def wait_for_url(url):
for i in range(TIMEOUT):

View File

@ -1,7 +1,12 @@
import unittest
import pathlib
import gradio as gr
class TestHuggingFaceModels(unittest.TestCase):
"""
WARNING: These tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
"""
class TestHuggingFaceModelAPI(unittest.TestCase):
def test_text_generation(self):
model_type = "text_generation"
interface_info = gr.external.get_huggingface_interface("gpt2", api_key=None, alias=None)
@ -45,7 +50,8 @@ class TestHuggingFaceModels(unittest.TestCase):
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
class TestHuggingFaceSpaces(unittest.TestCase):
class TestHuggingFaceSpaceAPI(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
@ -63,7 +69,46 @@ class TestLoadInterface(unittest.TestCase):
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_models_src(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
class TestCallingLoadInterface(unittest.TestCase):
def test_sentiment_model(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
io = gr.Interface(**interface_info)
output = io("I am happy, I love you.")
self.assertGreater(output['Positive'], 0.5)
def test_image_classification_model(self):
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
io = gr.Interface(**interface_info)
output = io("test/images/lion.jpg")
self.assertGreater(output['lion'], 0.5)
def test_translation_model(self):
interface_info = gr.external.load_interface("models/t5-base")
io = gr.Interface(**interface_info)
output = io("My name is Sarah and I live in London")
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
def test_numerical_to_label_space(self):
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
io = gr.Interface(**interface_info)
output = io("male", 77, 10)
self.assertLess(output['Survives'], 0.5)
def test_image_to_image_space(self):
def assertIsFile(path):
if not pathlib.Path(path).resolve().is_file():
raise AssertionError("File does not exist: %s" % str(path))
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
io = gr.Interface(**interface_info)
output = io("test/images/lion.jpg")
assertIsFile(output)
if __name__ == '__main__':
unittest.main()

View File

@ -83,7 +83,7 @@ class TestImage(unittest.TestCase):
open_and_rotate,
gr.inputs.Image(shape=(30, 10), type="file"),
"image")
output = iface.process([x_img])[0][0][0]
output = iface.process([x_img])[0][0]
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))

99
test/test_networking.py Normal file
View File

@ -0,0 +1,99 @@
from gradio import networking
import gradio as gr
import unittest
import unittest.mock as mock
import ipaddress
import requests
import warnings
class TestUser(unittest.TestCase):
def test_id(self):
user = networking.User("test")
self.assertEqual(user.get_id(), "test")
def test_load_user(self):
user = networking.load_user("test")
self.assertEqual(user.get_id(), "test")
class TestIPAddress(unittest.TestCase):
def test_get_ip(self):
ip = networking.get_local_ip_address()
try: # check whether ip is valid
ipaddress.ip_address(ip)
except ValueError:
self.fail("Invalid IP address")
@mock.patch("requests.get")
def test_get_ip_without_internet(self, mock_get):
mock_get.side_effect = requests.ConnectionError()
ip = networking.get_local_ip_address()
self.assertEqual(ip, "No internet connection")
class TestPort(unittest.TestCase):
def test_port_is_in_range(self):
start = 7860
end = 7960
try:
port = networking.get_first_available_port(start, end)
self.assertTrue(start <= port <= end)
except OSError:
warnings.warn("Unable to test, no ports available")
def test_same_port_is_returned(self):
start = 7860
end = 7960
try:
port1 = networking.get_first_available_port(start, end)
port2 = networking.get_first_available_port(start, end)
self.assertEqual(port1, port2)
except OSError:
warnings.warn("Unable to test, no ports available")
class TestFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = self.app.test_client()
def test_get_main_route(self):
response = self.client.get('/')
self.assertEqual(response.status_code, 200)
def test_get_config_route(self):
response = self.client.get('/config/')
self.assertEqual(response.status_code, 200)
def test_get_static_route(self):
response = self.client.get('/static/bundle.css')
self.assertEqual(response.status_code, 302) # This should redirect to static files.
def tearDown(self) -> None:
self.io.close()
gr.reset_all()
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
self.client = self.app.test_client()
def test_get_login_route(self):
response = self.client.get('/login')
self.assertEqual(response.status_code, 200)
def test_post_login(self):
response = self.client.post('/login', data=dict(username="test", password="correct_password"))
self.assertEqual(response.status_code, 302)
response = self.client.post('/login', data=dict(username="test", password="incorrect_password"))
self.assertEqual(response.status_code, 401)
def tearDown(self) -> None:
self.io.close()
gr.reset_all()
if __name__ == '__main__':
unittest.main()

View File

@ -62,15 +62,15 @@ class TestImage(unittest.TestCase):
def test_as_component(self):
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
image_output = gr.outputs.Image()
self.assertTrue(image_output.postprocess(y_img)[0].startswith(""))
self.assertTrue(image_output.postprocess(np.array(y_img))[0].startswith(""))
self.assertTrue(image_output.postprocess(y_img).startswith(""))
self.assertTrue(image_output.postprocess(np.array(y_img)).startswith(""))
def test_in_interface(self):
def generate_noise(width, height):
return np.random.randint(0, 256, (width, height, 3))
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(iface.process([10, 20])[0][0][0].startswith("data:image/png;base64"))
self.assertTrue(iface.process([10, 20])[0][0].startswith("data:image/png;base64"))
class TestKeyValues(unittest.TestCase):
def test_in_interface(self):

33
test/test_tunneling.py Normal file
View File

@ -0,0 +1,33 @@
import io
import sys
import unittest
import unittest.mock as mock
from gradio import tunneling
# class TestTunneling(unittest.TestCase):
# pass
# @mock.patch("pkg_resources.require")
# def test_should_fail_with_distribution_not_found(self, mock_require):
class TestVerbose(unittest.TestCase):
"""Unncessary tests but just including them for the sake of completion."""
def setUp(self):
self.message = "print test"
self.capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = self.capturedOutput # and redirect stdout.
def test_verbose_debug_true(self):
tunneling.verbose(self.message, debug_mode=True)
self.assertEqual(self.capturedOutput.getvalue().strip(), self.message)
def test_verbose_debug_false(self):
tunneling.verbose(self.message, debug_mode=False)
self.assertEqual(self.capturedOutput.getvalue().strip(), '')
def tearDown(self):
sys.stdout = sys.__stdout__
if __name__ == '__main__':
unittest.main()

View File

@ -3,6 +3,7 @@ import unittest
import pkg_resources
import unittest.mock as mock
import warnings
import requests
class TestUtils(unittest.TestCase):
@ -29,7 +30,7 @@ class TestUtils(unittest.TestCase):
@mock.patch("requests.get")
def test_should_warn_with_connection_error(self, mock_get):
mock_get.side_effect = ConnectionError()
mock_get.side_effect = requests.ConnectionError()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@ -45,7 +46,49 @@ class TestUtils(unittest.TestCase):
warnings.simplefilter("always")
version_check()
self.assertEqual(str(w[-1].message), "package URL does not contain version info.")
@mock.patch("requests.post")
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
mock_post.side_effect = requests.ConnectionError()
error_analytics("placeholder")
@mock.patch("requests.post")
def test_error_analytics_successful(self, mock_post):
error_analytics("placeholder")
@mock.patch("IPython.get_ipython")
@mock.patch("gradio.utils.error_analytics")
def test_colab_check_sends_analytics_on_import_fail(self, mock_error_analytics, mock_get_ipython):
mock_get_ipython.side_effect = ImportError()
colab_check()
mock_error_analytics.assert_called_with("NameError")
@mock.patch("IPython.get_ipython")
def test_colab_check_no_ipython(self, mock_get_ipython):
mock_get_ipython.return_value = None
assert colab_check() is False
@mock.patch("IPython.get_ipython")
def test_ipython_check_import_fail(self, mock_get_ipython):
mock_get_ipython.side_effect = ImportError()
assert ipython_check() is False
@mock.patch("IPython.get_ipython")
def test_ipython_check_no_ipython(self, mock_get_ipython):
mock_get_ipython.return_value = None
assert ipython_check() is False
@mock.patch("requests.get")
def test_readme_to_html_doesnt_crah_on_connection_error(self, mock_get):
mock_get.side_effect = requests.ConnectionError()
readme_to_html("placeholder")
def test_readme_to_html_correct_parse(self):
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
if __name__ == '__main__':
unittest.main()