This commit is contained in:
Ali Abid 2022-03-24 22:41:01 -07:00
parent 49bd10b69b
commit 1b567e7144

View File

@ -506,9 +506,9 @@ class Interface(Blocks):
with Row():
flag_btn = Button("Flag")
submit_btn.click(
lambda *args: self.process(args)[0][0]
lambda *args: self.run_prediction(args, return_duration=False)[0]
if len(self.output_components) == 1
else self.process(args)[0],
else self.run_prediction(args, return_duration=False),
self.input_components,
self.output_components,
)
@ -629,9 +629,19 @@ class Interface(Blocks):
processed output: a list of processed outputs to return as the prediction(s).
duration: a list of time deltas measuring inference time for each prediction fn.
"""
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)
]
predictions, durations = self.run_prediction(
raw_input, return_duration=True
processed_input, return_duration=True
)
processed_output = [
output_component.postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_component in enumerate(self.output_components)
]
avg_durations = []
for i, duration in enumerate(durations):
self.predict_durations[i][0] += duration
@ -642,7 +652,7 @@ class Interface(Blocks):
if hasattr(self, "config"):
self.config["avg_durations"] = avg_durations
return predictions, durations
return processed_output, durations
def interpret(self, raw_input: List[Any]) -> List[Any]:
return interpretation.run_interpret(self, raw_input)