fix multiple io + fn

This commit is contained in:
Ali Abid 2020-06-11 10:12:18 -07:00
parent e0bcf324a8
commit da7056d137
7 changed files with 118 additions and 42 deletions

View File

@ -11,7 +11,7 @@ from PIL import Image, ImageOps
import time
import warnings
import json
import datetime
# Where to find the static resources associated with each template.
# BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'static/js/interfaces/input/{}.js'
@ -113,8 +113,9 @@ class Sketchpad(AbstractInput):
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
@ -159,8 +160,8 @@ class Webcam(AbstractInput):
"""
inp = msg['data']['input']
im = preprocessing_utils.decode_base64_to_image(inp)
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
@ -186,11 +187,7 @@ class Textbox(AbstractInput):
"""
Default rebuild method for text saves it .txt file
"""
timestamp = time.time()*1000
filename = f'input_{timestamp}'
with open(f'{dir}/{filename}.txt','w') as f:
f.write(msg)
return filename
return json.loads(msg)
def get_sample_inputs(self):
return self.sample_inputs
@ -240,8 +237,8 @@ class ImageIn(AbstractInput):
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
@ -250,8 +247,8 @@ class ImageIn(AbstractInput):
"""
"""
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
img.save(f'{dir}/{filename}', 'PNG')
filename = 'input_{}.png'.format(timestamp)
img.save('{}/{}'.format(dir, filename), 'PNG')
return filename

View File

@ -28,7 +28,8 @@ class Interface:
the appropriate inputs and outputs
"""
def __init__(self, fn, inputs, outputs, verbose=False, live=False, show_input=True, show_output=True):
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
live=False, show_input=True, show_output=True):
"""
:param fn: a function that will process the input panel data from the interface and return the output panel data.
:param inputs: a string or `AbstractInput` representing the input interface.
@ -64,11 +65,11 @@ class Interface:
self.predict = fn
self.verbose = verbose
self.status = "OFF"
self.saliency = None
self.saliency = saliency
self.live = live
self.show_input = show_input
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
def update_config_file(self, output_directory):
config = {
@ -98,7 +99,7 @@ class Interface:
for m, msg in enumerate(validation_inputs):
if self.verbose:
print(
f"Validating samples: {m+1}/{n} ["
"Validating samples: {}/{} [".format(m+1, n)
+ "=" * (m + 1)
+ "." * (n - m - 1)
+ "]",
@ -177,9 +178,10 @@ class Interface:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, "
f"however version {latest_pkg_version} "
f"is available, please upgrade.")
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass

View File

@ -143,12 +143,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
if len(interface.output_interfaces) == 1:
prediction = [prediction]
predictions.extend(prediction)
print(predictions)
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
output = {"action": "output", "data": processed_output}
if interface.saliency is not None:
import numpy as np
saliency = interface.saliency(interface, interface.model_obj, raw_input, processed_input, prediction, processed_output)
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()
# if interface.always_flag:
# msg = json.loads(data_string)
@ -168,17 +166,92 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
msg = json.loads(data_string)
flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash))
os.makedirs(FLAGGING_DIRECTORY, exist_ok=True)
output = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']['input_data']),
'output': interface.output_interface.rebuild_flagged(flag_dir, msg['data']['output_data']),
flag_dir = os.path.join(FLAGGING_DIRECTORY,
str(interface.flag_hash))
os.makedirs(flag_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))],
'message': msg['data']['message']}
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
f.write(json.dumps(output))
f.write("\n")
#TODO(abidlabs): clean this up
elif self.path == "/api/auto/rotation":
from gradio import validation_data, preprocessing_utils
import numpy as np
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
msg = json.loads(data_string)
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224))
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
os.makedirs(flag_dir, exist_ok=True)
for deg in range(-180, 180+45, 45):
img = img_orig.rotate(deg)
img_array = np.array(img) / 127.5 - 1
prediction = interface.predict(np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}),
'message': 'rotation by {} degrees'.format(
deg)}
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
f.write(json.dumps(output))
f.write("\n")
# Prepare return json dictionary.
self.wfile.write(json.dumps({}).encode())
elif self.path == "/api/auto/lighting":
from gradio import validation_data, preprocessing_utils
import numpy as np
from PIL import ImageEnhance
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
msg = json.loads(data_string)
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224))
enhancer = ImageEnhance.Brightness(img_orig)
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
os.makedirs(flag_dir, exist_ok=True)
for i in range(9):
img = enhancer.enhance(i/4)
img_array = np.array(img) / 127.5 - 1
prediction = interface.predict(np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}),
'message': 'brighting adjustment by a factor '
'of {}'.format(i)}
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
f.write(json.dumps(output))
f.write("\n")
# Prepare return json dictionary.
self.wfile.write(json.dumps({}).encode())
else:
self.send_error(404, 'Path not found: %s' % self.path)
self.send_error(404, 'Path not found: {}'.format(self.path))
class HTTPServer(BaseHTTPServer):
"""The main server, you pass in base_path which is the path you want to serve requests from"""

View File

@ -7,7 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in
from abc import ABC, abstractmethod
import numpy as np
import json
from gradio import imagenet_class_labels, preprocessing_utils
from gradio import preprocessing_utils
import datetime
# Where to find the static resources associated with each template.
@ -124,8 +124,9 @@ class Image(AbstractOutput):
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
filename = 'output_{}.png'.format(timestamp.
strftime("%Y-%m-%d-%H-%M-%S"))
im.save('{}/{}'.format(dir, filename), 'PNG')
return filename

View File

@ -19,12 +19,13 @@ def handler(chan, host, port):
try:
sock.connect((host, port))
except Exception as e:
verbose("Forwarding request to %s:%d failed: %r" % (host, port, e))
verbose("Forwarding request to {}:{} failed: {}".format(host, port, e))
return
verbose(
"Connected! Tunnel open %r -> %r -> %r"
% (chan.origin_addr, chan.getpeername(), (host, port))
"Connected! Tunnel open {} -> {} -> {}".format(chan.origin_addr,
chan.getpeername(),
(host, port))
)
while True:
r, w, x = select.select([sock, chan], [], [])
@ -40,7 +41,7 @@ def handler(chan, host, port):
sock.send(data)
chan.close()
sock.close()
verbose("Tunnel closed from %r" % (chan.origin_addr,))
verbose("Tunnel closed from {}".format(chan.origin_addr,))
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
@ -65,7 +66,8 @@ def create_tunnel(payload, local_server, local_server_port):
client.set_missing_host_key_policy(paramiko.WarningPolicy())
verbose(
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload[
"port"]))
)
try:
with warnings.catch_warnings():
@ -78,14 +80,16 @@ def create_tunnel(payload, local_server, local_server_port):
)
except Exception as e:
print(
"*** Failed to connect to %s:%d: %r"
% (payload["host"], int(payload["port"]), e)
"*** Failed to connect to {}:{}: {}}".format(payload["host"],
int(payload["port"]), e)
)
sys.exit(1)
verbose(
"Now forwarding remote port %d to %s:%d ..."
% (int(payload["remote_port"]), local_server, local_server_port)
"Now forwarding remote port {} to {}:{} ...".format(int(payload[
"remote_port"]),
local_server,
local_server_port)
)
thread = threading.Thread(

Binary file not shown.

View File

@ -143,7 +143,6 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
if len(interface.output_interfaces) == 1:
prediction = [prediction]
predictions.extend(prediction)
print(predictions)
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
output = {"action": "output", "data": processed_output}
if interface.saliency is not None: