fixed capture session code

This commit is contained in:
Abubakar Abid 2020-07-09 15:15:34 -05:00
parent 58493e4a2b
commit 6c1648b287

View File

@ -107,8 +107,16 @@ class Interface:
'host_name': hostname, 'host_name': hostname,
'ip_address': ip_address 'ip_address': ip_address
} }
if self.capture_session:
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
pass
try: try:
print("try initiated")
requests.post(analytics_url + 'gradio-initiated-analytics/', requests.post(analytics_url + 'gradio-initiated-analytics/',
data=data) data=data)
except requests.ConnectionError: except requests.ConnectionError:
@ -144,7 +152,6 @@ class Interface:
return config return config
def process(self, raw_input): def process(self, raw_input):
processed_input = [input_interface.preprocess( processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in raw_input[i]) for i, input_interface in
@ -262,14 +269,6 @@ class Interface:
# if validate and not self.validate_flag: # if validate and not self.validate_flag:
# self.validate() # self.validate()
if self.capture_session:
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
pass
output_directory = tempfile.mkdtemp() output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface. # Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name) server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)