mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
brought back sessions for TF 1.x
This commit is contained in:
parent
286ea40794
commit
85b0d2df34
@ -206,8 +206,16 @@ class Interface:
|
|||||||
|
|
||||||
if capture_session is not None:
|
if capture_session is not None:
|
||||||
warnings.warn("The `capture_session` parameter in the `Interface`"
|
warnings.warn("The `capture_session` parameter in the `Interface`"
|
||||||
" is deprecated and has no effect.")
|
" is deprecated and may be removed in the future.")
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
self.session = tf.get_default_graph(), \
|
||||||
|
tf.keras.backend.get_session()
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# If they are using TF >= 2.0 or don't have TF,
|
||||||
|
# just ignore this parameter.
|
||||||
|
pass
|
||||||
|
|
||||||
if server_name is not None or server_port is not None:
|
if server_name is not None or server_port is not None:
|
||||||
raise DeprecationWarning(
|
raise DeprecationWarning(
|
||||||
"The `server_name` and `server_port` parameters in `Interface`"
|
"The `server_name` and `server_port` parameters in `Interface`"
|
||||||
@ -400,7 +408,12 @@ class Interface:
|
|||||||
|
|
||||||
for predict_fn in self.predict:
|
for predict_fn in self.predict:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
prediction = predict_fn(*processed_input)
|
if self.capture_session and self.session is not None: # For TF 1.x
|
||||||
|
graph, sess = self.session
|
||||||
|
with graph.as_default(), sess.as_default():
|
||||||
|
prediction = predict_fn(*processed_input)
|
||||||
|
else:
|
||||||
|
prediction = predict_fn(*processed_input)
|
||||||
duration = time.time() - start
|
duration = time.time() - start
|
||||||
|
|
||||||
if len(self.output_components) == len(self.predict):
|
if len(self.output_components) == len(self.predict):
|
||||||
|
@ -115,7 +115,12 @@ def run_interpret(interface, raw_input):
|
|||||||
processed_input = [input_component.preprocess(raw_input[i])
|
processed_input = [input_component.preprocess(raw_input[i])
|
||||||
for i, input_component in enumerate(interface.input_components)]
|
for i, input_component in enumerate(interface.input_components)]
|
||||||
interpreter = interface.interpretation
|
interpreter = interface.interpretation
|
||||||
interpretation = interpreter(*processed_input)
|
if interface.capture_session and interface.session is not None:
|
||||||
|
graph, sess = interface.session
|
||||||
|
with graph.as_default(), sess.as_default():
|
||||||
|
interpretation = interpreter(*processed_input)
|
||||||
|
else:
|
||||||
|
interpretation = interpreter(*processed_input)
|
||||||
if len(raw_input) == 1:
|
if len(raw_input) == 1:
|
||||||
interpretation = [interpretation]
|
interpretation = [interpretation]
|
||||||
return interpretation, []
|
return interpretation, []
|
||||||
|
Loading…
Reference in New Issue
Block a user