mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
capture_session flag
This commit is contained in:
parent
1bbd84163e
commit
0ed08ffe56
@ -311,9 +311,10 @@ class ImageIn(AbstractInput):
|
||||
im = np.array(im).flatten()
|
||||
im = im * self.scale + self.shift
|
||||
if self.num_channels is None:
|
||||
array = im.reshape(self.image_width, self.image_height)
|
||||
array = im.reshape(1, self.image_width, self.image_height)
|
||||
else:
|
||||
array = im.reshape(self.image_width, self.image_height, self.num_channels)
|
||||
array = im.reshape(1, self.image_width, self.image_height, \
|
||||
self.num_channels)
|
||||
return array
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
|
@ -16,6 +16,7 @@ import requests
|
||||
import random
|
||||
import time
|
||||
from IPython import get_ipython
|
||||
import tensorflow as tf
|
||||
|
||||
LOCALHOST_IP = "127.0.0.1"
|
||||
TRY_NUM_PORTS = 100
|
||||
@ -30,7 +31,7 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
|
||||
live=False, show_input=True, show_output=True,
|
||||
load_fn=None):
|
||||
load_fn=None, capture_session=False):
|
||||
"""
|
||||
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
@ -73,6 +74,8 @@ class Interface:
|
||||
self.show_input = show_input
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.session = None
|
||||
|
||||
def update_config_file(self, output_directory):
|
||||
config = {
|
||||
@ -154,6 +157,10 @@ class Interface:
|
||||
context = self.load_fn() if self.load_fn else None
|
||||
self.context = context
|
||||
|
||||
if self.capture_session:
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
if self.verbose:
|
||||
|
@ -140,11 +140,25 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
if interface.context:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / len(interface.predict) == 1:
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / \
|
||||
len(interface.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
|
Loading…
Reference in New Issue
Block a user