From f107c9b32ffb679ee0b079742ce882c5dca7c71f Mon Sep 17 00:00:00 2001 From: aliabd Date: Tue, 16 Jun 2020 09:15:54 -0700 Subject: [PATCH] added load_fn --- gradio/interface.py | 7 ++++++- gradio/networking.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index daf10ff6dd..947fcf6085 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -29,7 +29,8 @@ class Interface: """ def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, - live=False, show_input=True, show_output=True): + live=False, show_input=True, show_output=True, + load_fn=None): """ :param fn: a function that will process the input panel data from the interface and return the output panel data. :param inputs: a string or `AbstractInput` representing the input interface. @@ -63,6 +64,8 @@ class Interface: fn = [fn] self.output_interfaces *= len(fn) self.predict = fn + self.load_fn = load_fn + self.context = None self.verbose = verbose self.status = "OFF" self.saliency = saliency @@ -148,6 +151,8 @@ class Interface: """ # if validate and not self.validate_flag: # self.validate() + context = self.load_fn() if self.load_fn else None + self.context = context # If an existing interface is running with this instance, close it. if self.status == "RUNNING": diff --git a/gradio/networking.py b/gradio/networking.py index ac8e2c99bd..3ba2f0754f 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -139,7 +139,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None): processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)] predictions = [] for predict_fn in interface.predict: - prediction = predict_fn(*processed_input) + if interface.context: + prediction = predict_fn(*processed_input, + interface.context) + else: + prediction = predict_fn(*processed_input) if len(interface.output_interfaces) / len(interface.predict) == 1: prediction = [prediction] predictions.extend(prediction)