added interpretation capture session for TF1

This commit is contained in:
Abubakar Abid 2020-10-05 06:47:36 -05:00
parent 6942d2e44d
commit 570550d3ca
2 changed files with 32 additions and 27 deletions

View File

@ -3,8 +3,6 @@ This is the core file in the `gradio` package, and defines the Interface class,
interface using the input and output types.
"""
import tempfile
import webbrowser
from gradio.inputs import InputComponent
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils
@ -12,8 +10,8 @@ import gradio.interpretation
import requests
import random
import time
import webbrowser
import inspect
from IPython import get_ipython
import sys
import weakref
import analytics
@ -209,21 +207,14 @@ class Interface:
start = time.time()
if self.capture_session and self.session is not None:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
with graph.as_default(), sess.as_default():
prediction = predict_fn(*processed_input)
else:
try:
prediction = predict_fn(*processed_input)
except ValueError as exception:
if str(exception).endswith("is not an element of this "
"graph."):
raise ValueError("It looks like you might be using "
"tensorflow < 2.0. Please "
"pass capture_session=True in "
"Interface to avoid the 'Tensor is "
"not an element of this graph.' "
"error.")
if str(exception).endswith("is not an element of this graph."):
raise ValueError(strings.en["TF1_ERROR"])
else:
raise exception
duration = time.time() - start
@ -238,17 +229,11 @@ class Interface:
else:
return predictions
def process(self, raw_input, predict_fn=None):
def process(self, raw_input):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
:param raw_input: a list of raw inputs to process and apply the prediction(s) on.
processed output: a list of processed outputs to return as the prediction(s).
duration: a list of time deltas measuring inference time for each prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
@ -258,6 +243,11 @@ class Interface:
return processed_output, durations
def interpret(self, raw_input):
"""
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
interpretation for a certain set of UI component types, as well as the custom interpretation case.
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
"""
if self.interpretation == "default":
interpreter = gradio.interpretation.default()
processed_input = []
@ -270,9 +260,22 @@ class Interface:
interpretation = interpreter(self, processed_input)
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
for i, input_interface in enumerate(self.input_interfaces)]
interpreter = self.interpretation
interpretation = interpreter(*processed_input)
if self.capture_session and self.session is not None:
graph, sess = self.session
with graph.as_default(), sess.as_default():
interpretation = interpreter(*processed_input).tolist()
else:
try:
interpretation = interpreter(*processed_input).tolist()
except ValueError as exception:
if str(exception).endswith("is not an element of this graph."):
raise ValueError(strings.en["TF1_ERROR"])
else:
raise exception
if len(raw_input) == 1:
interpretation = [interpretation]
return interpretation
@ -422,9 +425,9 @@ class Interface:
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return app, path_to_local_server, share_url
def reset_all():
for io in Interface.get_instances():
io.close()

View File

@ -7,4 +7,6 @@ en = {
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
" avoid the 'Tensor is not an element of this graph.' error."
}