mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-13 11:57:29 +08:00
removed load_fn, refactored capture_session
This commit is contained in:
parent
a04aca35f4
commit
6c5f2444c2
@ -1,19 +0,0 @@
|
|||||||
"""
|
|
||||||
This file is used by launch models on a hosted service, like `GradioHub`
|
|
||||||
"""
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import traceback
|
|
||||||
import webbrowser
|
|
||||||
|
|
||||||
import gradio.inputs
|
|
||||||
import gradio.outputs
|
|
||||||
from gradio import networking, strings
|
|
||||||
from distutils.version import StrictVersion
|
|
||||||
import pkg_resources
|
|
||||||
import requests
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
|
|
||||||
def launch_from_config(path):
|
|
||||||
pass
|
|
@ -29,8 +29,7 @@ class Interface:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
|
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
|
||||||
live=False, show_input=True, show_output=True,
|
live=False, show_input=True, show_output=True, capture_session=False, title=None, description=None,
|
||||||
load_fn=None, capture_session=False, title=None, description=None,
|
|
||||||
server_name=LOCALHOST_IP):
|
server_name=LOCALHOST_IP):
|
||||||
"""
|
"""
|
||||||
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
||||||
@ -68,8 +67,6 @@ class Interface:
|
|||||||
fn = [fn]
|
fn = [fn]
|
||||||
self.output_interfaces *= len(fn)
|
self.output_interfaces *= len(fn)
|
||||||
self.predict = fn
|
self.predict = fn
|
||||||
self.load_fn = load_fn
|
|
||||||
self.context = None
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.status = "OFF"
|
self.status = "OFF"
|
||||||
self.saliency = saliency
|
self.saliency = saliency
|
||||||
@ -102,51 +99,29 @@ class Interface:
|
|||||||
|
|
||||||
def process(self, raw_input):
|
def process(self, raw_input):
|
||||||
processed_input = [input_interface.preprocess(
|
processed_input = [input_interface.preprocess(
|
||||||
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
|
raw_input[i]) for i, input_interface in
|
||||||
|
enumerate(self.input_interfaces)]
|
||||||
predictions = []
|
predictions = []
|
||||||
for predict_fn in self.predict:
|
for predict_fn in self.predict:
|
||||||
if self.context:
|
if self.capture_session:
|
||||||
if self.capture_session:
|
graph, sess = self.session
|
||||||
graph, sess = self.session
|
with graph.as_default():
|
||||||
with graph.as_default():
|
with sess.as_default():
|
||||||
with sess.as_default():
|
|
||||||
prediction = predict_fn(*processed_input,
|
|
||||||
self.context)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
prediction = predict_fn(*processed_input, self.context)
|
|
||||||
except ValueError as exception:
|
|
||||||
if str(exception).endswith("is not an element "
|
|
||||||
"of this graph."):
|
|
||||||
print("It looks like you might be "
|
|
||||||
"using tensorflow < 2.0. Please pass "
|
|
||||||
"capture_session=True in Interface to avoid "
|
|
||||||
"the 'Tensor is not an element of this "
|
|
||||||
"graph.' "
|
|
||||||
"error.")
|
|
||||||
prediction = predict_fn(*processed_input)
|
|
||||||
else:
|
|
||||||
prediction = predict_fn(*processed_input)
|
|
||||||
else:
|
|
||||||
if self.capture_session:
|
|
||||||
graph, sess = self.session
|
|
||||||
with graph.as_default():
|
|
||||||
with sess.as_default():
|
|
||||||
prediction = predict_fn(*processed_input)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
prediction = predict_fn(*processed_input)
|
prediction = predict_fn(*processed_input)
|
||||||
except ValueError as exception:
|
else:
|
||||||
if str(exception).endswith("is not an element "
|
try:
|
||||||
"of this graph."):
|
prediction = predict_fn(*processed_input)
|
||||||
print("It looks like you might be "
|
except ValueError as exception:
|
||||||
"using tensorflow < 2.0. Please pass "
|
if str(exception).endswith("is not an element of this "
|
||||||
"capture_session=True in Interface to avoid "
|
"graph."):
|
||||||
"the 'Tensor is not an element of this graph.' "
|
raise ValueError("It looks like you might be using "
|
||||||
"error.")
|
"tensorflow < 2.0. Please "
|
||||||
prediction = predict_fn(*processed_input)
|
"pass capture_session=True in "
|
||||||
else:
|
"Interface to avoid the 'Tensor is "
|
||||||
prediction = predict_fn(*processed_input)
|
"not an element of this graph.' "
|
||||||
|
"error.")
|
||||||
|
else:
|
||||||
|
raise exception
|
||||||
|
|
||||||
if len(self.output_interfaces) / \
|
if len(self.output_interfaces) / \
|
||||||
len(self.predict) == 1:
|
len(self.predict) == 1:
|
||||||
@ -218,8 +193,6 @@ class Interface:
|
|||||||
"""
|
"""
|
||||||
# if validate and not self.validate_flag:
|
# if validate and not self.validate_flag:
|
||||||
# self.validate()
|
# self.validate()
|
||||||
context = self.load_fn() if self.load_fn else None
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
if self.capture_session:
|
if self.capture_session:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
Loading…
Reference in New Issue
Block a user