mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
PR review changes
This commit is contained in:
parent
05099e7cf6
commit
96543e3ee7
@ -63,11 +63,11 @@ class AbstractInput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
return msg
|
||||
return data
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
"""
|
||||
@ -295,11 +295,11 @@ class Image(AbstractInput):
|
||||
else:
|
||||
return example
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
@ -356,11 +356,11 @@ class Sketchpad(AbstractInput):
|
||||
def process_example(self, example):
|
||||
return preprocessing_utils.convert_file_to_base64(example)
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
@ -403,11 +403,11 @@ class Webcam(AbstractInput):
|
||||
im, (self.image_width, self.image_height))
|
||||
return np.array(im)
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
|
@ -30,9 +30,6 @@ try:
|
||||
except requests.ConnectionError:
|
||||
ip_address = "No internet connection"
|
||||
|
||||
FLAGGING_DIRECTORY = 'flagged/'
|
||||
|
||||
|
||||
class Interface:
|
||||
"""
|
||||
Interfaces are created with Gradio using the `gradio.Interface()` function.
|
||||
@ -43,7 +40,8 @@ class Interface:
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True):
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged"):
|
||||
"""
|
||||
Parameters:
|
||||
fn (Callable): the function to wrap an interface around.
|
||||
@ -104,6 +102,7 @@ class Interface:
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
self.allow_flagging = allow_flagging
|
||||
self.flagging_dir = flagging_dir
|
||||
Interface.instances.add(self)
|
||||
|
||||
data = {'fn': fn,
|
||||
@ -125,15 +124,15 @@ class Interface:
|
||||
|
||||
if self.allow_flagging:
|
||||
if self.title is not None:
|
||||
dir_name = "_".join(self.title.split(" ")) + "_1"
|
||||
dir_name = "_".join(self.title.split(" "))
|
||||
else:
|
||||
dir_name = "_".join([fn.__name__ for fn in self.predict]) + \
|
||||
"_1"
|
||||
i = 1
|
||||
while os.path.exists(FLAGGING_DIRECTORY + dir_name):
|
||||
i += 1
|
||||
dir_name = dir_name[:-2] + "_" + str(i)
|
||||
self.flagging_dir = FLAGGING_DIRECTORY + dir_name
|
||||
dir_name = "_".join([fn.__name__ for fn in self.predict])
|
||||
index = 1
|
||||
while os.path.exists(self.flagging_dir + "/" + dir_name +
|
||||
"_{}".format(index)):
|
||||
index += 1
|
||||
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
|
||||
"_{}".format(index)
|
||||
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-initiated-analytics/',
|
||||
|
@ -35,7 +35,6 @@ CONFIG_FILE = "static/config.json"
|
||||
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||
|
||||
FLAGGING_FILENAME = 'flagged.txt'
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
|
||||
@ -187,15 +186,16 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
msg = json.loads(data_string)
|
||||
os.makedirs(interface.flagging_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild_flagged(
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['input_data']) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild_flagged(
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['output_data']) for i
|
||||
in range(len(interface.output_interfaces))]}
|
||||
|
||||
with open(os.path.join(interface.flagging_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
with open("{}/log.txt".format(interface.flagging_dir),
|
||||
'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
|
@ -44,11 +44,11 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
return msg
|
||||
return data
|
||||
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
@ -136,11 +136,11 @@ class Label(AbstractOutput):
|
||||
"label": {},
|
||||
}
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method for label
|
||||
"""
|
||||
return json.loads(msg)
|
||||
return json.loads(data)
|
||||
|
||||
class Image(AbstractOutput):
|
||||
'''
|
||||
@ -180,11 +180,11 @@ class Image(AbstractOutput):
|
||||
raise ValueError(
|
||||
"The `Image` output interface (with plt=False) expects a numpy array.")
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
def rebuild(self, dir, data):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.decode_base64_to_image(msg)
|
||||
im = preprocessing_utils.decode_base64_to_image(data)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = 'output_{}.png'.format(timestamp.
|
||||
strftime("%Y-%m-%d-%H-%M-%S"))
|
||||
|
Loading…
Reference in New Issue
Block a user