mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
added interpretation capture session for TF1
This commit is contained in:
parent
6942d2e44d
commit
570550d3ca
@ -3,8 +3,6 @@ This is the core file in the `gradio` package, and defines the Interface class,
|
||||
interface using the input and output types.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import webbrowser
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
@ -12,8 +10,8 @@ import gradio.interpretation
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
import webbrowser
|
||||
import inspect
|
||||
from IPython import get_ipython
|
||||
import sys
|
||||
import weakref
|
||||
import analytics
|
||||
@ -209,21 +207,14 @@ class Interface:
|
||||
start = time.time()
|
||||
if self.capture_session and self.session is not None:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
with graph.as_default(), sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
try:
|
||||
prediction = predict_fn(*processed_input)
|
||||
except ValueError as exception:
|
||||
if str(exception).endswith("is not an element of this "
|
||||
"graph."):
|
||||
raise ValueError("It looks like you might be using "
|
||||
"tensorflow < 2.0. Please "
|
||||
"pass capture_session=True in "
|
||||
"Interface to avoid the 'Tensor is "
|
||||
"not an element of this graph.' "
|
||||
"error.")
|
||||
if str(exception).endswith("is not an element of this graph."):
|
||||
raise ValueError(strings.en["TF1_ERROR"])
|
||||
else:
|
||||
raise exception
|
||||
duration = time.time() - start
|
||||
@ -238,17 +229,11 @@ class Interface:
|
||||
else:
|
||||
return predictions
|
||||
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
def process(self, raw_input):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
:param raw_input: a list of raw inputs to process and apply the prediction(s) on.
|
||||
processed output: a list of processed outputs to return as the prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
@ -258,6 +243,11 @@ class Interface:
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(self, raw_input):
|
||||
"""
|
||||
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
||||
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
||||
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
|
||||
"""
|
||||
if self.interpretation == "default":
|
||||
interpreter = gradio.interpretation.default()
|
||||
processed_input = []
|
||||
@ -270,9 +260,22 @@ class Interface:
|
||||
interpretation = interpreter(self, processed_input)
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
interpreter = self.interpretation
|
||||
interpretation = interpreter(*processed_input)
|
||||
|
||||
if self.capture_session and self.session is not None:
|
||||
graph, sess = self.session
|
||||
with graph.as_default(), sess.as_default():
|
||||
interpretation = interpreter(*processed_input).tolist()
|
||||
else:
|
||||
try:
|
||||
interpretation = interpreter(*processed_input).tolist()
|
||||
except ValueError as exception:
|
||||
if str(exception).endswith("is not an element of this graph."):
|
||||
raise ValueError(strings.en["TF1_ERROR"])
|
||||
else:
|
||||
raise exception
|
||||
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
return interpretation
|
||||
@ -422,9 +425,9 @@ class Interface:
|
||||
if not is_in_interactive_mode:
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
@ -7,4 +7,6 @@ en = {
|
||||
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
|
||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
|
||||
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
|
||||
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
|
||||
" avoid the 'Tensor is not an element of this graph.' error."
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user