PR review changes

This commit is contained in:
aliabd 2020-07-29 00:00:14 -07:00
parent 05099e7cf6
commit 96543e3ee7
4 changed files with 29 additions and 30 deletions

View File

@ -63,11 +63,11 @@ class AbstractInput(ABC):
"""
return {}
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
return msg
return data
class Textbox(AbstractInput):
"""
@ -295,11 +295,11 @@ class Image(AbstractInput):
else:
return example
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
@ -356,11 +356,11 @@ class Sketchpad(AbstractInput):
def process_example(self, example):
return preprocessing_utils.convert_file_to_base64(example)
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
@ -403,11 +403,11 @@ class Webcam(AbstractInput):
im, (self.image_width, self.image_height))
return np.array(im)
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')

View File

@ -30,9 +30,6 @@ try:
except requests.ConnectionError:
ip_address = "No internet connection"
FLAGGING_DIRECTORY = 'flagged/'
class Interface:
"""
Interfaces are created with Gradio using the `gradio.Interface()` function.
@ -43,7 +40,8 @@ class Interface:
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True):
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged"):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
@ -104,6 +102,7 @@ class Interface:
self.simple_server = None
self.allow_screenshot = allow_screenshot
self.allow_flagging = allow_flagging
self.flagging_dir = flagging_dir
Interface.instances.add(self)
data = {'fn': fn,
@ -125,15 +124,15 @@ class Interface:
if self.allow_flagging:
if self.title is not None:
dir_name = "_".join(self.title.split(" ")) + "_1"
dir_name = "_".join(self.title.split(" "))
else:
dir_name = "_".join([fn.__name__ for fn in self.predict]) + \
"_1"
i = 1
while os.path.exists(FLAGGING_DIRECTORY + dir_name):
i += 1
dir_name = dir_name[:-2] + "_" + str(i)
self.flagging_dir = FLAGGING_DIRECTORY + dir_name
dir_name = "_".join([fn.__name__ for fn in self.predict])
index = 1
while os.path.exists(self.flagging_dir + "/" + dir_name +
"_{}".format(index)):
index += 1
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
"_{}".format(index)
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',

View File

@ -35,7 +35,6 @@ CONFIG_FILE = "static/config.json"
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
FLAGGING_FILENAME = 'flagged.txt'
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
@ -187,15 +186,16 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
msg = json.loads(data_string)
os.makedirs(interface.flagging_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
i].rebuild(
interface.flagging_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
i].rebuild(
interface.flagging_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))]}
with open(os.path.join(interface.flagging_dir, FLAGGING_FILENAME), 'a+') as f:
with open("{}/log.txt".format(interface.flagging_dir),
'a+') as f:
f.write(json.dumps(output))
f.write("\n")

View File

@ -44,11 +44,11 @@ class AbstractOutput(ABC):
"""
return {}
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
return msg
return data
class Textbox(AbstractOutput):
@ -136,11 +136,11 @@ class Label(AbstractOutput):
"label": {},
}
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method for label
"""
return json.loads(msg)
return json.loads(data)
class Image(AbstractOutput):
'''
@ -180,11 +180,11 @@ class Image(AbstractOutput):
raise ValueError(
"The `Image` output interface (with plt=False) expects a numpy array.")
def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = 'output_{}.png'.format(timestamp.
strftime("%Y-%m-%d-%H-%M-%S"))