mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
support auto output datatype
This commit is contained in:
parent
95c7b78b60
commit
c6585b5da6
@ -294,12 +294,6 @@ class Image(InputComponent):
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
|
||||
def process_example(self, example):
|
||||
if os.path.exists(example):
|
||||
return processing_utils.encode_file_to_base64(example)
|
||||
else:
|
||||
return example
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
|
@ -197,6 +197,17 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
else:
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith("/file/"):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
with open(self.path[6:], "rb") as f:
|
||||
self.wfile.write(f.read())
|
||||
else:
|
||||
super().do_GET()
|
||||
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
||||
|
@ -15,6 +15,9 @@ import warnings
|
||||
import tempfile
|
||||
import scipy
|
||||
import os
|
||||
import pandas as pd
|
||||
import PIL
|
||||
from types import ModuleType
|
||||
|
||||
class OutputComponent(Component):
|
||||
"""
|
||||
@ -28,10 +31,10 @@ class Textbox(OutputComponent):
|
||||
Output type: Union[str, float, int]
|
||||
'''
|
||||
|
||||
def __init__(self, type="str", label=None):
|
||||
def __init__(self, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "str" expects a string, "number" expects a float value.
|
||||
type (str): Type of value to be passed to component. "str" expects a string, "number" expects a float value, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.type = type
|
||||
@ -51,7 +54,7 @@ class Textbox(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "str":
|
||||
if self.type == "str" or self.type == "auto":
|
||||
return y
|
||||
elif self.type == "number":
|
||||
return str(y)
|
||||
@ -69,19 +72,21 @@ class Label(OutputComponent):
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, label=None):
|
||||
def __init__(self, num_top_classes=None, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
num_top_classes (int): number of most confident classes to show.
|
||||
type (str): Type of value to be passed to component. "value" expects a single out label, "confidences" expects a dictionary mapping labels to confidence scores, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.num_top_classes = num_top_classes
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, y):
|
||||
if isinstance(y, str) or isinstance(y, Number):
|
||||
return {"label": str(y)}
|
||||
elif isinstance(y, dict):
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
key=operator.itemgetter(1),
|
||||
@ -98,8 +103,6 @@ class Label(OutputComponent):
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
elif isinstance(y, int) or isinstance(y, float):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
else:
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
@ -123,10 +126,10 @@ class Image(OutputComponent):
|
||||
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot]
|
||||
'''
|
||||
|
||||
def __init__(self, type="numpy", plot=False, label=None):
|
||||
def __init__(self, type="auto", plot=False, label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image, "plot" expects a matplotlib.pyplot object.
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
@ -146,16 +149,27 @@ class Image(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type in ["numpy", "pil"]:
|
||||
if self.type == "pil":
|
||||
if self.type == "auto":
|
||||
if isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, PIL.Image.Image):
|
||||
dtype = "pil"
|
||||
elif isinstance(y, str):
|
||||
dtype = "file"
|
||||
elif isinstance(y, ModuleType):
|
||||
dtype = "plot"
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype in ["numpy", "pil"]:
|
||||
if dtype == "pil":
|
||||
y = np.array(y)
|
||||
return processing_utils.encode_array_to_base64(y)
|
||||
elif self.type == "file":
|
||||
elif dtype == "file":
|
||||
return processing_utils.encode_file_to_base64(y)
|
||||
elif self.type == "plot":
|
||||
elif dtype == "plot":
|
||||
return processing_utils.encode_plot_to_base64(y)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
raise ValueError("Unknown type: " + dtype + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
@ -235,10 +249,10 @@ class Audio(OutputComponent):
|
||||
Output type: Union[Tuple[int, numpy.array], str]
|
||||
'''
|
||||
|
||||
def __init__(self, type="numpy", label=None):
|
||||
def __init__(self, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file.
|
||||
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.type = type
|
||||
@ -256,8 +270,8 @@ class Audio(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type in ["numpy", "file"]:
|
||||
if self.type == "numpy":
|
||||
if self.type in ["numpy", "file", "auto"]:
|
||||
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
|
||||
file = tempfile.NamedTemporaryFile()
|
||||
scipy.io.wavfile.write(file, y[0], y[1])
|
||||
y = file.name
|
||||
@ -348,11 +362,11 @@ class Dataframe(OutputComponent):
|
||||
Output type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
|
||||
"""
|
||||
|
||||
def __init__(self, headers=None, type="pandas", label=None):
|
||||
def __init__(self, headers=None, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
headers (List[str]): Header names to dataframe.
|
||||
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array.
|
||||
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.type = type
|
||||
@ -369,17 +383,26 @@ class Dataframe(OutputComponent):
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"dataframe": {"type": "pandas"},
|
||||
"dataframe": {},
|
||||
"numpy": {"type": "numpy"},
|
||||
"matrix": {"type": "array"},
|
||||
"list": {"type": "array"},
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "pandas":
|
||||
if self.type == "auto":
|
||||
if isinstance(y, pd.core.frame.DataFrame):
|
||||
dtype = "pandas"
|
||||
elif isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, list):
|
||||
dtype = "array"
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype == "pandas":
|
||||
return {"headers": list(y.columns), "data": y.values.tolist()}
|
||||
elif self.type in ("numpy", "array"):
|
||||
if self.type == "numpy":
|
||||
elif dtype in ("numpy", "array"):
|
||||
if dtype == "numpy":
|
||||
y = y.tolist()
|
||||
if len(y) == 0 or not isinstance(y[0], list):
|
||||
y = [y]
|
||||
|
@ -1,4 +1,4 @@
|
||||
.input_text {
|
||||
textarea.input_text, input.input_text {
|
||||
resize: none;
|
||||
width: 100%;
|
||||
font-size: 18px;
|
||||
@ -11,6 +11,3 @@
|
||||
padding: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
.input_text_saliency {
|
||||
display: none;
|
||||
}
|
||||
|
@ -7,17 +7,13 @@ body#lib {
|
||||
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {
|
||||
border: none;
|
||||
font: inherit;
|
||||
cursor: pointer;
|
||||
outline: inherit;
|
||||
-webkit-appearance: none;
|
||||
}
|
||||
select {
|
||||
font: inherit;
|
||||
}
|
||||
label, input[type=radio], input[type=checkbox], select, input[type=range] {
|
||||
cursor: pointer;
|
||||
}
|
||||
button, input[type="submit"], input[type="reset"], input[type="button"], select[type="submit"] {
|
||||
label, input[type=radio], input[type=checkbox], select, input[type=range], button, input[type="submit"], input[type="reset"], input[type="button"], select[type="submit"] {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
function gradio(config, fn, target) {
|
||||
function gradio(config, fn, target, example_file_path) {
|
||||
target = $(target);
|
||||
target.html(`
|
||||
<div class="share invisible">
|
||||
@ -42,6 +42,7 @@ function gradio(config, fn, target) {
|
||||
io_master.fn = fn
|
||||
io_master.target = target;
|
||||
io_master.config = config;
|
||||
io_master.example_file_path = example_file_path;
|
||||
|
||||
let input_to_object_map = {
|
||||
"csv" : {},
|
||||
@ -243,7 +244,7 @@ function gradio(config, fn, target) {
|
||||
|
||||
return io_master;
|
||||
}
|
||||
function gradio_url(config, url, target) {
|
||||
function gradio_url(config, url, target, example_file_path) {
|
||||
return gradio(config, function(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
$.ajax({type: "POST",
|
||||
@ -253,7 +254,7 @@ function gradio_url(config, url, target) {
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, target);
|
||||
}, target, example_file_path);
|
||||
}
|
||||
function saveAs(uri, filename) {
|
||||
var link = document.createElement('a');
|
||||
|
@ -215,25 +215,28 @@ const image_input = {
|
||||
}
|
||||
},
|
||||
load_example_preview: function(data) {
|
||||
return "<img src="+data+" height=100>"
|
||||
return "<img src='"+this.io_master.example_file_path+data+"' height=100>"
|
||||
},
|
||||
load_example: function(data) {
|
||||
load_example: function(example_data) {
|
||||
example_data = this.io_master.example_file_path + example_data;
|
||||
let io = this;
|
||||
if (this.source == "canvas") {
|
||||
this.clear();
|
||||
let ctx = this.context;
|
||||
var img = new Image;
|
||||
let dimension = this.target.find(".canvas_holder canvas").width();
|
||||
img.onload = function(){
|
||||
ctx.clearRect(0,0,dimension,dimension);
|
||||
ctx.drawImage(img,0,0,dimension,dimension);
|
||||
};
|
||||
img.src = data;
|
||||
} else {
|
||||
io.target.find(".upload_zone").hide();
|
||||
io.target.find(".image_display").removeClass("hide");
|
||||
io.set_image_data(data, /*update_editor=*/true);
|
||||
io.state = "IMAGE_LOADED";
|
||||
}
|
||||
toDataURL(example_data, function(data) {
|
||||
if (io.source == "canvas") {
|
||||
io.clear();
|
||||
let ctx = this.context;
|
||||
var img = new Image;
|
||||
let dimension = io.target.find(".canvas_holder canvas").width();
|
||||
img.onload = function(){
|
||||
ctx.clearRect(0,0,dimension,dimension);
|
||||
ctx.drawImage(img,0,0,dimension,dimension);
|
||||
};
|
||||
img.src = data;
|
||||
} else {
|
||||
io.target.find(".upload_zone").hide();
|
||||
io.target.find(".image_display").removeClass("hide");
|
||||
io.set_image_data(data, /*update_editor=*/true);
|
||||
io.state = "IMAGE_LOADED";
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
const textbox_input = {
|
||||
html: `<textarea class="input_text"></textarea>
|
||||
<div class='input_text_saliency'></div>`,
|
||||
html: `<textarea class="input_text"></textarea>`,
|
||||
one_line_html: `<input type="text" class="input_text">`,
|
||||
init: function(opts) {
|
||||
if (opts.lines) {
|
||||
if (opts.lines > 1) {
|
||||
this.target.find(".input_text").attr("rows", opts.lines).css("height", "auto");
|
||||
} else {
|
||||
this.target.html(this.one_line_html);
|
||||
}
|
||||
if (opts.placeholder) {
|
||||
this.target.find(".input_text").attr("placeholder", opts.placeholder)
|
||||
@ -16,22 +18,8 @@ const textbox_input = {
|
||||
text = this.target.find(".input_text").val();
|
||||
this.io_master.input(this.id, text);
|
||||
},
|
||||
output: function(data) {
|
||||
this.target.find(".input_text").hide();
|
||||
this.target.find(".input_text_saliency").show();
|
||||
this.target.find(".input_text_saliency").empty();
|
||||
let html = '';
|
||||
let text = this.target.find(".input_text").val();
|
||||
let index = 0;
|
||||
data.forEach(function(value, index) {
|
||||
html += `<span style='background-color:rgba(75,150,255,${value})'>${text.charAt(index)}</span>`;
|
||||
})
|
||||
$(".input_text_saliency").html(html);
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find(".input_text").val("");
|
||||
this.target.find(".input_text_saliency").hide();
|
||||
this.target.find(".input_text").show();
|
||||
},
|
||||
load_example: function(data) {
|
||||
this.target.find(".input_text").val(data);
|
||||
|
@ -6,6 +6,20 @@ String.prototype.format = function() {
|
||||
return a
|
||||
}
|
||||
|
||||
function toDataURL(url, callback) {
|
||||
var xhr = new XMLHttpRequest();
|
||||
xhr.onload = function() {
|
||||
var reader = new FileReader();
|
||||
reader.onloadend = function() {
|
||||
callback(reader.result);
|
||||
}
|
||||
reader.readAsDataURL(xhr.response);
|
||||
};
|
||||
xhr.open('GET', url);
|
||||
xhr.responseType = 'blob';
|
||||
xhr.send();
|
||||
}
|
||||
|
||||
function resizeImage(base64Str, max_width, max_height, callback) {
|
||||
var img = new Image();
|
||||
img.src = base64Str;
|
||||
|
@ -118,7 +118,7 @@
|
||||
<script src="/static/js/gradio.js"></script>
|
||||
<script>
|
||||
$.getJSON("static/config.json", function(config) {
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target");
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target", "/file/");
|
||||
});
|
||||
const copyToClipboard = str => {
|
||||
const el = document.createElement('textarea');
|
||||
|
@ -11,11 +11,10 @@ def image_mod(image):
|
||||
|
||||
gr.Interface(image_mod,
|
||||
gr.inputs.Image(type="pil"),
|
||||
gr.outputs.Image(type="pil"),
|
||||
"image",
|
||||
examples=[
|
||||
["images/cheetah1.jpg"],
|
||||
["images/cheetah2.jpg"],
|
||||
["images/lion.jpg"],
|
||||
],
|
||||
live=True,
|
||||
).launch(share=True)
|
||||
|
@ -294,12 +294,6 @@ class Image(InputComponent):
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
|
||||
def process_example(self, example):
|
||||
if os.path.exists(example):
|
||||
return processing_utils.encode_file_to_base64(example)
|
||||
else:
|
||||
return example
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
|
@ -197,6 +197,17 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
else:
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith("/file/"):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
with open(self.path[6:], "rb") as f:
|
||||
self.wfile.write(f.read())
|
||||
else:
|
||||
super().do_GET()
|
||||
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
||||
|
@ -15,6 +15,9 @@ import warnings
|
||||
import tempfile
|
||||
import scipy
|
||||
import os
|
||||
import pandas as pd
|
||||
import PIL
|
||||
from types import ModuleType
|
||||
|
||||
class OutputComponent(Component):
|
||||
"""
|
||||
@ -28,10 +31,10 @@ class Textbox(OutputComponent):
|
||||
Output type: Union[str, float, int]
|
||||
'''
|
||||
|
||||
def __init__(self, type="str", label=None):
|
||||
def __init__(self, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "str" expects a string, "number" expects a float value.
|
||||
type (str): Type of value to be passed to component. "str" expects a string, "number" expects a float value, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.type = type
|
||||
@ -53,7 +56,7 @@ class Textbox(OutputComponent):
|
||||
def postprocess(self, y):
|
||||
if self.type == "str":
|
||||
return y
|
||||
elif self.type == "number":
|
||||
elif self.type == "number" or self.type == "auto":
|
||||
return str(y)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
|
||||
@ -69,19 +72,21 @@ class Label(OutputComponent):
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, label=None):
|
||||
def __init__(self, num_top_classes=None, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
num_top_classes (int): number of most confident classes to show.
|
||||
type (str): Type of value to be passed to component. "value" expects a single out label, "confidences" expects a dictionary mapping labels to confidence scores, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.num_top_classes = num_top_classes
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, y):
|
||||
if isinstance(y, str) or isinstance(y, Number):
|
||||
return {"label": str(y)}
|
||||
elif isinstance(y, dict):
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
key=operator.itemgetter(1),
|
||||
@ -98,8 +103,6 @@ class Label(OutputComponent):
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
elif isinstance(y, int) or isinstance(y, float):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
else:
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
@ -123,10 +126,10 @@ class Image(OutputComponent):
|
||||
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot]
|
||||
'''
|
||||
|
||||
def __init__(self, type="numpy", plot=False, label=None):
|
||||
def __init__(self, type="auto", plot=False, label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image, "plot" expects a matplotlib.pyplot object.
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
@ -146,16 +149,27 @@ class Image(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type in ["numpy", "pil"]:
|
||||
if self.type == "pil":
|
||||
if self.type == "auto":
|
||||
if isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, PIL.Image.Image):
|
||||
dtype = "pil"
|
||||
elif isinstance(y, str):
|
||||
dtype = "file"
|
||||
elif isinstance(y, ModuleType):
|
||||
dtype = "plot"
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype in ["numpy", "pil"]:
|
||||
if dtype == "pil":
|
||||
y = np.array(y)
|
||||
return processing_utils.encode_array_to_base64(y)
|
||||
elif self.type == "file":
|
||||
elif dtype == "file":
|
||||
return processing_utils.encode_file_to_base64(y)
|
||||
elif self.type == "plot":
|
||||
elif dtype == "plot":
|
||||
return processing_utils.encode_plot_to_base64(y)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
raise ValueError("Unknown type: " + dtype + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
@ -235,10 +249,10 @@ class Audio(OutputComponent):
|
||||
Output type: Union[Tuple[int, numpy.array], str]
|
||||
'''
|
||||
|
||||
def __init__(self, type="numpy", label=None):
|
||||
def __init__(self, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file.
|
||||
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.type = type
|
||||
@ -256,8 +270,8 @@ class Audio(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type in ["numpy", "file"]:
|
||||
if self.type == "numpy":
|
||||
if self.type in ["numpy", "file", "auto"]:
|
||||
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
|
||||
file = tempfile.NamedTemporaryFile()
|
||||
scipy.io.wavfile.write(file, y[0], y[1])
|
||||
y = file.name
|
||||
@ -348,11 +362,11 @@ class Dataframe(OutputComponent):
|
||||
Output type: Union[pandas.DataFrame, numpy.array, List[Union[str, float]], List[List[Union[str, float]]]]
|
||||
"""
|
||||
|
||||
def __init__(self, headers=None, type="pandas", label=None):
|
||||
def __init__(self, headers=None, type="auto", label=None):
|
||||
'''
|
||||
Parameters:
|
||||
headers (List[str]): Header names to dataframe.
|
||||
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array.
|
||||
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
self.type = type
|
||||
@ -369,17 +383,26 @@ class Dataframe(OutputComponent):
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"dataframe": {"type": "pandas"},
|
||||
"dataframe": {},
|
||||
"numpy": {"type": "numpy"},
|
||||
"matrix": {"type": "array"},
|
||||
"list": {"type": "array"},
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "pandas":
|
||||
if self.type == "auto":
|
||||
if isinstance(y, pd.core.frame.DataFrame):
|
||||
dtype = "pandas"
|
||||
elif isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, list):
|
||||
dtype = "array"
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype == "pandas":
|
||||
return {"headers": list(y.columns), "data": y.values.tolist()}
|
||||
elif self.type in ("numpy", "array"):
|
||||
if self.type == "numpy":
|
||||
elif dtype in ("numpy", "array"):
|
||||
if dtype == "numpy":
|
||||
y = y.tolist()
|
||||
if len(y) == 0 or not isinstance(y[0], list):
|
||||
y = [y]
|
||||
|
@ -1,4 +1,4 @@
|
||||
.input_text {
|
||||
textarea.input_text, input.input_text {
|
||||
resize: none;
|
||||
width: 100%;
|
||||
font-size: 18px;
|
||||
@ -11,6 +11,3 @@
|
||||
padding: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
.input_text_saliency {
|
||||
display: none;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
function gradio(config, fn, target) {
|
||||
function gradio(config, fn, target, example_file_path) {
|
||||
target = $(target);
|
||||
target.html(`
|
||||
<div class="share invisible">
|
||||
@ -42,6 +42,7 @@ function gradio(config, fn, target) {
|
||||
io_master.fn = fn
|
||||
io_master.target = target;
|
||||
io_master.config = config;
|
||||
io_master.example_file_path = example_file_path;
|
||||
|
||||
let input_to_object_map = {
|
||||
"csv" : {},
|
||||
@ -243,7 +244,7 @@ function gradio(config, fn, target) {
|
||||
|
||||
return io_master;
|
||||
}
|
||||
function gradio_url(config, url, target) {
|
||||
function gradio_url(config, url, target, example_file_path) {
|
||||
return gradio(config, function(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
$.ajax({type: "POST",
|
||||
@ -253,7 +254,7 @@ function gradio_url(config, url, target) {
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, target);
|
||||
}, target, example_file_path);
|
||||
}
|
||||
function saveAs(uri, filename) {
|
||||
var link = document.createElement('a');
|
||||
|
@ -215,25 +215,28 @@ const image_input = {
|
||||
}
|
||||
},
|
||||
load_example_preview: function(data) {
|
||||
return "<img src="+data+" height=100>"
|
||||
return "<img src='"+this.io_master.example_file_path+data+"' height=100>"
|
||||
},
|
||||
load_example: function(data) {
|
||||
load_example: function(example_data) {
|
||||
example_data = this.io_master.example_file_path + example_data;
|
||||
let io = this;
|
||||
if (this.source == "canvas") {
|
||||
this.clear();
|
||||
let ctx = this.context;
|
||||
var img = new Image;
|
||||
let dimension = this.target.find(".canvas_holder canvas").width();
|
||||
img.onload = function(){
|
||||
ctx.clearRect(0,0,dimension,dimension);
|
||||
ctx.drawImage(img,0,0,dimension,dimension);
|
||||
};
|
||||
img.src = data;
|
||||
} else {
|
||||
io.target.find(".upload_zone").hide();
|
||||
io.target.find(".image_display").removeClass("hide");
|
||||
io.set_image_data(data, /*update_editor=*/true);
|
||||
io.state = "IMAGE_LOADED";
|
||||
}
|
||||
toDataURL(example_data, function(data) {
|
||||
if (io.source == "canvas") {
|
||||
io.clear();
|
||||
let ctx = this.context;
|
||||
var img = new Image;
|
||||
let dimension = io.target.find(".canvas_holder canvas").width();
|
||||
img.onload = function(){
|
||||
ctx.clearRect(0,0,dimension,dimension);
|
||||
ctx.drawImage(img,0,0,dimension,dimension);
|
||||
};
|
||||
img.src = data;
|
||||
} else {
|
||||
io.target.find(".upload_zone").hide();
|
||||
io.target.find(".image_display").removeClass("hide");
|
||||
io.set_image_data(data, /*update_editor=*/true);
|
||||
io.state = "IMAGE_LOADED";
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
const textbox_input = {
|
||||
html: `<textarea class="input_text"></textarea>
|
||||
<div class='input_text_saliency'></div>`,
|
||||
html: `<textarea class="input_text"></textarea>`,
|
||||
one_line_html: `<input type="text" class="input_text">`,
|
||||
init: function(opts) {
|
||||
if (opts.lines) {
|
||||
if (opts.lines > 1) {
|
||||
this.target.find(".input_text").attr("rows", opts.lines).css("height", "auto");
|
||||
} else {
|
||||
this.target.html(this.one_line_html);
|
||||
}
|
||||
if (opts.placeholder) {
|
||||
this.target.find(".input_text").attr("placeholder", opts.placeholder)
|
||||
@ -16,22 +18,8 @@ const textbox_input = {
|
||||
text = this.target.find(".input_text").val();
|
||||
this.io_master.input(this.id, text);
|
||||
},
|
||||
output: function(data) {
|
||||
this.target.find(".input_text").hide();
|
||||
this.target.find(".input_text_saliency").show();
|
||||
this.target.find(".input_text_saliency").empty();
|
||||
let html = '';
|
||||
let text = this.target.find(".input_text").val();
|
||||
let index = 0;
|
||||
data.forEach(function(value, index) {
|
||||
html += `<span style='background-color:rgba(75,150,255,${value})'>${text.charAt(index)}</span>`;
|
||||
})
|
||||
$(".input_text_saliency").html(html);
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find(".input_text").val("");
|
||||
this.target.find(".input_text_saliency").hide();
|
||||
this.target.find(".input_text").show();
|
||||
},
|
||||
load_example: function(data) {
|
||||
this.target.find(".input_text").val(data);
|
||||
|
@ -6,6 +6,20 @@ String.prototype.format = function() {
|
||||
return a
|
||||
}
|
||||
|
||||
function toDataURL(url, callback) {
|
||||
var xhr = new XMLHttpRequest();
|
||||
xhr.onload = function() {
|
||||
var reader = new FileReader();
|
||||
reader.onloadend = function() {
|
||||
callback(reader.result);
|
||||
}
|
||||
reader.readAsDataURL(xhr.response);
|
||||
};
|
||||
xhr.open('GET', url);
|
||||
xhr.responseType = 'blob';
|
||||
xhr.send();
|
||||
}
|
||||
|
||||
function resizeImage(base64Str, max_width, max_height, callback) {
|
||||
var img = new Image();
|
||||
img.src = base64Str;
|
||||
|
@ -118,7 +118,7 @@
|
||||
<script src="/static/js/gradio.js"></script>
|
||||
<script>
|
||||
$.getJSON("static/config.json", function(config) {
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target");
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target", "/file/");
|
||||
});
|
||||
const copyToClipboard = str => {
|
||||
const el = document.createElement('textarea');
|
||||
|
Loading…
Reference in New Issue
Block a user