mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
commit
480591ad2b
@ -63,6 +63,11 @@ class AbstractInput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
return data
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
"""
|
||||
@ -290,6 +295,16 @@ class Image(AbstractInput):
|
||||
else:
|
||||
return example
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
"""
|
||||
@ -341,6 +356,16 @@ class Sketchpad(AbstractInput):
|
||||
def process_example(self, example):
|
||||
return preprocessing_utils.convert_file_to_base64(example)
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
"""
|
||||
@ -378,6 +403,16 @@ class Webcam(AbstractInput):
|
||||
im, (self.image_width, self.image_height))
|
||||
return np.array(im)
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
class Microphone(AbstractInput):
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ from IPython import get_ipython
|
||||
import sys
|
||||
import weakref
|
||||
import analytics
|
||||
|
||||
import os
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
@ -30,7 +30,6 @@ try:
|
||||
except requests.ConnectionError:
|
||||
ip_address = "No internet connection"
|
||||
|
||||
|
||||
class Interface:
|
||||
"""
|
||||
Interfaces are created with Gradio using the `gradio.Interface()` function.
|
||||
@ -41,7 +40,8 @@ class Interface:
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True):
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged"):
|
||||
"""
|
||||
Parameters:
|
||||
fn (Callable): the function to wrap an interface around.
|
||||
@ -101,6 +101,8 @@ class Interface:
|
||||
self.server_port = server_port
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
self.allow_flagging = allow_flagging
|
||||
self.flagging_dir = flagging_dir
|
||||
Interface.instances.add(self)
|
||||
|
||||
data = {'fn': fn,
|
||||
@ -120,6 +122,18 @@ class Interface:
|
||||
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
|
||||
pass
|
||||
|
||||
if self.allow_flagging:
|
||||
if self.title is not None:
|
||||
dir_name = "_".join(self.title.split(" "))
|
||||
else:
|
||||
dir_name = "_".join([fn.__name__ for fn in self.predict])
|
||||
index = 1
|
||||
while os.path.exists(self.flagging_dir + "/" + dir_name +
|
||||
"_{}".format(index)):
|
||||
index += 1
|
||||
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
|
||||
"_{}".format(index)
|
||||
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-initiated-analytics/',
|
||||
data=data)
|
||||
@ -141,7 +155,8 @@ class Interface:
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
|
@ -17,7 +17,6 @@ import requests
|
||||
import sys
|
||||
import analytics
|
||||
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = int(os.getenv(
|
||||
@ -36,8 +35,6 @@ CONFIG_FILE = "static/config.json"
|
||||
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||
|
||||
FLAGGING_DIRECTORY = 'static/flagged/'
|
||||
FLAGGING_FILENAME = 'data.txt'
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
|
||||
@ -175,16 +172,6 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
output['saliency'] = saliency.tolist()
|
||||
# if interface.always_flag:
|
||||
# msg = json.loads(data_string)
|
||||
# flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash))
|
||||
# os.makedirs(flag_dir, exist_ok=True)
|
||||
# output_flag = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']),
|
||||
# 'output': interface.output_interface.rebuild_flagged(flag_dir, processed_output),
|
||||
# }
|
||||
# with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
# f.write(json.dumps(output_flag))
|
||||
# f.write("\n")
|
||||
|
||||
self.wfile.write(json.dumps(output).encode())
|
||||
|
||||
@ -197,20 +184,18 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY,
|
||||
str(interface.flag_hash))
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
os.makedirs(interface.flagging_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['input_data']) for i
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['input_data']) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['output_data']) for i
|
||||
in range(len(interface.output_interfaces))],
|
||||
'message': msg['data']['message']}
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['output_data']) for i
|
||||
in range(len(interface.output_interfaces))]}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
with open("{}/log.txt".format(interface.flagging_dir),
|
||||
'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
|
@ -44,6 +44,12 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
return data
|
||||
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
'''
|
||||
@ -130,6 +136,11 @@ class Label(AbstractOutput):
|
||||
"label": {},
|
||||
}
|
||||
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method for label
|
||||
"""
|
||||
return json.loads(data)
|
||||
|
||||
class Image(AbstractOutput):
|
||||
'''
|
||||
@ -169,11 +180,11 @@ class Image(AbstractOutput):
|
||||
raise ValueError(
|
||||
"The `Image` output interface (with plt=False) expects a numpy array.")
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = 'output_{}.png'.format(timestamp.
|
||||
strftime("%Y-%m-%d-%H-%M-%S"))
|
||||
|
@ -75,6 +75,12 @@ input.submit {
|
||||
input.submit:hover {
|
||||
background-color: #f39c12;
|
||||
}
|
||||
.flag {
|
||||
visibility: hidden;
|
||||
}
|
||||
.flagged {
|
||||
background-color: pink !important;
|
||||
}
|
||||
/* label:hover {
|
||||
background-color: lightgray;
|
||||
} */
|
||||
|
@ -56,12 +56,11 @@ var io_master_template = {
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
}
|
||||
},
|
||||
flag: function(message) {
|
||||
flag: function() {
|
||||
var post_data = {
|
||||
'data': {
|
||||
'input_data' : toStringIfObject(this.last_input) ,
|
||||
'output_data' : toStringIfObject(this.last_output),
|
||||
'message' : message
|
||||
'output_data' : toStringIfObject(this.last_output)
|
||||
}
|
||||
}
|
||||
$.ajax({type: "POST",
|
||||
|
@ -22,7 +22,7 @@ function gradio(config, fn, target) {
|
||||
<div class="screenshot_logo">
|
||||
<img src="static/img/logo_inline.png">
|
||||
</div>
|
||||
</div>
|
||||
<input class="flag panel_button" type="button" value="FLAG"/>
|
||||
</div>
|
||||
</div>`);
|
||||
let io_master = Object.create(io_master_template);
|
||||
@ -117,6 +117,7 @@ function gradio(config, fn, target) {
|
||||
output_interface.clear();
|
||||
}
|
||||
target.find(".flag").removeClass("flagged");
|
||||
target.find(".flag").val("FLAG");
|
||||
target.find(".flag_message").empty();
|
||||
target.find(".loading").addClass("invisible");
|
||||
target.find(".loading_time").text("");
|
||||
@ -127,6 +128,9 @@ function gradio(config, fn, target) {
|
||||
if (config["allow_screenshot"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
}
|
||||
if(config["allow_flagging"]){
|
||||
target.find(".flag").css("visibility", "visible");
|
||||
}
|
||||
target.find(".screenshot").click(function() {
|
||||
$(".screenshot").hide();
|
||||
$(".screenshot_logo").show();
|
||||
@ -146,11 +150,22 @@ function gradio(config, fn, target) {
|
||||
target.find(".submit").click(function() {
|
||||
io_master.gather();
|
||||
target.find(".flag").removeClass("flagged");
|
||||
target.find(".flag").val("FLAG");
|
||||
})
|
||||
}
|
||||
if (!config.show_input) {
|
||||
target.find(".input_panel").hide();
|
||||
}
|
||||
}
|
||||
|
||||
target.find(".flag").click(function() {
|
||||
if (io_master.last_output) {
|
||||
target.find(".flag").addClass("flagged");
|
||||
target.find(".flag").val("FLAGGED");
|
||||
io_master.flag();
|
||||
|
||||
// io_master.flag($(".flag_message").val());
|
||||
}
|
||||
})
|
||||
|
||||
return io_master;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user