removed load_fn, refactored capture_session

This commit is contained in:
aliabd 2020-06-29 14:34:53 -07:00
parent a04aca35f4
commit 6c5f2444c2
2 changed files with 21 additions and 67 deletions

View File

@ -1,19 +0,0 @@
"""
This file is used by launch models on a hosted service, like `GradioHub`
"""
import tempfile
import traceback
import webbrowser
import gradio.inputs
import gradio.outputs
from gradio import networking, strings
from distutils.version import StrictVersion
import pkg_resources
import requests
import random
import time
def launch_from_config(path):
pass

View File

@ -29,8 +29,7 @@ class Interface:
"""
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
load_fn=None, capture_session=False, title=None, description=None,
live=False, show_input=True, show_output=True, capture_session=False, title=None, description=None,
server_name=LOCALHOST_IP):
"""
:param fn: a function that will process the input panel data from the interface and return the output panel data.
@ -68,8 +67,6 @@ class Interface:
fn = [fn]
self.output_interfaces *= len(fn)
self.predict = fn
self.load_fn = load_fn
self.context = None
self.verbose = verbose
self.status = "OFF"
self.saliency = saliency
@ -102,51 +99,29 @@ class Interface:
def process(self, raw_input):
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
raw_input[i]) for i, input_interface in
enumerate(self.input_interfaces)]
predictions = []
for predict_fn in self.predict:
if self.context:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
self.context)
else:
try:
prediction = predict_fn(*processed_input, self.context)
except ValueError as exception:
if str(exception).endswith("is not an element "
"of this graph."):
print("It looks like you might be "
"using tensorflow < 2.0. Please pass "
"capture_session=True in Interface to avoid "
"the 'Tensor is not an element of this "
"graph.' "
"error.")
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
else:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
try:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
except ValueError as exception:
if str(exception).endswith("is not an element "
"of this graph."):
print("It looks like you might be "
"using tensorflow < 2.0. Please pass "
"capture_session=True in Interface to avoid "
"the 'Tensor is not an element of this graph.' "
"error.")
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
else:
try:
prediction = predict_fn(*processed_input)
except ValueError as exception:
if str(exception).endswith("is not an element of this "
"graph."):
raise ValueError("It looks like you might be using "
"tensorflow < 2.0. Please "
"pass capture_session=True in "
"Interface to avoid the 'Tensor is "
"not an element of this graph.' "
"error.")
else:
raise exception
if len(self.output_interfaces) / \
len(self.predict) == 1:
@ -218,8 +193,6 @@ class Interface:
"""
# if validate and not self.validate_flag:
# self.validate()
context = self.load_fn() if self.load_fn else None
self.context = context
if self.capture_session:
import tensorflow as tf