mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
mid progress
This commit is contained in:
parent
6d7b6e8122
commit
7dc1a83ce3
@ -292,6 +292,7 @@ class Image(InputComponent):
|
||||
|
||||
def preprocess(self, x):
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
@ -305,7 +306,7 @@ class Image(InputComponent):
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file":
|
||||
file_obj = tempfile.NamedTemporaryFile()
|
||||
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
else:
|
||||
|
@ -55,9 +55,9 @@ class Textbox(OutputComponent):
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "str" or self.type == "auto":
|
||||
return y
|
||||
elif self.type == "number":
|
||||
return str(y)
|
||||
elif self.type == "number":
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
|
||||
|
||||
|
@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {
|
||||
<div class="output_interfaces">
|
||||
</div>
|
||||
<div class="panel_buttons">
|
||||
<input class="interpret panel_button" type="button" value="INTERPRET"/>
|
||||
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
|
||||
<div class="screenshot_logo">
|
||||
<img src="/static/img/logo_inline.png">
|
||||
@ -165,17 +166,14 @@ function gradio(config, fn, target, example_file_path) {
|
||||
io_master.last_output = null;
|
||||
});
|
||||
|
||||
if (config["allow_screenshot"] && !config["allow_flagging"]) {
|
||||
if (config["allow_screenshot"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
target.find(".flag").css("display", "none")
|
||||
}
|
||||
if (!config["allow_screenshot"] && config["allow_flagging"]) {
|
||||
if (config["allow_flagging"]) {
|
||||
target.find(".flag").css("visibility", "visible");
|
||||
target.find(".screenshot").css("display", "none")
|
||||
}
|
||||
if (config["allow_screenshot"] && config["allow_flagging"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
target.find(".flag").css("visibility", "visible")
|
||||
if (config["allow_interpretation"]) {
|
||||
target.find(".interpret").css("visibility", "visible");
|
||||
}
|
||||
if (config["examples"]) {
|
||||
target.find(".examples").removeClass("invisible");
|
||||
|
File diff suppressed because one or more lines are too long
@ -11,5 +11,5 @@ def reverse_audio(audio):
|
||||
|
||||
io = gr.Interface(reverse_audio, "microphone", "audio")
|
||||
|
||||
io.test_launch()
|
||||
# io.test_launch()
|
||||
io.launch()
|
||||
|
@ -292,6 +292,7 @@ class Image(InputComponent):
|
||||
|
||||
def preprocess(self, x):
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
@ -305,7 +306,7 @@ class Image(InputComponent):
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file":
|
||||
file_obj = tempfile.NamedTemporaryFile()
|
||||
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
else:
|
||||
|
@ -7,9 +7,15 @@ import tempfile
|
||||
import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import Image
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
from gradio import networking, strings, utils, processing_utils
|
||||
from distutils.version import StrictVersion
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from gradio import processing_utils
|
||||
import PIL
|
||||
import pkg_resources
|
||||
import requests
|
||||
import random
|
||||
@ -20,6 +26,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
@ -46,7 +53,7 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
capture_session=False, explain_by=None, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
@ -110,6 +117,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.explain_by = explain_by
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -177,7 +185,8 @@ class Interface:
|
||||
"description": self.description,
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging
|
||||
"allow_flagging": self.allow_flagging,
|
||||
"allow_interpretation": self.explain_by is not None
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -190,7 +199,6 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
@ -210,7 +218,7 @@ class Interface:
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
start = time.time()
|
||||
if self.capture_session and not (self.session is None):
|
||||
if self.capture_session and self.session is not None:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
@ -284,7 +292,7 @@ class Interface:
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd, thread = networking.start_simple_server(
|
||||
self, output_directory, self.server_name, server_port=self.server_port)
|
||||
self, self.server_name, server_port=self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
|
||||
@ -414,6 +422,94 @@ class Interface:
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
def tokenize_text(self, text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return tokens, leave_one_out_tokens
|
||||
|
||||
def tokenize_image(self, image):
|
||||
image = np.array(processing_utils.decode_base64_to_image(image))
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(self, tokens, leave_one_out_tokens, text):
|
||||
original_label = ""
|
||||
original_confidence = 0
|
||||
tokens = text.split()
|
||||
|
||||
input_text = " ".join(tokens)
|
||||
original_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
scores = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
input_text = " ".join(input_text)
|
||||
raw_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
scores.append(original_confidence - output[original_label])
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores[idx]))
|
||||
return scores_by_char
|
||||
|
||||
def score_image(self, leave_one_out_tokens, image):
|
||||
original_output = self.process([image])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
image_interface = self.input_interfaces[0]
|
||||
shape = processing_utils.decode_base64_to_image(image).size
|
||||
output_scores = np.full((shape[1], shape[0]), 0.0)
|
||||
|
||||
for mask, input_image in leave_one_out_tokens:
|
||||
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||
input_image)
|
||||
raw_output = self.process([input_image_base64])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
score = original_confidence - output[original_label]
|
||||
output_scores += score * mask
|
||||
max_val = np.max(np.abs(output_scores))
|
||||
if max_val > 0:
|
||||
output_scores = output_scores / max_val
|
||||
return output_scores.tolist()
|
||||
|
||||
def simple_explanation(self, x):
|
||||
if isinstance(self.input_interfaces[0], Textbox):
|
||||
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
|
||||
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
|
||||
elif isinstance(self.input_interfaces[0], Image):
|
||||
leave_one_out_tokens = self.tokenize_image(x[0])
|
||||
return [self.score_image(leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input type")
|
||||
|
||||
def explain(self, x):
|
||||
if self.explain_by == "default":
|
||||
return self.simple_explanation(x)
|
||||
else:
|
||||
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
|
||||
return self.explain_by(*preprocessed_x)
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
|
@ -128,7 +128,7 @@ def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
||||
|
||||
def start_server(interface, directory_to_serve=None, server_name=None, server_port=None):
|
||||
def start_server(interface, server_name=None, server_port=None):
|
||||
if server_port is None:
|
||||
server_port = INITIAL_PORT_VALUE
|
||||
port = get_first_available_port(
|
||||
|
@ -54,10 +54,10 @@ class Textbox(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "str":
|
||||
return y
|
||||
elif self.type == "number" or self.type == "auto":
|
||||
if self.type == "str" or self.type == "auto":
|
||||
return str(y)
|
||||
elif self.type == "number":
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,47 +1,97 @@
|
||||
import unittest
|
||||
import gradio as gr
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_component(self):
|
||||
box = gr.inputs.Textbox()
|
||||
assert box.preprocess("Hello") == "Hello"
|
||||
box = gr.inputs.Textbox(type="str")
|
||||
assert box.preprocess(125) == 125
|
||||
|
||||
def test_interface(self):
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
assert interface.process("Hello") == "olleH"
|
||||
assert iface.process(["Hello"])[0] == ["olleH"]
|
||||
iface = gr.Interface(lambda x: x*x, "number", "number")
|
||||
assert interface.process(5) == 25
|
||||
|
||||
assert iface.process(["5"])[0] == [25]
|
||||
|
||||
class TestSlider(unittest.TestCase):
|
||||
pass
|
||||
def test_interface(self):
|
||||
iface = gr.Interface(lambda x: str(x) + " cats", "slider", "textbox")
|
||||
assert iface.process([4])[0] == ["4 cats"]
|
||||
|
||||
|
||||
class TestCheckbox(unittest.TestCase):
|
||||
pass
|
||||
def test_interface(self):
|
||||
iface = gr.Interface(lambda x: "yes" if x else "no", "checkbox", "textbox")
|
||||
assert iface.process([False])[0] == ["no"]
|
||||
|
||||
|
||||
class TestCheckboxGroup(unittest.TestCase):
|
||||
pass
|
||||
def test_interface(self):
|
||||
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: "|".join(x), checkboxes, "textbox")
|
||||
assert iface.process([["a", "c"]])[0] == ["a|c"]
|
||||
assert iface.process([[]])[0] == [""]
|
||||
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes, "textbox")
|
||||
assert iface.process([["a", "c"]])[0] == ["0|2"]
|
||||
|
||||
|
||||
class TestRadio(unittest.TestCase):
|
||||
pass
|
||||
def test_interface(self):
|
||||
radio = gr.inputs.Radio(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, radio, "textbox")
|
||||
assert iface.process(["c"])[0] == ["cc"]
|
||||
radio = gr.inputs.Radio(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: 2 * x, radio, "number")
|
||||
assert iface.process(["c"])[0] == [4]
|
||||
|
||||
|
||||
class TestDropdown(unittest.TestCase):
|
||||
pass
|
||||
def test_interface(self):
|
||||
dropdown = gr.inputs.Dropdown(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown, "textbox")
|
||||
assert iface.process(["c"])[0] == ["cc"]
|
||||
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown, "number")
|
||||
assert iface.process(["c"])[0] == [4]
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
pass
|
||||
def test_component(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
image_input = gr.inputs.Image()
|
||||
assert image_input.preprocess(x_img).shape == (68, 61 ,3)
|
||||
image_input = gr.inputs.Image(image_mode="L", shape=(25, 25))
|
||||
assert image_input.preprocess(x_img).shape == (25, 25)
|
||||
image_input = gr.inputs.Image(shape=(30, 10), type="pil")
|
||||
assert image_input.preprocess(x_img).size == (30, 10)
|
||||
|
||||
|
||||
def test_interface(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
|
||||
def open_and_rotate(img_file):
|
||||
img = PIL.Image.open(img_file)
|
||||
return img.rotate(90, expand=True)
|
||||
|
||||
iface = gr.Interface(
|
||||
open_and_rotate,
|
||||
gr.inputs.Image(shape=(30, 10), type="file"),
|
||||
"image")
|
||||
output = iface.process([x_img])[0][0]
|
||||
assert gr.processing_utils.decode_base64_to_image(output).size == (10, 30)
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
pass
|
||||
def test_component(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
audio_input = gr.inputs.Audio()
|
||||
output = audio_input.preprocess(x_wav)
|
||||
print(output[0])
|
||||
print(output[1].shape)
|
||||
assert output[0] == 44000
|
||||
assert output[1].shape == (100, 2)
|
||||
|
||||
|
||||
def test_interface(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user