mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
added load_fn
This commit is contained in:
parent
5b9d73d915
commit
63f96985f9
@ -29,7 +29,8 @@ class Interface:
|
||||
"""
|
||||
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
|
||||
live=False, show_input=True, show_output=True):
|
||||
live=False, show_input=True, show_output=True,
|
||||
load_fn=None):
|
||||
"""
|
||||
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
@ -63,6 +64,8 @@ class Interface:
|
||||
fn = [fn]
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.load_fn = load_fn
|
||||
self.context = None
|
||||
self.verbose = verbose
|
||||
self.status = "OFF"
|
||||
self.saliency = saliency
|
||||
@ -148,6 +151,8 @@ class Interface:
|
||||
"""
|
||||
# if validate and not self.validate_flag:
|
||||
# self.validate()
|
||||
context = self.load_fn() if self.load_fn else None
|
||||
self.context = context
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
|
@ -139,7 +139,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if interface.context:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / len(interface.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
|
Loading…
Reference in New Issue
Block a user