mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
merge
This commit is contained in:
commit
bcf46b225f
@ -1,6 +1,7 @@
|
||||
version: 2.1
|
||||
orbs:
|
||||
codecov: codecov/codecov@3.1.1
|
||||
node: circleci/node@1.0.1
|
||||
jobs:
|
||||
build:
|
||||
docker:
|
||||
@ -26,6 +27,14 @@ jobs:
|
||||
key: deps1-{{ .Branch }}-{{ checksum "gradio.egg-info/requires.txt" }}
|
||||
paths:
|
||||
- "venv"
|
||||
- node/install
|
||||
- node/install-npm
|
||||
- run:
|
||||
name: Build frontend
|
||||
command: |
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
- run:
|
||||
command: |
|
||||
mkdir screenshots
|
||||
|
8
demo/divide.py
Normal file
8
demo/divide.py
Normal file
@ -0,0 +1,8 @@
|
||||
import gradio as gr
|
||||
|
||||
def divide(num1):
|
||||
return num1/0
|
||||
|
||||
iface = gr.Interface(fn=divide, inputs="number", outputs="number")
|
||||
if __name__ == "__main__":
|
||||
iface.launch(debug=True)
|
@ -7,7 +7,7 @@ def image_mod(image):
|
||||
return image.rotate(45)
|
||||
|
||||
|
||||
iface = gr.Interface(image_mod, gr.inputs.Image(type="pil", optional=True), "image")
|
||||
iface = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image")
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -7,7 +7,7 @@ class ImageOutput extends BaseComponent {
|
||||
return this.props.value ? (
|
||||
<div className="output_image">
|
||||
<div class="image_preview_holder">
|
||||
<img class="image_preview" alt="" src={this.props.value[0]}></img>
|
||||
<img class="image_preview" alt="" src={this.props.value}></img>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
|
@ -1,15 +1,15 @@
|
||||
Flask-Cors>=3.0.8
|
||||
Flask-Login
|
||||
Flask>=1.1.1
|
||||
analytics-python
|
||||
ffmpy
|
||||
flask-cachebuster
|
||||
markdown2
|
||||
matplotlib
|
||||
numpy
|
||||
pandas
|
||||
paramiko
|
||||
pillow
|
||||
pycryptodome
|
||||
pydub
|
||||
matplotlib
|
||||
pandas
|
||||
pillow
|
||||
ffmpy
|
||||
markdown2
|
||||
pycryptodome
|
||||
requests
|
||||
paramiko
|
||||
analytics-python
|
||||
Flask>=1.1.1
|
||||
Flask-Cors>=3.0.8
|
||||
flask-cachebuster
|
||||
Flask-Login
|
||||
|
@ -3,10 +3,14 @@ import tempfile
|
||||
import requests
|
||||
from gradio import inputs, outputs
|
||||
import re
|
||||
import base64
|
||||
|
||||
|
||||
def get_huggingface_interface(model_name, api_key, alias):
|
||||
model_url = "https://huggingface.co/{}".format(model_name)
|
||||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
||||
print("Fetching model from: {}".format(model_url))
|
||||
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
@ -17,6 +21,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
assert response.status_code == 200, "Invalid model name or src"
|
||||
p = response.json().get('pipeline_tag')
|
||||
|
||||
# convert from binary to base64
|
||||
def post_process_binary_body(r: requests.Response):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||
fp.write(r.content)
|
||||
@ -27,78 +32,77 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
|
||||
'outputs': [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
|
||||
'preprocess': lambda c, q: {"inputs": {"context": c, "question": q}},
|
||||
'postprocess': lambda r: (r["answer"], r["score"]),
|
||||
'postprocess': lambda r: (r.json()["answer"], r.json()["score"]),
|
||||
# 'examples': [['My name is Sarah and I live in London', 'Where do I live?']]
|
||||
},
|
||||
'text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["generated_text"],
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"],
|
||||
# 'examples': [['My name is Clara and I am']]
|
||||
},
|
||||
'summarization': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Summary"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["summary_text"]
|
||||
'postprocess': lambda r: r.json()[0]["summary_text"]
|
||||
},
|
||||
'translation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Translation"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["translation_text"]
|
||||
'postprocess': lambda r: r.json()[0]["translation_text"]
|
||||
},
|
||||
'text2text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["generated_text"]
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"]
|
||||
},
|
||||
'text-classification': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {'Negative': r[0][0]["score"],
|
||||
'Positive': r[0][1]["score"]}
|
||||
'postprocess': lambda r: {'Negative': r.json()[0][0]["score"],
|
||||
'Positive': r.json()[0][1]["score"]}
|
||||
},
|
||||
'fill-mask': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': "label",
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r}
|
||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
|
||||
},
|
||||
'zero-shot-classification': {
|
||||
'inputs': [inputs.Textbox(label="Input"),
|
||||
inputs.Textbox(label="Possible class names ("
|
||||
"comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes")],
|
||||
'outputs': "label",
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i, c, m: {"inputs": i, "parameters":
|
||||
{"candidate_labels": c, "multi_class": m}},
|
||||
'postprocess': lambda r: {r["labels"][i]: r["scores"][i] for i in
|
||||
range(len(r["labels"]))}
|
||||
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
|
||||
range(len(r.json()["labels"]))}
|
||||
},
|
||||
'automatic-speech-recognition': {
|
||||
'inputs': inputs.Audio(label="Input", source="upload",
|
||||
type="file"),
|
||||
type="filepath"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda i: {"inputs": i},
|
||||
'postprocess': lambda r: r["text"]
|
||||
'postprocess': lambda r: r.json()["text"]
|
||||
},
|
||||
'image-classification': {
|
||||
'inputs': inputs.Image(label="Input Image", type="file"),
|
||||
'outputs': outputs.Label(label="Classification"),
|
||||
'preprocess': lambda i: i,
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for
|
||||
i in r}
|
||||
'inputs': inputs.Image(label="Input Image", type="filepath"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
},
|
||||
'feature-extraction': {
|
||||
# example model: hf.co/julien-c/distilbert-feature-extraction
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Dataframe(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0],
|
||||
'postprocess': lambda r: r.json()[0],
|
||||
},
|
||||
'sentence-similarity': {
|
||||
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
||||
@ -106,26 +110,26 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
inputs.Textbox(label="Source Sentence", default="That is a happy person"),
|
||||
inputs.Textbox(lines=7, label="Sentences to compare to", placeholder="Separate each sentence by a newline"),
|
||||
],
|
||||
'outputs': outputs.Label(label="Classification"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda src, sentences: {"inputs": {
|
||||
"source_sentence": src,
|
||||
"sentences": [s for s in sentences.splitlines() if s != ""],
|
||||
}},
|
||||
'postprocess': lambda r: { f"sentence {i}": v for i, v in enumerate(r) },
|
||||
'postprocess': lambda r: { f"sentence {i}": v for i, v in enumerate(r.json()) },
|
||||
},
|
||||
'text-to-speech': {
|
||||
# example model: hf.co/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Audio(label="Audio"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': post_process_binary_body,
|
||||
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
|
||||
},
|
||||
'text-to-image': {
|
||||
# example model: hf.co/osanseviero/BigGAN-deep-128
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Image(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': post_process_binary_body,
|
||||
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
|
||||
},
|
||||
}
|
||||
|
||||
@ -134,23 +138,16 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
|
||||
pipeline = pipelines[p]
|
||||
|
||||
def query_huggingface_api(*input):
|
||||
payload = pipeline['preprocess'](*input)
|
||||
if p == 'automatic-speech-recognition' or p == 'image-classification':
|
||||
with open(input[0].name, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
payload.update({'options': {'wait_for_model': True}})
|
||||
data = json.dumps(payload)
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if response.status_code == 200:
|
||||
if p == 'text-to-speech' or p == 'text-to-image':
|
||||
output = pipeline['postprocess'](response)
|
||||
else:
|
||||
result = response.json()
|
||||
output = pipeline['postprocess'](result)
|
||||
else:
|
||||
def query_huggingface_api(*params):
|
||||
# Convert to a list of input components
|
||||
data = pipeline['preprocess'](*params)
|
||||
if isinstance(data, dict): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
|
||||
data.update({'options': {'wait_for_model': True}})
|
||||
data = json.dumps(data)
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if not(response.status_code == 200):
|
||||
raise ValueError("Could not complete request to HuggingFace API, Error {}".format(response.status_code))
|
||||
output = pipeline['postprocess'](response)
|
||||
return output
|
||||
|
||||
if alias is None:
|
||||
@ -163,14 +160,14 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
'inputs': pipeline['inputs'],
|
||||
'outputs': pipeline['outputs'],
|
||||
'title': model_name,
|
||||
# 'examples': pipeline['examples'],
|
||||
'api_mode': True,
|
||||
}
|
||||
|
||||
return interface_info
|
||||
|
||||
def load_interface(name, src=None, api_key=None, alias=None):
|
||||
if src is None:
|
||||
tokens = name.split("/")
|
||||
tokens = name.split("/") # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
|
||||
assert len(tokens) > 1, "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
|
||||
src = tokens[0]
|
||||
name = "/".join(tokens[1:])
|
||||
@ -182,20 +179,16 @@ def interface_params_from_config(config_dict):
|
||||
## instantiate input component and output component
|
||||
config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]]
|
||||
config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]]
|
||||
# remove preprocessing and postprocessing (since they'll be performed remotely)
|
||||
for component in config_dict["inputs"]:
|
||||
component.preprocess = lambda x:x
|
||||
for component in config_dict["outputs"]:
|
||||
component.postprocess = lambda x:x
|
||||
# Remove keys that are not parameters to Interface() class
|
||||
not_parameters = ("allow_embedding", "allow_interpretation", "avg_durations", "function_count",
|
||||
"queue", "input_components", "output_components", "examples")
|
||||
for key in not_parameters:
|
||||
if key in config_dict:
|
||||
del config_dict[key]
|
||||
parameters = {
|
||||
"allow_flagging", "allow_screenshot", "article", "description", "flagging_options", "inputs", "outputs",
|
||||
"show_input", "show_output", "theme", "title"
|
||||
}
|
||||
config_dict = {k: config_dict[k] for k in parameters}
|
||||
return config_dict
|
||||
|
||||
def get_spaces_interface(model_name, api_key, alias):
|
||||
space_url = "https://huggingface.co/spaces/{}".format(model_name)
|
||||
print("Fetching interface from: {}".format(space_url))
|
||||
iframe_url = "https://huggingface.co/gradioiframe/{}/+".format(model_name)
|
||||
api_url = "https://huggingface.co/gradioiframe/{}/api/predict/".format(model_name)
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
@ -213,19 +206,20 @@ def get_spaces_interface(model_name, api_key, alias):
|
||||
output = result["data"]
|
||||
if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it
|
||||
output = output[0]
|
||||
if len(interface_info["outputs"])==1 and isinstance(output, list): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
if alias is None:
|
||||
fn.__name__ = model_name
|
||||
else:
|
||||
fn.__name__ = alias
|
||||
fn.__name__ = alias if (alias is not None) else model_name
|
||||
interface_info["fn"] = fn
|
||||
|
||||
interface_info["api_mode"] = True
|
||||
|
||||
return interface_info
|
||||
|
||||
repos = {
|
||||
# for each repo, we have a method that returns the Interface given the model name & optionally an api_key
|
||||
"huggingface": get_huggingface_interface,
|
||||
"models": get_huggingface_interface,
|
||||
"spaces": get_spaces_interface,
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.e82c4b43.css",
|
||||
"main.css": "/static/css/main.ac4c682c.css",
|
||||
"main.js": "/static/bundle.js",
|
||||
"index.html": "/index.html",
|
||||
"static/bundle.js.LICENSE.txt": "/static/bundle.js.LICENSE.txt",
|
||||
@ -11,7 +11,7 @@
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/bundle.css",
|
||||
"static/css/main.e82c4b43.css",
|
||||
"static/css/main.ac4c682c.css",
|
||||
"static/bundle.js"
|
||||
]
|
||||
}
|
@ -8,4 +8,4 @@
|
||||
window.config = {{ config|tojson }};
|
||||
} catch (e) {
|
||||
window.config = {};
|
||||
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.e82c4b43.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
|
||||
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.ac4c682c.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
|
@ -32,6 +32,15 @@ class InputComponent(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def serialize(self, x, called_directly):
|
||||
"""
|
||||
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) into the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
|
||||
Parameters:
|
||||
x (Any): Input to interface
|
||||
called_directly (bool): if true, the interface was called(), otherwise, it is being used via the GUI
|
||||
"""
|
||||
return x
|
||||
|
||||
def preprocess_example(self, x):
|
||||
"""
|
||||
Any preprocessing needed to be performed on an example before being passed to the main function.
|
||||
@ -603,7 +612,7 @@ class Image(InputComponent):
|
||||
invert_colors (bool): whether to invert the image as a preprocessing step.
|
||||
source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
||||
tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
|
||||
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "base64" leaves as a base64 string.
|
||||
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
||||
label (str): component name in interface.
|
||||
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
||||
'''
|
||||
@ -638,7 +647,7 @@ class Image(InputComponent):
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
if x is None or self.type == "base64":
|
||||
if x is None:
|
||||
return x
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
@ -653,18 +662,41 @@ class Image(InputComponent):
|
||||
return im
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file":
|
||||
elif self.type == "file" or self.type == "filepath":
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=(
|
||||
"."+fmt.lower() if fmt is not None else ".png"))
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
|
||||
return file_obj
|
||||
else:
|
||||
return file_obj.name
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'pil', 'file'.")
|
||||
". Please choose from: 'numpy', 'pil', 'filepath'.")
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
|
||||
def serialize(self, x, called_directly=False):
|
||||
# if called directly, can assume it's a URL or filepath
|
||||
if self.type == "filepath" or called_directly:
|
||||
return processing_utils.encode_url_or_file_to_base64(x)
|
||||
elif self.type == "file":
|
||||
return processing_utils.encode_url_or_file_to_base64(x.name)
|
||||
elif self.type == "numpy" or "pil":
|
||||
if self.type == "numpy":
|
||||
x = PIL.Image.fromarray(np.uint8(x)).convert('RGB')
|
||||
fmt = x.format
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=(
|
||||
"."+fmt.lower() if fmt is not None else ".png"))
|
||||
x.save(file_obj.name)
|
||||
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'pil', 'filepath'.")
|
||||
|
||||
def set_interpret_parameters(self, segments=16):
|
||||
"""
|
||||
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
||||
@ -813,6 +845,9 @@ class Video(InputComponent):
|
||||
else:
|
||||
return file_name
|
||||
|
||||
def serialize(self, x, called_directly):
|
||||
raise NotImplementedError()
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
|
||||
@ -834,7 +869,7 @@ class Audio(InputComponent):
|
||||
"""
|
||||
Parameters:
|
||||
source (str): Source of audio. "upload" creates a box where user can drop an audio file, "microphone" creates a microphone input.
|
||||
type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name.
|
||||
type (str): Type of value to be returned by component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
||||
label (str): component name in interface.
|
||||
optional (bool): If True, the interface can be submitted with no uploaded audio, in which case the input value is None.
|
||||
"""
|
||||
@ -863,9 +898,6 @@ class Audio(InputComponent):
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
"""
|
||||
By default, no pre-processing is applied to a microphone input file
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
file_name, file_data, is_example = x["name"], x["data"], x["is_example"]
|
||||
@ -875,13 +907,38 @@ class Audio(InputComponent):
|
||||
file_obj = processing_utils.decode_base64_to_file(
|
||||
file_data, file_path=file_name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
|
||||
return file_obj
|
||||
elif self.type == "filepath":
|
||||
return file_obj.name
|
||||
elif self.type == "numpy":
|
||||
return processing_utils.audio_from_file(file_obj.name)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'filepath'.")
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x, type="audio")
|
||||
|
||||
def serialize(self, x, called_directly):
|
||||
if self.type == "filepath" or called_directly:
|
||||
name = x
|
||||
elif self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
|
||||
name = x.name
|
||||
elif self.type == "numpy":
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
name = file.name
|
||||
processing_utils.audio_to_file(x[0], x[1], name)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'filepath'.")
|
||||
|
||||
file_data = processing_utils.encode_url_or_file_to_base64(name, type="audio")
|
||||
return {"name": name, "data": file_data, "is_example": False}
|
||||
|
||||
def set_interpret_parameters(self, segments=8):
|
||||
"""
|
||||
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.
|
||||
|
@ -71,7 +71,7 @@ class Interface:
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False):
|
||||
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False, api_mode=False):
|
||||
"""
|
||||
Parameters:
|
||||
fn (Callable): the function to wrap an interface around.
|
||||
@ -100,6 +100,7 @@ class Interface:
|
||||
flagging_dir (str): what to name the dir where flagged data is stored.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
|
||||
"""
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
@ -181,6 +182,7 @@ class Interface:
|
||||
self.requires_permissions = any(
|
||||
[component.requires_permissions for component in self.input_components])
|
||||
self.enable_queue = enable_queue
|
||||
self.api_mode = api_mode
|
||||
|
||||
data = {'fn': fn,
|
||||
'inputs': inputs,
|
||||
@ -216,10 +218,11 @@ class Interface:
|
||||
pass # do not push analytics if no network
|
||||
|
||||
def __call__(self, *params):
|
||||
output = [p(*params) for p in self.predict]
|
||||
if len(output) == 1:
|
||||
return output.pop() # if there's only one output, then don't return as list
|
||||
return output
|
||||
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
|
||||
output = self.run_prediction(params, called_directly=True)
|
||||
else:
|
||||
output, _ = self.process(params)
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
@ -308,7 +311,10 @@ class Interface:
|
||||
config["examples"] = self.examples
|
||||
return config
|
||||
|
||||
def run_prediction(self, processed_input, return_duration=False):
|
||||
def run_prediction(self, processed_input, return_duration=False, called_directly=False):
|
||||
if self.api_mode: # Serialize the input
|
||||
processed_input = [input_component.serialize(processed_input[i], called_directly)
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
predictions = []
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
@ -330,6 +336,10 @@ class Interface:
|
||||
if len(self.output_components) == len(self.predict):
|
||||
prediction = [prediction]
|
||||
|
||||
if self.api_mode: # Serialize the input
|
||||
prediction = [output_component.deserialize(prediction[o])
|
||||
for o, output_component in enumerate(self.output_components)]
|
||||
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
||||
@ -348,8 +358,8 @@ class Interface:
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
predictions, durations = self.run_prediction(
|
||||
processed_input, return_duration=True)
|
||||
processed_output = [output_component.postprocess(
|
||||
predictions[i]) if predictions[i] is not None else None for i, output_component in enumerate(self.output_components)]
|
||||
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
|
||||
for i, output_component in enumerate(self.output_components)]
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(self, raw_input):
|
||||
|
@ -25,6 +25,8 @@ from gradio import encryptor
|
||||
from gradio import queue
|
||||
from functools import wraps
|
||||
import io
|
||||
import traceback
|
||||
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
@ -138,6 +140,7 @@ def static_resource(path):
|
||||
return send_file(os.path.join(STATIC_PATH_LIB, path))
|
||||
|
||||
|
||||
# TODO(@aliabid94): this throws a 500 error if app.auth is None (should probalbly just redirect to '/')
|
||||
@app.route('/login', methods=["GET", "POST"])
|
||||
def login():
|
||||
if request.method == "GET":
|
||||
|
@ -31,6 +31,12 @@ class OutputComponent(Component):
|
||||
"""
|
||||
return y
|
||||
|
||||
def deserialize(self, x):
|
||||
"""
|
||||
Convert from serialized output (e.g. base64 representation) from a call() to the interface to a human-readable version of the output (path of an image, etc.)
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class Textbox(OutputComponent):
|
||||
'''
|
||||
@ -115,6 +121,21 @@ class Label(OutputComponent):
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
|
||||
def deserialize(self, y):
|
||||
# 4 cases: (1): {'label': 'lion'}, {'label': 'lion', 'confidences':...}, {'lion': 0.46, ...}, 'lion'
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or ('label' in y and not('confidences' in y.keys())))):
|
||||
if isinstance(y, str):
|
||||
return y
|
||||
else:
|
||||
return y['label']
|
||||
elif self.type == "confidences" or self.type == "auto":
|
||||
if 'confidences' in y.keys() and isinstance(y['confidences'], list):
|
||||
return {k['label']:k['confidence'] for k in y['confidences']}
|
||||
else:
|
||||
return y
|
||||
raise ValueError("Unable to deserialize output: {}".format(y))
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
@ -145,15 +166,13 @@ class Image(OutputComponent):
|
||||
Demos: image_mod.py, webcam.py
|
||||
'''
|
||||
|
||||
def __init__(self, type="auto", labeled_segments=False, plot=False, label=None):
|
||||
def __init__(self, type="auto", plot=False, label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
labeled_segments (bool): If True, expects a two-element tuple to be returned. The first element of the tuple is the image of format specified by type. The second element is a list of tuples, where each tuple represents a labeled segment within the image. The first element of the tuple is the string label of the segment, followed by 4 floats that represent the left-x, top-y, right-x, and bottom-y coordinates of the bounding box.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.labeled_segments = labeled_segments
|
||||
if plot:
|
||||
warnings.warn(
|
||||
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.", DeprecationWarning)
|
||||
@ -166,16 +185,11 @@ class Image(OutputComponent):
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"image": {},
|
||||
"segmented_image": {"labeled_segments": True},
|
||||
"plot": {"type": "plot"},
|
||||
"pil": {"type": "pil"}
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.labeled_segments:
|
||||
y, coordinates = y
|
||||
else:
|
||||
coordinates = []
|
||||
if self.type == "auto":
|
||||
if isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
@ -195,17 +209,17 @@ class Image(OutputComponent):
|
||||
y = np.array(y)
|
||||
out_y = processing_utils.encode_array_to_base64(y)
|
||||
elif dtype == "file":
|
||||
try:
|
||||
requests.get(y)
|
||||
out_y = processing_utils.encode_url_to_base64(y)
|
||||
except requests.exceptions.MissingSchema:
|
||||
out_y = processing_utils.encode_file_to_base64(y)
|
||||
out_y = processing_utils.encode_url_or_file_to_base64(y)
|
||||
elif dtype == "plot":
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + dtype +
|
||||
". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
return out_y, coordinates
|
||||
return out_y
|
||||
|
||||
def deserialize(self, x):
|
||||
y = processing_utils.decode_base64_to_file(x).name
|
||||
return y
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
@ -253,6 +267,9 @@ class Video(OutputComponent):
|
||||
"data": processing_utils.encode_file_to_base64(y, type="video")
|
||||
}
|
||||
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to image file
|
||||
@ -369,11 +386,14 @@ class Audio(OutputComponent):
|
||||
file = tempfile.NamedTemporaryFile(prefix="sample", suffix=".wav", delete=False)
|
||||
processing_utils.audio_to_file(sample_rate, data, file.name)
|
||||
y = file.name
|
||||
return processing_utils.encode_file_to_base64(y, type="audio", ext="wav")
|
||||
return processing_utils.encode_url_or_file_to_base64(y, type="audio", ext="wav")
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type +
|
||||
". Please choose from: 'numpy', 'file'.")
|
||||
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to audio file
|
||||
|
@ -21,6 +21,22 @@ def decode_base64_to_image(encoding):
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
def get_url_or_file_as_bytes(path):
|
||||
try:
|
||||
return requests.get(path).content
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def encode_url_or_file_to_base64(path, type="image", ext=None, header=True):
|
||||
try:
|
||||
requests.get(path)
|
||||
return encode_url_to_base64(path, type, ext, header)
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
return encode_file_to_base64(path, type, ext, header)
|
||||
|
||||
|
||||
def encode_file_to_base64(f, type="image", ext=None, header=True):
|
||||
with open(f, "rb") as file:
|
||||
encoded_string = base64.b64encode(file.read())
|
||||
|
@ -11,8 +11,6 @@ from io import StringIO
|
||||
import warnings
|
||||
import paramiko
|
||||
|
||||
DEBUG_MODE = False
|
||||
|
||||
|
||||
def handler(chan, host, port):
|
||||
sock = socket.socket()
|
||||
@ -55,8 +53,8 @@ def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
|
||||
thr.start()
|
||||
|
||||
|
||||
def verbose(s):
|
||||
if DEBUG_MODE:
|
||||
def verbose(s, debug_mode=False):
|
||||
if debug_mode:
|
||||
print(s)
|
||||
|
||||
|
||||
|
@ -65,12 +65,13 @@ def ipython_check():
|
||||
Check if interface is launching from iPython (not colab)
|
||||
:return is_ipython (bool): True or False
|
||||
"""
|
||||
is_ipython = False
|
||||
try: # Check if running interactively using ipython.
|
||||
from IPython import get_ipython
|
||||
get_ipython()
|
||||
is_ipython = True
|
||||
if get_ipython() is not None:
|
||||
is_ipython = True
|
||||
except (ImportError, NameError):
|
||||
is_ipython = False
|
||||
pass
|
||||
return is_ipython
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ GOLDEN_PATH = "test/golden/{}/{}.png"
|
||||
TOLERANCE = 0.1
|
||||
TIMEOUT = 10
|
||||
|
||||
GAP_TO_SCREENSHOT = 1
|
||||
GAP_TO_SCREENSHOT = 2
|
||||
|
||||
def wait_for_url(url):
|
||||
for i in range(TIMEOUT):
|
||||
|
@ -1,7 +1,12 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
import gradio as gr
|
||||
|
||||
class TestHuggingFaceModels(unittest.TestCase):
|
||||
"""
|
||||
WARNING: These tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
|
||||
"""
|
||||
|
||||
class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_text_generation(self):
|
||||
model_type = "text_generation"
|
||||
interface_info = gr.external.get_huggingface_interface("gpt2", api_key=None, alias=None)
|
||||
@ -45,7 +50,8 @@ class TestHuggingFaceModels(unittest.TestCase):
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
|
||||
|
||||
class TestHuggingFaceSpaces(unittest.TestCase):
|
||||
|
||||
class TestHuggingFaceSpaceAPI(unittest.TestCase):
|
||||
def test_english_to_spanish(self):
|
||||
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
@ -63,7 +69,46 @@ class TestLoadInterface(unittest.TestCase):
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_models_src(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
class TestCallingLoadInterface(unittest.TestCase):
|
||||
def test_sentiment_model(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['Positive'], 0.5)
|
||||
|
||||
def test_image_classification_model(self):
|
||||
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("test/images/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
|
||||
def test_translation_model(self):
|
||||
interface_info = gr.external.load_interface("models/t5-base")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("My name is Sarah and I live in London")
|
||||
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
|
||||
|
||||
def test_numerical_to_label_space(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("male", 77, 10)
|
||||
self.assertLess(output['Survives'], 0.5)
|
||||
|
||||
def test_image_to_image_space(self):
|
||||
def assertIsFile(path):
|
||||
if not pathlib.Path(path).resolve().is_file():
|
||||
raise AssertionError("File does not exist: %s" % str(path))
|
||||
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("test/images/lion.jpg")
|
||||
assertIsFile(output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -83,7 +83,7 @@ class TestImage(unittest.TestCase):
|
||||
open_and_rotate,
|
||||
gr.inputs.Image(shape=(30, 10), type="file"),
|
||||
"image")
|
||||
output = iface.process([x_img])[0][0][0]
|
||||
output = iface.process([x_img])[0][0]
|
||||
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))
|
||||
|
||||
|
||||
|
99
test/test_networking.py
Normal file
99
test/test_networking.py
Normal file
@ -0,0 +1,99 @@
|
||||
from gradio import networking
|
||||
import gradio as gr
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import ipaddress
|
||||
import requests
|
||||
import warnings
|
||||
|
||||
|
||||
class TestUser(unittest.TestCase):
|
||||
def test_id(self):
|
||||
user = networking.User("test")
|
||||
self.assertEqual(user.get_id(), "test")
|
||||
|
||||
def test_load_user(self):
|
||||
user = networking.load_user("test")
|
||||
self.assertEqual(user.get_id(), "test")
|
||||
|
||||
class TestIPAddress(unittest.TestCase):
|
||||
def test_get_ip(self):
|
||||
ip = networking.get_local_ip_address()
|
||||
try: # check whether ip is valid
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError:
|
||||
self.fail("Invalid IP address")
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_get_ip_without_internet(self, mock_get):
|
||||
mock_get.side_effect = requests.ConnectionError()
|
||||
ip = networking.get_local_ip_address()
|
||||
self.assertEqual(ip, "No internet connection")
|
||||
|
||||
class TestPort(unittest.TestCase):
|
||||
def test_port_is_in_range(self):
|
||||
start = 7860
|
||||
end = 7960
|
||||
try:
|
||||
port = networking.get_first_available_port(start, end)
|
||||
self.assertTrue(start <= port <= end)
|
||||
except OSError:
|
||||
warnings.warn("Unable to test, no ports available")
|
||||
|
||||
def test_same_port_is_returned(self):
|
||||
start = 7860
|
||||
end = 7960
|
||||
try:
|
||||
port1 = networking.get_first_available_port(start, end)
|
||||
port2 = networking.get_first_available_port(start, end)
|
||||
self.assertEqual(port1, port2)
|
||||
except OSError:
|
||||
warnings.warn("Unable to test, no ports available")
|
||||
|
||||
|
||||
class TestFlaskRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = gr.Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = self.app.test_client()
|
||||
|
||||
def test_get_main_route(self):
|
||||
response = self.client.get('/')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_config_route(self):
|
||||
response = self.client.get('/config/')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_static_route(self):
|
||||
response = self.client.get('/static/bundle.css')
|
||||
self.assertEqual(response.status_code, 302) # This should redirect to static files.
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
gr.reset_all()
|
||||
|
||||
|
||||
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = gr.Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
|
||||
self.client = self.app.test_client()
|
||||
|
||||
def test_get_login_route(self):
|
||||
response = self.client.get('/login')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_post_login(self):
|
||||
response = self.client.post('/login', data=dict(username="test", password="correct_password"))
|
||||
self.assertEqual(response.status_code, 302)
|
||||
response = self.client.post('/login', data=dict(username="test", password="incorrect_password"))
|
||||
self.assertEqual(response.status_code, 401)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
gr.reset_all()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -62,15 +62,15 @@ class TestImage(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
|
||||
image_output = gr.outputs.Image()
|
||||
self.assertTrue(image_output.postprocess(y_img)[0].startswith(""))
|
||||
self.assertTrue(image_output.postprocess(np.array(y_img))[0].startswith(""))
|
||||
self.assertTrue(image_output.postprocess(y_img).startswith(""))
|
||||
self.assertTrue(image_output.postprocess(np.array(y_img)).startswith(""))
|
||||
|
||||
def test_in_interface(self):
|
||||
def generate_noise(width, height):
|
||||
return np.random.randint(0, 256, (width, height, 3))
|
||||
|
||||
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
|
||||
self.assertTrue(iface.process([10, 20])[0][0][0].startswith("data:image/png;base64"))
|
||||
self.assertTrue(iface.process([10, 20])[0][0].startswith("data:image/png;base64"))
|
||||
|
||||
class TestKeyValues(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
|
33
test/test_tunneling.py
Normal file
33
test/test_tunneling.py
Normal file
@ -0,0 +1,33 @@
|
||||
import io
|
||||
import sys
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from gradio import tunneling
|
||||
|
||||
|
||||
# class TestTunneling(unittest.TestCase):
|
||||
# pass
|
||||
# @mock.patch("pkg_resources.require")
|
||||
# def test_should_fail_with_distribution_not_found(self, mock_require):
|
||||
|
||||
class TestVerbose(unittest.TestCase):
|
||||
"""Unncessary tests but just including them for the sake of completion."""
|
||||
|
||||
def setUp(self):
|
||||
self.message = "print test"
|
||||
self.capturedOutput = io.StringIO() # Create StringIO object
|
||||
sys.stdout = self.capturedOutput # and redirect stdout.
|
||||
|
||||
def test_verbose_debug_true(self):
|
||||
tunneling.verbose(self.message, debug_mode=True)
|
||||
self.assertEqual(self.capturedOutput.getvalue().strip(), self.message)
|
||||
|
||||
def test_verbose_debug_false(self):
|
||||
tunneling.verbose(self.message, debug_mode=False)
|
||||
self.assertEqual(self.capturedOutput.getvalue().strip(), '')
|
||||
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -3,6 +3,7 @@ import unittest
|
||||
import pkg_resources
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
import requests
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
@ -29,7 +30,7 @@ class TestUtils(unittest.TestCase):
|
||||
@mock.patch("requests.get")
|
||||
def test_should_warn_with_connection_error(self, mock_get):
|
||||
|
||||
mock_get.side_effect = ConnectionError()
|
||||
mock_get.side_effect = requests.ConnectionError()
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
@ -45,7 +46,49 @@ class TestUtils(unittest.TestCase):
|
||||
warnings.simplefilter("always")
|
||||
version_check()
|
||||
self.assertEqual(str(w[-1].message), "package URL does not contain version info.")
|
||||
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
error_analytics("placeholder")
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_successful(self, mock_post):
|
||||
error_analytics("placeholder")
|
||||
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
@mock.patch("gradio.utils.error_analytics")
|
||||
def test_colab_check_sends_analytics_on_import_fail(self, mock_error_analytics, mock_get_ipython):
|
||||
|
||||
mock_get_ipython.side_effect = ImportError()
|
||||
colab_check()
|
||||
mock_error_analytics.assert_called_with("NameError")
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_colab_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
assert colab_check() is False
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_ipython_check_import_fail(self, mock_get_ipython):
|
||||
mock_get_ipython.side_effect = ImportError()
|
||||
assert ipython_check() is False
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_ipython_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
assert ipython_check() is False
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_readme_to_html_doesnt_crah_on_connection_error(self, mock_get):
|
||||
mock_get.side_effect = requests.ConnectionError()
|
||||
readme_to_html("placeholder")
|
||||
|
||||
def test_readme_to_html_correct_parse(self):
|
||||
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user