mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
add testing and flask changes
This commit is contained in:
parent
7dc1a83ce3
commit
498a615ccf
3
.gitignore
vendored
3
.gitignore
vendored
@ -22,4 +22,5 @@ dist/*
|
||||
docs.json
|
||||
*.bak
|
||||
demo/tmp.zip
|
||||
demo/flagged
|
||||
demo/flagged
|
||||
test.txt
|
@ -450,7 +450,7 @@ class Dataframe(InputComponent):
|
||||
else:
|
||||
return pd.DataFrame(x)
|
||||
if self.col_count == 1:
|
||||
x = x[0]
|
||||
x = [row[0] for row in x]
|
||||
if self.type == "numpy":
|
||||
return np.array(x)
|
||||
elif self.type == "array":
|
||||
|
@ -4,13 +4,11 @@ interface using the input and output types.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import webbrowser
|
||||
# import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
from distutils.version import StrictVersion
|
||||
import pkg_resources
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -21,7 +19,6 @@ import weakref
|
||||
import analytics
|
||||
import os
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
try:
|
||||
@ -46,7 +43,7 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
capture_session=False, explain_by=None, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
@ -110,6 +107,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.explain_by = explain_by
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -177,7 +175,8 @@ class Interface:
|
||||
"description": self.description,
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging
|
||||
"allow_flagging": self.allow_flagging,
|
||||
"allow_interpretation": self.explain_by is not None
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -190,7 +189,14 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
@ -210,7 +216,7 @@ class Interface:
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
start = time.time()
|
||||
if self.capture_session and not (self.session is None):
|
||||
if self.capture_session and self.session is not None:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
@ -277,33 +283,22 @@ class Interface:
|
||||
share (bool): whether to create a publicly shareable link from your computer for the interface.
|
||||
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
|
||||
Returns
|
||||
httpd (str): HTTPServer object
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd, thread = networking.start_simple_server(
|
||||
self, output_directory, self.server_name, server_port=self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
config = self.get_config_file()
|
||||
networking.set_config(config)
|
||||
networking.set_meta_tags(self.title, self.description, self.thumbnail)
|
||||
|
||||
server_port, app, thread = networking.start_server(
|
||||
self, self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
self.server = app
|
||||
|
||||
utils.version_check()
|
||||
is_colab = utils.colab_check()
|
||||
if not is_colab:
|
||||
if not networking.url_ok(path_to_local_server):
|
||||
@ -373,20 +368,7 @@ class Interface:
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
||||
config = self.get_config_file()
|
||||
config["share_url"] = share_url
|
||||
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
|
||||
networking.set_config(config, output_directory)
|
||||
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
|
||||
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
|
||||
|
||||
if debug:
|
||||
while True:
|
||||
@ -394,14 +376,15 @@ class Interface:
|
||||
time.sleep(0.1)
|
||||
|
||||
launch_method = 'browser' if inbrowser else 'inline'
|
||||
data = {'launch_method': launch_method,
|
||||
|
||||
if self.analytics_enabled:
|
||||
data = {
|
||||
'launch_method': launch_method,
|
||||
'is_google_colab': is_colab,
|
||||
'is_sharing_on': share,
|
||||
'share_url': share_url,
|
||||
'ip_address': ip_address
|
||||
}
|
||||
|
||||
if self.analytics_enabled:
|
||||
}
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-launched-analytics/',
|
||||
data=data)
|
||||
@ -412,7 +395,8 @@ class Interface:
|
||||
if not is_in_interactive_mode:
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
|
||||
def reset_all():
|
||||
|
@ -83,6 +83,13 @@ def gradio():
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
path = None
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
@app.route("/api/predict/", methods=["POST"])
|
||||
def predict():
|
||||
raw_input = request.json["data"]
|
||||
@ -128,14 +135,14 @@ def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
||||
|
||||
def start_server(interface, directory_to_serve=None, server_name=None, server_port=None):
|
||||
def start_server(interface, server_port=None):
|
||||
if server_port is None:
|
||||
server_port = INITIAL_PORT_VALUE
|
||||
port = get_first_available_port(
|
||||
server_port, server_port + TRY_NUM_PORTS
|
||||
)
|
||||
app.interface = interface
|
||||
process = Process(target=app.run, kwargs={"port": port, "debug": True})
|
||||
process = Process(target=app.run, kwargs={"port": port})
|
||||
process.start()
|
||||
return port, app, process
|
||||
|
||||
|
@ -68,8 +68,6 @@ class Label(OutputComponent):
|
||||
Output type: Union[Dict[str, float], str, int, float]
|
||||
'''
|
||||
|
||||
LABEL_KEY = "label"
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, type="auto", label=None):
|
||||
@ -85,7 +83,7 @@ class Label(OutputComponent):
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
return {"label": str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
@ -95,11 +93,11 @@ class Label(OutputComponent):
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
return {
|
||||
self.LABEL_KEY: sorted_pred[0][0],
|
||||
self.CONFIDENCES_KEY: [
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
self.LABEL_KEY: pred[0],
|
||||
self.CONFIDENCE_KEY: pred[1]
|
||||
"label": pred[0],
|
||||
"confidence": pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,7 +1,22 @@
|
||||
import requests
|
||||
import pkg_resources
|
||||
from distutils.version import StrictVersion
|
||||
from IPython import get_ipython
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
def version_check():
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
|
||||
def error_analytics(type):
|
||||
"""
|
||||
|
@ -11,5 +11,5 @@ def reverse_audio(audio):
|
||||
|
||||
io = gr.Interface(reverse_audio, "microphone", "audio")
|
||||
|
||||
# io.test_launch()
|
||||
io.test_launch()
|
||||
io.launch()
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Demo: (Slider, Dropdown, Radio, CheckboxGroup, Checkbox) -> (Textbox)
|
||||
print("0 -->")
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def sentence_builder(quantity, animal, place, activity_list, morning):
|
||||
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""
|
||||
|
||||
@ -24,5 +24,8 @@ io = gr.Interface(
|
||||
[8, "cat", "zoo", ["ate"], True],
|
||||
])
|
||||
|
||||
io.test_launch()
|
||||
# io.test_launch()
|
||||
a = 1
|
||||
print("start -->", a)
|
||||
a += 1
|
||||
io.launch()
|
||||
|
@ -119,5 +119,5 @@ gradio/static/js/vendor/webcam.min.js
|
||||
gradio/static/js/vendor/white-theme.js
|
||||
gradio/templates/index.html
|
||||
test/test_inputs.py
|
||||
test/test_networking.py
|
||||
test/test_interfaces.py
|
||||
test/test_outputs.py
|
@ -450,7 +450,7 @@ class Dataframe(InputComponent):
|
||||
else:
|
||||
return pd.DataFrame(x)
|
||||
if self.col_count == 1:
|
||||
x = x[0]
|
||||
x = [row[0] for row in x]
|
||||
if self.type == "numpy":
|
||||
return np.array(x)
|
||||
elif self.type == "array":
|
||||
|
@ -4,19 +4,11 @@ interface using the input and output types.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import webbrowser
|
||||
# import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import Image
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils, processing_utils
|
||||
from distutils.version import StrictVersion
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from gradio import processing_utils
|
||||
import PIL
|
||||
import pkg_resources
|
||||
from gradio import networking, strings, utils
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -26,9 +18,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
try:
|
||||
@ -199,6 +189,14 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
@ -285,33 +283,22 @@ class Interface:
|
||||
share (bool): whether to create a publicly shareable link from your computer for the interface.
|
||||
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
|
||||
Returns
|
||||
httpd (str): HTTPServer object
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd, thread = networking.start_simple_server(
|
||||
self, self.server_name, server_port=self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
config = self.get_config_file()
|
||||
networking.set_config(config)
|
||||
networking.set_meta_tags(self.title, self.description, self.thumbnail)
|
||||
|
||||
server_port, app, thread = networking.start_server(
|
||||
self, self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
self.server = app
|
||||
|
||||
utils.version_check()
|
||||
is_colab = utils.colab_check()
|
||||
if not is_colab:
|
||||
if not networking.url_ok(path_to_local_server):
|
||||
@ -381,20 +368,7 @@ class Interface:
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
||||
config = self.get_config_file()
|
||||
config["share_url"] = share_url
|
||||
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
|
||||
networking.set_config(config, output_directory)
|
||||
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
|
||||
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
|
||||
|
||||
if debug:
|
||||
while True:
|
||||
@ -402,14 +376,15 @@ class Interface:
|
||||
time.sleep(0.1)
|
||||
|
||||
launch_method = 'browser' if inbrowser else 'inline'
|
||||
data = {'launch_method': launch_method,
|
||||
|
||||
if self.analytics_enabled:
|
||||
data = {
|
||||
'launch_method': launch_method,
|
||||
'is_google_colab': is_colab,
|
||||
'is_sharing_on': share,
|
||||
'share_url': share_url,
|
||||
'ip_address': ip_address
|
||||
}
|
||||
|
||||
if self.analytics_enabled:
|
||||
}
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-launched-analytics/',
|
||||
data=data)
|
||||
@ -420,96 +395,9 @@ class Interface:
|
||||
if not is_in_interactive_mode:
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
def tokenize_text(self, text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return tokens, leave_one_out_tokens
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def tokenize_image(self, image):
|
||||
image = np.array(processing_utils.decode_base64_to_image(image))
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(self, tokens, leave_one_out_tokens, text):
|
||||
original_label = ""
|
||||
original_confidence = 0
|
||||
tokens = text.split()
|
||||
|
||||
input_text = " ".join(tokens)
|
||||
original_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
scores = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
input_text = " ".join(input_text)
|
||||
raw_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
scores.append(original_confidence - output[original_label])
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores[idx]))
|
||||
return scores_by_char
|
||||
|
||||
def score_image(self, leave_one_out_tokens, image):
|
||||
original_output = self.process([image])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
image_interface = self.input_interfaces[0]
|
||||
shape = processing_utils.decode_base64_to_image(image).size
|
||||
output_scores = np.full((shape[1], shape[0]), 0.0)
|
||||
|
||||
for mask, input_image in leave_one_out_tokens:
|
||||
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||
input_image)
|
||||
raw_output = self.process([input_image_base64])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
score = original_confidence - output[original_label]
|
||||
output_scores += score * mask
|
||||
max_val = np.max(np.abs(output_scores))
|
||||
if max_val > 0:
|
||||
output_scores = output_scores / max_val
|
||||
return output_scores.tolist()
|
||||
|
||||
def simple_explanation(self, x):
|
||||
if isinstance(self.input_interfaces[0], Textbox):
|
||||
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
|
||||
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
|
||||
elif isinstance(self.input_interfaces[0], Image):
|
||||
leave_one_out_tokens = self.tokenize_image(x[0])
|
||||
return [self.score_image(leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input type")
|
||||
|
||||
def explain(self, x):
|
||||
if self.explain_by == "default":
|
||||
return self.simple_explanation(x)
|
||||
else:
|
||||
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
|
||||
return self.explain_by(*preprocessed_x)
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
|
@ -83,6 +83,13 @@ def gradio():
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
path = None
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
@app.route("/api/predict/", methods=["POST"])
|
||||
def predict():
|
||||
raw_input = request.json["data"]
|
||||
@ -128,14 +135,14 @@ def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
||||
|
||||
def start_server(interface, server_name=None, server_port=None):
|
||||
def start_server(interface, server_port=None):
|
||||
if server_port is None:
|
||||
server_port = INITIAL_PORT_VALUE
|
||||
port = get_first_available_port(
|
||||
server_port, server_port + TRY_NUM_PORTS
|
||||
)
|
||||
app.interface = interface
|
||||
process = Process(target=app.run, kwargs={"port": port, "debug": True})
|
||||
process = Process(target=app.run, kwargs={"port": port})
|
||||
process.start()
|
||||
return port, app, process
|
||||
|
||||
|
@ -68,8 +68,6 @@ class Label(OutputComponent):
|
||||
Output type: Union[Dict[str, float], str, int, float]
|
||||
'''
|
||||
|
||||
LABEL_KEY = "label"
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, type="auto", label=None):
|
||||
@ -85,7 +83,7 @@ class Label(OutputComponent):
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
return {"label": str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
@ -95,11 +93,11 @@ class Label(OutputComponent):
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
return {
|
||||
self.LABEL_KEY: sorted_pred[0][0],
|
||||
self.CONFIDENCES_KEY: [
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
self.LABEL_KEY: pred[0],
|
||||
self.CONFIDENCE_KEY: pred[1]
|
||||
"label": pred[0],
|
||||
"confidence": pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,7 +1,22 @@
|
||||
import requests
|
||||
import pkg_resources
|
||||
from distutils.version import StrictVersion
|
||||
from IPython import get_ipython
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
def version_check():
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
|
||||
def error_analytics(type):
|
||||
"""
|
||||
|
@ -2,69 +2,71 @@ import unittest
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import numpy as np
|
||||
import scipy
|
||||
import os
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
assert iface.process(["Hello"])[0] == ["olleH"]
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
|
||||
iface = gr.Interface(lambda x: x*x, "number", "number")
|
||||
assert iface.process(["5"])[0] == [25]
|
||||
self.assertEqual(iface.process(["5"])[0], [25])
|
||||
|
||||
class TestSlider(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: str(x) + " cats", "slider", "textbox")
|
||||
assert iface.process([4])[0] == ["4 cats"]
|
||||
self.assertEqual(iface.process([4])[0], ["4 cats"])
|
||||
|
||||
|
||||
class TestCheckbox(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: "yes" if x else "no", "checkbox", "textbox")
|
||||
assert iface.process([False])[0] == ["no"]
|
||||
self.assertEqual(iface.process([False])[0], ["no"])
|
||||
|
||||
|
||||
class TestCheckboxGroup(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: "|".join(x), checkboxes, "textbox")
|
||||
assert iface.process([["a", "c"]])[0] == ["a|c"]
|
||||
assert iface.process([[]])[0] == [""]
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
|
||||
self.assertEqual(iface.process([[]])[0], [""])
|
||||
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes, "textbox")
|
||||
assert iface.process([["a", "c"]])[0] == ["0|2"]
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["0|2"])
|
||||
|
||||
|
||||
class TestRadio(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
radio = gr.inputs.Radio(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, radio, "textbox")
|
||||
assert iface.process(["c"])[0] == ["cc"]
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
radio = gr.inputs.Radio(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: 2 * x, radio, "number")
|
||||
assert iface.process(["c"])[0] == [4]
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
|
||||
|
||||
class TestDropdown(unittest.TestCase):
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
dropdown = gr.inputs.Dropdown(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown, "textbox")
|
||||
assert iface.process(["c"])[0] == ["cc"]
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown, "number")
|
||||
assert iface.process(["c"])[0] == [4]
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_component(self):
|
||||
def test_as_component(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
image_input = gr.inputs.Image()
|
||||
assert image_input.preprocess(x_img).shape == (68, 61 ,3)
|
||||
self.assertEqual(image_input.preprocess(x_img).shape, (68, 61 ,3))
|
||||
image_input = gr.inputs.Image(image_mode="L", shape=(25, 25))
|
||||
assert image_input.preprocess(x_img).shape == (25, 25)
|
||||
self.assertEqual(image_input.preprocess(x_img).shape, (25, 25))
|
||||
image_input = gr.inputs.Image(shape=(30, 10), type="pil")
|
||||
assert image_input.preprocess(x_img).size == (30, 10)
|
||||
self.assertEqual(image_input.preprocess(x_img).size, (30, 10))
|
||||
|
||||
|
||||
def test_interface(self):
|
||||
def test_in_interface(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
|
||||
def open_and_rotate(img_file):
|
||||
@ -76,31 +78,58 @@ class TestImage(unittest.TestCase):
|
||||
gr.inputs.Image(shape=(30, 10), type="file"),
|
||||
"image")
|
||||
output = iface.process([x_img])[0][0]
|
||||
assert gr.processing_utils.decode_base64_to_image(output).size == (10, 30)
|
||||
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def test_component(self):
|
||||
def test_as_component(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
audio_input = gr.inputs.Audio()
|
||||
output = audio_input.preprocess(x_wav)
|
||||
print(output[0])
|
||||
print(output[1].shape)
|
||||
assert output[0] == 44000
|
||||
assert output[1].shape == (100, 2)
|
||||
self.assertEqual(output[0], 8000)
|
||||
self.assertEqual(output[1].shape, (8046,))
|
||||
|
||||
def test_in_interface(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
def max_amplitude_from_wav_file(wav_file):
|
||||
_, data = scipy.io.wavfile.read(wav_file.name)
|
||||
return np.max(data)
|
||||
|
||||
def test_interface(self):
|
||||
pass
|
||||
|
||||
iface = gr.Interface(
|
||||
max_amplitude_from_wav_file,
|
||||
gr.inputs.Audio(type="file"),
|
||||
"number")
|
||||
self.assertEqual(iface.process([x_wav])[0], [5239])
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
pass
|
||||
def test_in_interface(self):
|
||||
x_file = gr.test_data.BASE64_AUDIO
|
||||
def get_size_of_file(file_obj):
|
||||
return os.path.getsize(file_obj.name)
|
||||
|
||||
iface = gr.Interface(
|
||||
get_size_of_file, "file", "number")
|
||||
self.assertEqual(iface.process([x_file])[0], [16362])
|
||||
|
||||
|
||||
class TestDataframe(unittest.TestCase):
|
||||
pass
|
||||
def test_as_component(self):
|
||||
x_data = [["Tim",12,False],["Jan",24,True]]
|
||||
dataframe_input = gr.inputs.Dataframe(headers=["Name","Age","Member"])
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
self.assertEqual(output["Age"][1], 24)
|
||||
self.assertEqual(output["Member"][0], False)
|
||||
|
||||
def test_in_interface(self):
|
||||
x_data = [[1,2,3],[4,5,6]]
|
||||
iface = gr.Interface(np.max, "numpy", "number")
|
||||
self.assertEqual(iface.process([x_data])[0], [6])
|
||||
|
||||
x_data = [["Tim"], ["Jon"], ["Sal"]]
|
||||
def get_last(l):
|
||||
return l[-1]
|
||||
iface = gr.Interface(get_last, "list", "text")
|
||||
self.assertEqual(iface.process([x_data])[0], ["Sal"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
4
test/test_interfaces.py
Normal file
4
test/test_interfaces.py
Normal file
@ -0,0 +1,4 @@
|
||||
import unittest
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,63 +0,0 @@
|
||||
import unittest
|
||||
from gradio import networking
|
||||
from gradio import inputs
|
||||
from gradio import outputs
|
||||
import socket
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class TestGetAvailablePort(unittest.TestCase):
|
||||
def test_get_first_available_port_by_blocking_port(self):
|
||||
initial = 7000
|
||||
final = 8000
|
||||
port_found = False
|
||||
for port in range(initial, final):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.close()
|
||||
port_found = True
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
if port_found:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
|
||||
new_port = networking.get_first_available_port(initial, final)
|
||||
s.close()
|
||||
self.assertFalse(port == new_port)
|
||||
|
||||
|
||||
# class TestSetSampleData(unittest.TestCase):
|
||||
# def test_set_sample_data(self):
|
||||
# test_array = ["test1", "test2", "test3"]
|
||||
# temp_dir = tempfile.mkdtemp()
|
||||
# inp = inputs.Sketchpad()
|
||||
# out = outputs.Label()
|
||||
# networking.build_template(temp_dir, inp, out)
|
||||
# networking.set_sample_data_in_config_file(temp_dir, test_array)
|
||||
# # We need to come up with a better way so that the config file isn't invalid json unless
|
||||
# # the following parameters are set... (TODO: abidlabs)
|
||||
# networking.set_always_flagged_in_config_file(temp_dir, False)
|
||||
# networking.set_disabled_in_config_file(temp_dir, False)
|
||||
# config_file = os.path.join(temp_dir, 'static/config.json')
|
||||
# with open(config_file) as json_file:
|
||||
# data = json.load(json_file)
|
||||
# self.assertTrue(test_array == data["sample_inputs"])
|
||||
|
||||
# class TestCopyFiles(unittest.TestCase):
|
||||
# def test_copy_files(self):
|
||||
# filename = "a.txt"
|
||||
# with tempfile.TemporaryDirectory() as temp_src:
|
||||
# with open(os.path.join(temp_src, "a.txt"), "w+") as f:
|
||||
# f.write('Hi')
|
||||
# with tempfile.TemporaryDirectory() as temp_dest:
|
||||
# self.assertFalse(os.path.exists(os.path.join(temp_dest, filename)))
|
||||
# networking.copy_files(temp_src, temp_dest)
|
||||
# self.assertTrue(os.path.exists(os.path.join(temp_dest, filename)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user