brought back sessions for TF 1.x

This commit is contained in:
Abubakar Abid 2022-01-12 13:49:46 -06:00
parent 286ea40794
commit 85b0d2df34
2 changed files with 22 additions and 4 deletions

View File

@ -206,8 +206,16 @@ class Interface:
if capture_session is not None:
warnings.warn("The `capture_session` parameter in the `Interface`"
" is deprecated and has no effect.")
" is deprecated and may be removed in the future.")
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
except (ImportError, AttributeError):
# If they are using TF >= 2.0 or don't have TF,
# just ignore this parameter.
pass
if server_name is not None or server_port is not None:
raise DeprecationWarning(
"The `server_name` and `server_port` parameters in `Interface`"
@ -400,7 +408,12 @@ class Interface:
for predict_fn in self.predict:
start = time.time()
prediction = predict_fn(*processed_input)
if self.capture_session and self.session is not None: # For TF 1.x
graph, sess = self.session
with graph.as_default(), sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
duration = time.time() - start
if len(self.output_components) == len(self.predict):

View File

@ -115,7 +115,12 @@ def run_interpret(interface, raw_input):
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)]
interpreter = interface.interpretation
interpretation = interpreter(*processed_input)
if interface.capture_session and interface.session is not None:
graph, sess = interface.session
with graph.as_default(), sess.as_default():
interpretation = interpreter(*processed_input)
else:
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return interpretation, []