mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-13 11:57:29 +08:00
Merge branch 'master' of https://github.com/gradio-app/gradio-UI
This commit is contained in:
commit
a78e34e75b
@ -32,7 +32,7 @@ def launch_interface(args):
|
||||
pass
|
||||
|
||||
def service_shutdown(signum, frame):
|
||||
print('Shutting server down due to signal %d' % signum)
|
||||
print('Shutting server down due to signal {}'.format(signum))
|
||||
httpd.shutdown()
|
||||
raise ServiceExit
|
||||
|
||||
|
@ -11,7 +11,7 @@ from PIL import Image, ImageOps
|
||||
import time
|
||||
import warnings
|
||||
import json
|
||||
|
||||
import datetime
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
# BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'static/js/interfaces/input/{}.js'
|
||||
@ -113,8 +113,9 @@ class Sketchpad(AbstractInput):
|
||||
"""
|
||||
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
@ -159,8 +160,8 @@ class Webcam(AbstractInput):
|
||||
"""
|
||||
inp = msg['data']['input']
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
@ -186,11 +187,7 @@ class Textbox(AbstractInput):
|
||||
"""
|
||||
Default rebuild method for text saves it .txt file
|
||||
"""
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}'
|
||||
with open(f'{dir}/{filename}.txt','w') as f:
|
||||
f.write(msg)
|
||||
return filename
|
||||
return json.loads(msg)
|
||||
|
||||
def get_sample_inputs(self):
|
||||
return self.sample_inputs
|
||||
@ -240,8 +237,8 @@ class ImageIn(AbstractInput):
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
@ -250,8 +247,8 @@ class ImageIn(AbstractInput):
|
||||
"""
|
||||
"""
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
img.save(f'{dir}/{filename}', 'PNG')
|
||||
filename = 'input_{}.png'.format(timestamp)
|
||||
img.save('{}/{}'.format(dir, filename), 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
|
@ -28,7 +28,8 @@ class Interface:
|
||||
the appropriate inputs and outputs
|
||||
"""
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, live=False, show_input=True, show_output=True):
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
|
||||
live=False, show_input=True, show_output=True):
|
||||
"""
|
||||
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
@ -58,19 +59,23 @@ class Interface:
|
||||
self.output_interfaces = [get_output_instance(i) for i in outputs]
|
||||
else:
|
||||
self.output_interfaces = [get_output_instance(outputs)]
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.verbose = verbose
|
||||
self.status = "OFF"
|
||||
self.saliency = None
|
||||
self.saliency = saliency
|
||||
self.live = live
|
||||
self.show_input = show_input
|
||||
self.show_output = show_output
|
||||
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
|
||||
def update_config_file(self, output_directory):
|
||||
config = {
|
||||
"input_interfaces": [iface.__class__.__name__.lower() for iface in self.input_interfaces],
|
||||
"output_interfaces": [iface.__class__.__name__.lower() for iface in self.output_interfaces],
|
||||
"function_count": len(self.predict),
|
||||
"live": self.live,
|
||||
"show_input": self.show_input,
|
||||
"show_output": self.show_output,
|
||||
@ -94,7 +99,7 @@ class Interface:
|
||||
for m, msg in enumerate(validation_inputs):
|
||||
if self.verbose:
|
||||
print(
|
||||
f"Validating samples: {m+1}/{n} ["
|
||||
"Validating samples: {}/{} [".format(m+1, n)
|
||||
+ "=" * (m + 1)
|
||||
+ "." * (n - m - 1)
|
||||
+ "]",
|
||||
@ -173,9 +178,10 @@ class Interface:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, "
|
||||
f"however version {latest_pkg_version} "
|
||||
f"is available, please upgrade.")
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
|
@ -137,14 +137,16 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
prediction = interface.predict(*processed_input)
|
||||
if len(interface.input_interfaces) == 1:
|
||||
prediction = [prediction]
|
||||
processed_output = [output_interface.postprocess(prediction[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
output = {"action": "output", "data": processed_output}
|
||||
if interface.saliency is not None:
|
||||
import numpy as np
|
||||
saliency = interface.saliency(interface, interface.model_obj, raw_input, processed_input, prediction, processed_output)
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
output['saliency'] = saliency.tolist()
|
||||
# if interface.always_flag:
|
||||
# msg = json.loads(data_string)
|
||||
@ -164,17 +166,92 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash))
|
||||
os.makedirs(FLAGGING_DIRECTORY, exist_ok=True)
|
||||
output = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']['input_data']),
|
||||
'output': interface.output_interface.rebuild_flagged(flag_dir, msg['data']['output_data']),
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY,
|
||||
str(interface.flag_hash))
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['input_data']) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['output_data']) for i
|
||||
in range(len(interface.output_interfaces))],
|
||||
'message': msg['data']['message']}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
#TODO(abidlabs): clean this up
|
||||
elif self.path == "/api/auto/rotation":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for deg in range(-180, 180+45, 45):
|
||||
img = img_orig.rotate(deg)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': 'rotation by {} degrees'.format(
|
||||
deg)}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
elif self.path == "/api/auto/lighting":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
from PIL import ImageEnhance
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
enhancer = ImageEnhance.Brightness(img_orig)
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for i in range(9):
|
||||
img = enhancer.enhance(i/4)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': 'brighting adjustment by a factor '
|
||||
'of {}'.format(i)}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
else:
|
||||
self.send_error(404, 'Path not found: %s' % self.path)
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
@ -7,7 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import json
|
||||
from gradio import imagenet_class_labels, preprocessing_utils
|
||||
from gradio import preprocessing_utils
|
||||
import datetime
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
@ -124,8 +124,9 @@ class Image(AbstractOutput):
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
filename = 'output_{}.png'.format(timestamp.
|
||||
strftime("%Y-%m-%d-%H-%M-%S"))
|
||||
im.save('{}/{}'.format(dir, filename), 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
|
@ -12,15 +12,17 @@
|
||||
text-transform: uppercase;
|
||||
font-family: Arial;
|
||||
color: #888;
|
||||
padding: 6px;
|
||||
padding: 6px 6px 0;
|
||||
font-size: 14px;
|
||||
font-weight: bold;
|
||||
display: flex;
|
||||
}
|
||||
.input_interfaces, .output_interfaces {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.interface {
|
||||
height: 360px;
|
||||
margin-bottom: 16px;
|
||||
padding: 0 6px 6px;
|
||||
padding: 8px;
|
||||
display: flex;
|
||||
flex-flow: column;
|
||||
}
|
||||
@ -33,6 +35,10 @@
|
||||
}
|
||||
.panel_buttons {
|
||||
display: flex;
|
||||
margin-left: -8px;
|
||||
}
|
||||
.panel_buttons > * {
|
||||
margin-left: 8px;
|
||||
}
|
||||
.submit {
|
||||
display: none;
|
||||
@ -58,9 +64,6 @@
|
||||
padding: 8px !important;
|
||||
background-color: #EEEEEE !important;
|
||||
}
|
||||
.clear, .flag {
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
.upload_zone {
|
||||
font-weight: bold;
|
||||
|
@ -41,6 +41,9 @@ $.getJSON("static/config.json", function(data) {
|
||||
_id++;
|
||||
}
|
||||
for (let i = 0; i < config["output_interfaces"].length; i++) {
|
||||
if (i != 0 && i % (config["output_interfaces"].length / config.function_count) == 0) {
|
||||
$(".output_interfaces").append("<hr>");
|
||||
}
|
||||
output_interface = Object.create(output_to_object_map[
|
||||
config["output_interfaces"][i]]);
|
||||
$(".output_interfaces").append(`
|
||||
|
@ -19,12 +19,13 @@ def handler(chan, host, port):
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
except Exception as e:
|
||||
verbose("Forwarding request to %s:%d failed: %r" % (host, port, e))
|
||||
verbose("Forwarding request to {}:{} failed: {}".format(host, port, e))
|
||||
return
|
||||
|
||||
verbose(
|
||||
"Connected! Tunnel open %r -> %r -> %r"
|
||||
% (chan.origin_addr, chan.getpeername(), (host, port))
|
||||
"Connected! Tunnel open {} -> {} -> {}".format(chan.origin_addr,
|
||||
chan.getpeername(),
|
||||
(host, port))
|
||||
)
|
||||
while True:
|
||||
r, w, x = select.select([sock, chan], [], [])
|
||||
@ -40,7 +41,7 @@ def handler(chan, host, port):
|
||||
sock.send(data)
|
||||
chan.close()
|
||||
sock.close()
|
||||
verbose("Tunnel closed from %r" % (chan.origin_addr,))
|
||||
verbose("Tunnel closed from {}".format(chan.origin_addr,))
|
||||
|
||||
|
||||
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
|
||||
@ -65,7 +66,8 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
|
||||
verbose(
|
||||
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
||||
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload[
|
||||
"port"]))
|
||||
)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
@ -78,14 +80,16 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"*** Failed to connect to %s:%d: %r"
|
||||
% (payload["host"], int(payload["port"]), e)
|
||||
"*** Failed to connect to {}:{}: {}}".format(payload["host"],
|
||||
int(payload["port"]), e)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
verbose(
|
||||
"Now forwarding remote port %d to %s:%d ..."
|
||||
% (int(payload["remote_port"]), local_server, local_server_port)
|
||||
"Now forwarding remote port {} to {}:{} ...".format(int(payload[
|
||||
"remote_port"]),
|
||||
local_server,
|
||||
local_server_port)
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
|
48
demo/GPT-2-Demo.py
Normal file
48
demo/GPT-2-Demo.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
# installing transformers
|
||||
# !pip install -q git+https://github.com/huggingface/transformers.git
|
||||
# !pip install -q tensorflow==2.1
|
||||
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
|
||||
import gradio
|
||||
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
# add the EOS token as PAD token to avoid warnings
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
|
||||
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
def predict(inp):
|
||||
input_ids = tokenizer.encode(inp, return_tensors='tf')
|
||||
beam_output = model.generate(input_ids, max_length=49, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
|
||||
output = tokenizer.decode(beam_output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
return ".".join(output.split(".")[:-1]) + "."
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
gradio.Interface(predict,"textbox","textbox").launch(inbrowser=True)
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
|
||||
|
@ -1,7 +1,14 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from time import time
|
||||
|
||||
def flip(image):
|
||||
return np.flipud(image)
|
||||
start = time()
|
||||
return np.flipud(image), time() - start
|
||||
|
||||
gr.Interface(flip, "imagein", "image").launch()
|
||||
def flip2(image):
|
||||
start = time()
|
||||
return np.fliplr(image), time() - start
|
||||
|
||||
|
||||
gr.Interface([flip, flip2], "imagein", ["image", "textbox"]).launch()
|
BIN
dist/gradio-0.9.0-py3.7.egg
vendored
BIN
dist/gradio-0.9.0-py3.7.egg
vendored
Binary file not shown.
@ -11,7 +11,7 @@ from PIL import Image, ImageOps
|
||||
import time
|
||||
import warnings
|
||||
import json
|
||||
|
||||
import datetime
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
# BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'static/js/interfaces/input/{}.js'
|
||||
@ -113,8 +113,9 @@ class Sketchpad(AbstractInput):
|
||||
"""
|
||||
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
@ -159,8 +160,8 @@ class Webcam(AbstractInput):
|
||||
"""
|
||||
inp = msg['data']['input']
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
@ -186,11 +187,7 @@ class Textbox(AbstractInput):
|
||||
"""
|
||||
Default rebuild method for text saves it .txt file
|
||||
"""
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}'
|
||||
with open(f'{dir}/{filename}.txt','w') as f:
|
||||
f.write(msg)
|
||||
return filename
|
||||
return json.loads(msg)
|
||||
|
||||
def get_sample_inputs(self):
|
||||
return self.sample_inputs
|
||||
@ -240,8 +237,8 @@ class ImageIn(AbstractInput):
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
@ -250,8 +247,8 @@ class ImageIn(AbstractInput):
|
||||
"""
|
||||
"""
|
||||
timestamp = time.time()*1000
|
||||
filename = f'input_{timestamp}.png'
|
||||
img.save(f'{dir}/{filename}', 'PNG')
|
||||
filename = 'input_{}.png'.format(timestamp)
|
||||
img.save('{}/{}'.format(dir, filename), 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
|
@ -59,6 +59,9 @@ class Interface:
|
||||
self.output_interfaces = [get_output_instance(i) for i in outputs]
|
||||
else:
|
||||
self.output_interfaces = [get_output_instance(outputs)]
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.verbose = verbose
|
||||
self.status = "OFF"
|
||||
@ -66,12 +69,13 @@ class Interface:
|
||||
self.live = live
|
||||
self.show_input = show_input
|
||||
self.show_output = show_output
|
||||
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
|
||||
def update_config_file(self, output_directory):
|
||||
config = {
|
||||
"input_interfaces": [iface.__class__.__name__.lower() for iface in self.input_interfaces],
|
||||
"output_interfaces": [iface.__class__.__name__.lower() for iface in self.output_interfaces],
|
||||
"function_count": len(self.predict),
|
||||
"live": self.live,
|
||||
"show_input": self.show_input,
|
||||
"show_output": self.show_output,
|
||||
@ -95,7 +99,7 @@ class Interface:
|
||||
for m, msg in enumerate(validation_inputs):
|
||||
if self.verbose:
|
||||
print(
|
||||
f"Validating samples: {m+1}/{n} ["
|
||||
"Validating samples: {}/{} [".format(m+1, n)
|
||||
+ "=" * (m + 1)
|
||||
+ "." * (n - m - 1)
|
||||
+ "]",
|
||||
@ -174,9 +178,10 @@ class Interface:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, "
|
||||
f"however version {latest_pkg_version} "
|
||||
f"is available, please upgrade.")
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
|
@ -137,10 +137,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
prediction = interface.predict(*processed_input)
|
||||
if len(interface.input_interfaces) == 1:
|
||||
prediction = [prediction]
|
||||
processed_output = [output_interface.postprocess(prediction[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
output = {"action": "output", "data": processed_output}
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
@ -163,17 +166,92 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash))
|
||||
os.makedirs(FLAGGING_DIRECTORY, exist_ok=True)
|
||||
output = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']['input_data']),
|
||||
'output': interface.output_interface.rebuild_flagged(flag_dir, msg['data']['output_data']),
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY,
|
||||
str(interface.flag_hash))
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['input_data']) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['output_data']) for i
|
||||
in range(len(interface.output_interfaces))],
|
||||
'message': msg['data']['message']}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
#TODO(abidlabs): clean this up
|
||||
elif self.path == "/api/auto/rotation":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for deg in range(-180, 180+45, 45):
|
||||
img = img_orig.rotate(deg)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': 'rotation by {} degrees'.format(
|
||||
deg)}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
elif self.path == "/api/auto/lighting":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
from PIL import ImageEnhance
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
enhancer = ImageEnhance.Brightness(img_orig)
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for i in range(9):
|
||||
img = enhancer.enhance(i/4)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': 'brighting adjustment by a factor '
|
||||
'of {}'.format(i)}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
else:
|
||||
self.send_error(404, 'Path not found: %s' % self.path)
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
@ -124,8 +124,9 @@ class Image(AbstractOutput):
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
filename = 'output_{}.png'.format(timestamp.
|
||||
strftime("%Y-%m-%d-%H-%M-%S"))
|
||||
im.save('{}/{}'.format(dir, filename), 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
|
@ -12,15 +12,17 @@
|
||||
text-transform: uppercase;
|
||||
font-family: Arial;
|
||||
color: #888;
|
||||
padding: 6px;
|
||||
padding: 6px 6px 0;
|
||||
font-size: 14px;
|
||||
font-weight: bold;
|
||||
display: flex;
|
||||
}
|
||||
.input_interfaces, .output_interfaces {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.interface {
|
||||
height: 360px;
|
||||
margin-bottom: 16px;
|
||||
padding: 0 6px 6px;
|
||||
padding: 8px;
|
||||
display: flex;
|
||||
flex-flow: column;
|
||||
}
|
||||
@ -33,6 +35,10 @@
|
||||
}
|
||||
.panel_buttons {
|
||||
display: flex;
|
||||
margin-left: -8px;
|
||||
}
|
||||
.panel_buttons > * {
|
||||
margin-left: 8px;
|
||||
}
|
||||
.submit {
|
||||
display: none;
|
||||
@ -58,9 +64,6 @@
|
||||
padding: 8px !important;
|
||||
background-color: #EEEEEE !important;
|
||||
}
|
||||
.clear, .flag {
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
.upload_zone {
|
||||
font-weight: bold;
|
||||
|
@ -41,6 +41,9 @@ $.getJSON("static/config.json", function(data) {
|
||||
_id++;
|
||||
}
|
||||
for (let i = 0; i < config["output_interfaces"].length; i++) {
|
||||
if (i != 0 && i % (config["output_interfaces"].length / config.function_count) == 0) {
|
||||
$(".output_interfaces").append("<hr>");
|
||||
}
|
||||
output_interface = Object.create(output_to_object_map[
|
||||
config["output_interfaces"][i]]);
|
||||
$(".output_interfaces").append(`
|
||||
|
@ -19,12 +19,13 @@ def handler(chan, host, port):
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
except Exception as e:
|
||||
verbose("Forwarding request to %s:%d failed: %r" % (host, port, e))
|
||||
verbose("Forwarding request to {}:{} failed: {}".format(host, port, e))
|
||||
return
|
||||
|
||||
verbose(
|
||||
"Connected! Tunnel open %r -> %r -> %r"
|
||||
% (chan.origin_addr, chan.getpeername(), (host, port))
|
||||
"Connected! Tunnel open {} -> {} -> {}".format(chan.origin_addr,
|
||||
chan.getpeername(),
|
||||
(host, port))
|
||||
)
|
||||
while True:
|
||||
r, w, x = select.select([sock, chan], [], [])
|
||||
@ -40,7 +41,7 @@ def handler(chan, host, port):
|
||||
sock.send(data)
|
||||
chan.close()
|
||||
sock.close()
|
||||
verbose("Tunnel closed from %r" % (chan.origin_addr,))
|
||||
verbose("Tunnel closed from {}".format(chan.origin_addr,))
|
||||
|
||||
|
||||
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
|
||||
@ -65,7 +66,8 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
|
||||
verbose(
|
||||
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
||||
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload[
|
||||
"port"]))
|
||||
)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
@ -78,14 +80,16 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"*** Failed to connect to %s:%d: %r"
|
||||
% (payload["host"], int(payload["port"]), e)
|
||||
"*** Failed to connect to {}:{}: {}}".format(payload["host"],
|
||||
int(payload["port"]), e)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
verbose(
|
||||
"Now forwarding remote port %d to %s:%d ..."
|
||||
% (int(payload["remote_port"]), local_server, local_server_port)
|
||||
"Now forwarding remote port {} to {}:{} ...".format(int(payload[
|
||||
"remote_port"]),
|
||||
local_server,
|
||||
local_server_port)
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
|
2
static/flagged/1617829583/data.txt
Normal file
2
static/flagged/1617829583/data.txt
Normal file
@ -0,0 +1,2 @@
|
||||
{"inputs": [["It was all a dream. I used to read word up magazine. "]], "outputs": [["It was all a dream. I used to read word up magazine. It was like, \"Oh my God, this is going to be a big deal.\" And then I read it and I thought, Oh my god, I can't believe I'm reading this.\n\nSo I went back and read the book, and it was a really good book. And I think it's one of the best books I've read in a long time."]], "message": "Biggie smalls"}
|
||||
{"inputs": [["I went to Sudan last week and "]], "outputs": [["I went to Sudan last week and met with the president of the Sudanese government,\" he said.\n\n\"He told me that he was going to send a delegation to the United Nations to discuss the situation in Sudan, and I told him that I would be happy to meet with him."]], "message": "Sudan"}
|
Loading…
Reference in New Issue
Block a user