capture_session flag

This commit is contained in:
aliabd 2020-06-18 09:35:51 -07:00
parent 1bbd84163e
commit 0ed08ffe56
3 changed files with 29 additions and 7 deletions

View File

@ -311,9 +311,10 @@ class ImageIn(AbstractInput):
im = np.array(im).flatten()
im = im * self.scale + self.shift
if self.num_channels is None:
array = im.reshape(self.image_width, self.image_height)
array = im.reshape(1, self.image_width, self.image_height)
else:
array = im.reshape(self.image_width, self.image_height, self.num_channels)
array = im.reshape(1, self.image_width, self.image_height, \
self.num_channels)
return array
def rebuild_flagged(self, dir, msg):

View File

@ -16,6 +16,7 @@ import requests
import random
import time
from IPython import get_ipython
import tensorflow as tf
LOCALHOST_IP = "127.0.0.1"
TRY_NUM_PORTS = 100
@ -30,7 +31,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
live=False, show_input=True, show_output=True,
load_fn=None):
load_fn=None, capture_session=False):
"""
:param fn: a function that will process the input panel data from the interface and return the output panel data.
:param inputs: a string or `AbstractInput` representing the input interface.
@ -73,6 +74,8 @@ class Interface:
self.show_input = show_input
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.session = None
def update_config_file(self, output_directory):
config = {
@ -154,6 +157,10 @@ class Interface:
context = self.load_fn() if self.load_fn else None
self.context = context
if self.capture_session:
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
# If an existing interface is running with this instance, close it.
if self.status == "RUNNING":
if self.verbose:

View File

@ -140,11 +140,25 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
predictions = []
for predict_fn in interface.predict:
if interface.context:
prediction = predict_fn(*processed_input,
interface.context)
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / len(interface.predict) == 1:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / \
len(interface.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]