capture session warning

This commit is contained in:
aliabd 2020-06-28 23:37:53 -07:00
parent b000f86261
commit 8cecfc29e4

View File

@ -108,8 +108,15 @@ class Interface:
prediction = predict_fn(*processed_input,
self.context)
else:
prediction = predict_fn(*processed_input,
self.context)
try:
prediction = predict_fn(*processed_input, self.context)
except ValueError:
print("It looks like you might be "
"using tensorflow < 2.0. Please pass "
"capture_session=True in Interface to avoid "
"a 'Tensor is not an element of this graph.' "
"error.")
prediction = predict_fn(*processed_input, self.context)
else:
if self.capture_session:
graph, sess = self.session
@ -117,7 +124,16 @@ class Interface:
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
try:
prediction = predict_fn(*processed_input)
except ValueError:
print("It looks like you might be "
"using tensorflow < 2.0. Please pass "
"capture_session=True in Interface to avoid "
"a 'Tensor is not an element of this graph.' "
"error.")
prediction = predict_fn(*processed_input)
if len(self.output_interfaces) / \
len(self.predict) == 1:
prediction = [prediction]