mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
brought back sessions for TF 1.x
This commit is contained in:
parent
286ea40794
commit
85b0d2df34
@ -206,8 +206,16 @@ class Interface:
|
||||
|
||||
if capture_session is not None:
|
||||
warnings.warn("The `capture_session` parameter in the `Interface`"
|
||||
" is deprecated and has no effect.")
|
||||
|
||||
" is deprecated and may be removed in the future.")
|
||||
try:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
except (ImportError, AttributeError):
|
||||
# If they are using TF >= 2.0 or don't have TF,
|
||||
# just ignore this parameter.
|
||||
pass
|
||||
|
||||
if server_name is not None or server_port is not None:
|
||||
raise DeprecationWarning(
|
||||
"The `server_name` and `server_port` parameters in `Interface`"
|
||||
@ -400,7 +408,12 @@ class Interface:
|
||||
|
||||
for predict_fn in self.predict:
|
||||
start = time.time()
|
||||
prediction = predict_fn(*processed_input)
|
||||
if self.capture_session and self.session is not None: # For TF 1.x
|
||||
graph, sess = self.session
|
||||
with graph.as_default(), sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
duration = time.time() - start
|
||||
|
||||
if len(self.output_components) == len(self.predict):
|
||||
|
@ -115,7 +115,12 @@ def run_interpret(interface, raw_input):
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)]
|
||||
interpreter = interface.interpretation
|
||||
interpretation = interpreter(*processed_input)
|
||||
if interface.capture_session and interface.session is not None:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default(), sess.as_default():
|
||||
interpretation = interpreter(*processed_input)
|
||||
else:
|
||||
interpretation = interpreter(*processed_input)
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
return interpretation, []
|
||||
|
Loading…
Reference in New Issue
Block a user