add testing and flask changes

This commit is contained in:
Ali Abid 2020-09-21 11:51:39 -07:00
parent 7dc1a83ce3
commit 498a615ccf
20 changed files with 371 additions and 356 deletions

3
.gitignore vendored
View File

@ -22,4 +22,5 @@ dist/*
docs.json
*.bak
demo/tmp.zip
demo/flagged
demo/flagged
test.txt

View File

@ -450,7 +450,7 @@ class Dataframe(InputComponent):
else:
return pd.DataFrame(x)
if self.col_count == 1:
x = x[0]
x = [row[0] for row in x]
if self.type == "numpy":
return np.array(x)
elif self.type == "array":

View File

@ -4,13 +4,11 @@ interface using the input and output types.
"""
import tempfile
import webbrowser
# import webbrowser
from gradio.inputs import InputComponent
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils
from distutils.version import StrictVersion
import pkg_resources
import requests
import random
import time
@ -21,7 +19,6 @@ import weakref
import analytics
import os
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
try:
@ -46,7 +43,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
capture_session=False, explain_by=None, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True):
@ -110,6 +107,7 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.explain_by = explain_by
self.session = None
self.server_name = server_name
self.title = title
@ -177,7 +175,8 @@ class Interface:
"description": self.description,
"thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging
"allow_flagging": self.allow_flagging,
"allow_interpretation": self.explain_by is not None
}
try:
param_names = inspect.getfullargspec(self.predict[0])[0]
@ -190,7 +189,14 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
return config
def process(self, raw_input, predict_fn=None):
@ -210,7 +216,7 @@ class Interface:
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not (self.session is None):
if self.capture_session and self.session is not None:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
@ -277,33 +283,22 @@ class Interface:
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
Returns
httpd (str): HTTPServer object
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd, thread = networking.start_simple_server(
self, output_directory, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
config = self.get_config_file()
networking.set_config(config)
networking.set_meta_tags(self.title, self.description, self.thumbnail)
server_port, app, thread = networking.start_server(
self, self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
self.server = app
utils.version_check()
is_colab = utils.colab_check()
if not is_colab:
if not networking.url_ok(path_to_local_server):
@ -373,20 +368,7 @@ class Interface:
else:
display(IFrame(path_to_local_server, width=1000, height=500))
config = self.get_config_file()
config["share_url"] = share_url
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
networking.set_config(config, output_directory)
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
if debug:
while True:
@ -394,14 +376,15 @@ class Interface:
time.sleep(0.1)
launch_method = 'browser' if inbrowser else 'inline'
data = {'launch_method': launch_method,
if self.analytics_enabled:
data = {
'launch_method': launch_method,
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address
}
if self.analytics_enabled:
}
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data)
@ -412,7 +395,8 @@ class Interface:
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return httpd, path_to_local_server, share_url
return app, path_to_local_server, share_url
def reset_all():

View File

@ -83,6 +83,13 @@ def gradio():
def config():
return jsonify(app.app_globals["config"])
@app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path):
if path == "None":
path = None
app.app_globals["config"]["share_url"] = path
return jsonify(success=True)
@app.route("/api/predict/", methods=["POST"])
def predict():
raw_input = request.json["data"]
@ -128,14 +135,14 @@ def file(path):
return send_file(os.path.join(os.getcwd(), path))
def start_server(interface, directory_to_serve=None, server_name=None, server_port=None):
def start_server(interface, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
app.interface = interface
process = Process(target=app.run, kwargs={"port": port, "debug": True})
process = Process(target=app.run, kwargs={"port": port})
process.start()
return port, app, process

View File

@ -68,8 +68,6 @@ class Label(OutputComponent):
Output type: Union[Dict[str, float], str, int, float]
'''
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
CONFIDENCES_KEY = "confidences"
def __init__(self, num_top_classes=None, type="auto", label=None):
@ -85,7 +83,7 @@ class Label(OutputComponent):
def postprocess(self, y):
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
return {self.LABEL_KEY: str(y)}
return {"label": str(y)}
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
sorted_pred = sorted(
y.items(),
@ -95,11 +93,11 @@ class Label(OutputComponent):
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
return {
self.LABEL_KEY: sorted_pred[0][0],
self.CONFIDENCES_KEY: [
"label": sorted_pred[0][0],
"confidences": [
{
self.LABEL_KEY: pred[0],
self.CONFIDENCE_KEY: pred[1]
"label": pred[0],
"confidence": pred[1]
} for pred in sorted_pred
]
}

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,22 @@
import requests
import pkg_resources
from distutils.version import StrictVersion
from IPython import get_ipython
analytics_url = 'https://api.gradio.app/'
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
def version_check():
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
def error_analytics(type):
"""

View File

@ -11,5 +11,5 @@ def reverse_audio(audio):
io = gr.Interface(reverse_audio, "microphone", "audio")
# io.test_launch()
io.test_launch()
io.launch()

View File

@ -1,8 +1,8 @@
# Demo: (Slider, Dropdown, Radio, CheckboxGroup, Checkbox) -> (Textbox)
print("0 -->")
import gradio as gr
def sentence_builder(quantity, animal, place, activity_list, morning):
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""
@ -24,5 +24,8 @@ io = gr.Interface(
[8, "cat", "zoo", ["ate"], True],
])
io.test_launch()
# io.test_launch()
a = 1
print("start -->", a)
a += 1
io.launch()

View File

@ -119,5 +119,5 @@ gradio/static/js/vendor/webcam.min.js
gradio/static/js/vendor/white-theme.js
gradio/templates/index.html
test/test_inputs.py
test/test_networking.py
test/test_interfaces.py
test/test_outputs.py

View File

@ -450,7 +450,7 @@ class Dataframe(InputComponent):
else:
return pd.DataFrame(x)
if self.col_count == 1:
x = x[0]
x = [row[0] for row in x]
if self.type == "numpy":
return np.array(x)
elif self.type == "array":

View File

@ -4,19 +4,11 @@ interface using the input and output types.
"""
import tempfile
import webbrowser
# import webbrowser
from gradio.inputs import InputComponent
from gradio.inputs import Image
from gradio.inputs import Textbox
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources
from gradio import networking, strings, utils
import requests
import random
import time
@ -26,9 +18,7 @@ import sys
import weakref
import analytics
import os
import numpy as np
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
try:
@ -199,6 +189,14 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
return config
def process(self, raw_input, predict_fn=None):
@ -285,33 +283,22 @@ class Interface:
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
Returns
httpd (str): HTTPServer object
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd, thread = networking.start_simple_server(
self, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
config = self.get_config_file()
networking.set_config(config)
networking.set_meta_tags(self.title, self.description, self.thumbnail)
server_port, app, thread = networking.start_server(
self, self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
self.server = app
utils.version_check()
is_colab = utils.colab_check()
if not is_colab:
if not networking.url_ok(path_to_local_server):
@ -381,20 +368,7 @@ class Interface:
else:
display(IFrame(path_to_local_server, width=1000, height=500))
config = self.get_config_file()
config["share_url"] = share_url
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
networking.set_config(config, output_directory)
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
if debug:
while True:
@ -402,14 +376,15 @@ class Interface:
time.sleep(0.1)
launch_method = 'browser' if inbrowser else 'inline'
data = {'launch_method': launch_method,
if self.analytics_enabled:
data = {
'launch_method': launch_method,
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address
}
if self.analytics_enabled:
}
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data)
@ -420,96 +395,9 @@ class Interface:
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return httpd, path_to_local_server, share_url
def tokenize_text(self, text):
leave_one_out_tokens = []
tokens = text.split()
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return tokens, leave_one_out_tokens
return app, path_to_local_server, share_url
def tokenize_image(self, image):
image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text):
original_label = ""
original_confidence = 0
tokens = text.split()
input_text = " ".join(tokens)
original_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label])
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores[idx]))
return scores_by_char
def score_image(self, leave_one_out_tokens, image):
original_output = self.process([image])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
image_interface = self.input_interfaces[0]
shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
raw_output = self.process([input_image_base64])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
score = original_confidence - output[original_label]
output_scores += score * mask
max_val = np.max(np.abs(output_scores))
if max_val > 0:
output_scores = output_scores / max_val
return output_scores.tolist()
def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(x[0])
return [self.score_image(leave_one_out_tokens, x[0])]
else:
print("Not valid input type")
def explain(self, x):
if self.explain_by == "default":
return self.simple_explanation(x)
else:
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
def reset_all():
for io in Interface.get_instances():

View File

@ -83,6 +83,13 @@ def gradio():
def config():
return jsonify(app.app_globals["config"])
@app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path):
if path == "None":
path = None
app.app_globals["config"]["share_url"] = path
return jsonify(success=True)
@app.route("/api/predict/", methods=["POST"])
def predict():
raw_input = request.json["data"]
@ -128,14 +135,14 @@ def file(path):
return send_file(os.path.join(os.getcwd(), path))
def start_server(interface, server_name=None, server_port=None):
def start_server(interface, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
app.interface = interface
process = Process(target=app.run, kwargs={"port": port, "debug": True})
process = Process(target=app.run, kwargs={"port": port})
process.start()
return port, app, process

View File

@ -68,8 +68,6 @@ class Label(OutputComponent):
Output type: Union[Dict[str, float], str, int, float]
'''
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
CONFIDENCES_KEY = "confidences"
def __init__(self, num_top_classes=None, type="auto", label=None):
@ -85,7 +83,7 @@ class Label(OutputComponent):
def postprocess(self, y):
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
return {self.LABEL_KEY: str(y)}
return {"label": str(y)}
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
sorted_pred = sorted(
y.items(),
@ -95,11 +93,11 @@ class Label(OutputComponent):
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
return {
self.LABEL_KEY: sorted_pred[0][0],
self.CONFIDENCES_KEY: [
"label": sorted_pred[0][0],
"confidences": [
{
self.LABEL_KEY: pred[0],
self.CONFIDENCE_KEY: pred[1]
"label": pred[0],
"confidence": pred[1]
} for pred in sorted_pred
]
}

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,22 @@
import requests
import pkg_resources
from distutils.version import StrictVersion
from IPython import get_ipython
analytics_url = 'https://api.gradio.app/'
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
def version_check():
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
def error_analytics(type):
"""

View File

@ -2,69 +2,71 @@ import unittest
import gradio as gr
import PIL
import numpy as np
import scipy
import os
class TestTextbox(unittest.TestCase):
def test_interface(self):
def test_in_interface(self):
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
assert iface.process(["Hello"])[0] == ["olleH"]
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
iface = gr.Interface(lambda x: x*x, "number", "number")
assert iface.process(["5"])[0] == [25]
self.assertEqual(iface.process(["5"])[0], [25])
class TestSlider(unittest.TestCase):
def test_interface(self):
def test_in_interface(self):
iface = gr.Interface(lambda x: str(x) + " cats", "slider", "textbox")
assert iface.process([4])[0] == ["4 cats"]
self.assertEqual(iface.process([4])[0], ["4 cats"])
class TestCheckbox(unittest.TestCase):
def test_interface(self):
def test_in_interface(self):
iface = gr.Interface(lambda x: "yes" if x else "no", "checkbox", "textbox")
assert iface.process([False])[0] == ["no"]
self.assertEqual(iface.process([False])[0], ["no"])
class TestCheckboxGroup(unittest.TestCase):
def test_interface(self):
def test_in_interface(self):
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes, "textbox")
assert iface.process([["a", "c"]])[0] == ["a|c"]
assert iface.process([[]])[0] == [""]
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
self.assertEqual(iface.process([[]])[0], [""])
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes, "textbox")
assert iface.process([["a", "c"]])[0] == ["0|2"]
self.assertEqual(iface.process([["a", "c"]])[0], ["0|2"])
class TestRadio(unittest.TestCase):
def test_interface(self):
def test_in_interface(self):
radio = gr.inputs.Radio(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, radio, "textbox")
assert iface.process(["c"])[0] == ["cc"]
self.assertEqual(iface.process(["c"])[0], ["cc"])
radio = gr.inputs.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, radio, "number")
assert iface.process(["c"])[0] == [4]
self.assertEqual(iface.process(["c"])[0], [4])
class TestDropdown(unittest.TestCase):
def test_interface(self):
def test_in_interface(self):
dropdown = gr.inputs.Dropdown(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, dropdown, "textbox")
assert iface.process(["c"])[0] == ["cc"]
self.assertEqual(iface.process(["c"])[0], ["cc"])
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, dropdown, "number")
assert iface.process(["c"])[0] == [4]
self.assertEqual(iface.process(["c"])[0], [4])
class TestImage(unittest.TestCase):
def test_component(self):
def test_as_component(self):
x_img = gr.test_data.BASE64_IMAGE
image_input = gr.inputs.Image()
assert image_input.preprocess(x_img).shape == (68, 61 ,3)
self.assertEqual(image_input.preprocess(x_img).shape, (68, 61 ,3))
image_input = gr.inputs.Image(image_mode="L", shape=(25, 25))
assert image_input.preprocess(x_img).shape == (25, 25)
self.assertEqual(image_input.preprocess(x_img).shape, (25, 25))
image_input = gr.inputs.Image(shape=(30, 10), type="pil")
assert image_input.preprocess(x_img).size == (30, 10)
self.assertEqual(image_input.preprocess(x_img).size, (30, 10))
def test_interface(self):
def test_in_interface(self):
x_img = gr.test_data.BASE64_IMAGE
def open_and_rotate(img_file):
@ -76,31 +78,58 @@ class TestImage(unittest.TestCase):
gr.inputs.Image(shape=(30, 10), type="file"),
"image")
output = iface.process([x_img])[0][0]
assert gr.processing_utils.decode_base64_to_image(output).size == (10, 30)
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))
class TestAudio(unittest.TestCase):
def test_component(self):
def test_as_component(self):
x_wav = gr.test_data.BASE64_AUDIO
audio_input = gr.inputs.Audio()
output = audio_input.preprocess(x_wav)
print(output[0])
print(output[1].shape)
assert output[0] == 44000
assert output[1].shape == (100, 2)
self.assertEqual(output[0], 8000)
self.assertEqual(output[1].shape, (8046,))
def test_in_interface(self):
x_wav = gr.test_data.BASE64_AUDIO
def max_amplitude_from_wav_file(wav_file):
_, data = scipy.io.wavfile.read(wav_file.name)
return np.max(data)
def test_interface(self):
pass
iface = gr.Interface(
max_amplitude_from_wav_file,
gr.inputs.Audio(type="file"),
"number")
self.assertEqual(iface.process([x_wav])[0], [5239])
class TestFile(unittest.TestCase):
pass
def test_in_interface(self):
x_file = gr.test_data.BASE64_AUDIO
def get_size_of_file(file_obj):
return os.path.getsize(file_obj.name)
iface = gr.Interface(
get_size_of_file, "file", "number")
self.assertEqual(iface.process([x_file])[0], [16362])
class TestDataframe(unittest.TestCase):
pass
def test_as_component(self):
x_data = [["Tim",12,False],["Jan",24,True]]
dataframe_input = gr.inputs.Dataframe(headers=["Name","Age","Member"])
output = dataframe_input.preprocess(x_data)
self.assertEqual(output["Age"][1], 24)
self.assertEqual(output["Member"][0], False)
def test_in_interface(self):
x_data = [[1,2,3],[4,5,6]]
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(iface.process([x_data])[0], [6])
x_data = [["Tim"], ["Jon"], ["Sal"]]
def get_last(l):
return l[-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data])[0], ["Sal"])
if __name__ == '__main__':

4
test/test_interfaces.py Normal file
View File

@ -0,0 +1,4 @@
import unittest
if __name__ == '__main__':
unittest.main()

View File

@ -1,63 +0,0 @@
import unittest
from gradio import networking
from gradio import inputs
from gradio import outputs
import socket
import tempfile
import os
import json
class TestGetAvailablePort(unittest.TestCase):
def test_get_first_available_port_by_blocking_port(self):
initial = 7000
final = 8000
port_found = False
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
s.close()
port_found = True
break
except OSError:
pass
if port_found:
s = socket.socket() # create a socket object
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
new_port = networking.get_first_available_port(initial, final)
s.close()
self.assertFalse(port == new_port)
# class TestSetSampleData(unittest.TestCase):
# def test_set_sample_data(self):
# test_array = ["test1", "test2", "test3"]
# temp_dir = tempfile.mkdtemp()
# inp = inputs.Sketchpad()
# out = outputs.Label()
# networking.build_template(temp_dir, inp, out)
# networking.set_sample_data_in_config_file(temp_dir, test_array)
# # We need to come up with a better way so that the config file isn't invalid json unless
# # the following parameters are set... (TODO: abidlabs)
# networking.set_always_flagged_in_config_file(temp_dir, False)
# networking.set_disabled_in_config_file(temp_dir, False)
# config_file = os.path.join(temp_dir, 'static/config.json')
# with open(config_file) as json_file:
# data = json.load(json_file)
# self.assertTrue(test_array == data["sample_inputs"])
# class TestCopyFiles(unittest.TestCase):
# def test_copy_files(self):
# filename = "a.txt"
# with tempfile.TemporaryDirectory() as temp_src:
# with open(os.path.join(temp_src, "a.txt"), "w+") as f:
# f.write('Hi')
# with tempfile.TemporaryDirectory() as temp_dest:
# self.assertFalse(os.path.exists(os.path.join(temp_dest, filename)))
# networking.copy_files(temp_src, temp_dest)
# self.assertTrue(os.path.exists(os.path.join(temp_dest, filename)))
if __name__ == '__main__':
unittest.main()

File diff suppressed because one or more lines are too long