diff --git a/gradio/interface.py b/gradio/interface.py index e17c8f71db..adfa406a27 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -506,9 +506,9 @@ class Interface(Blocks): with Row(): flag_btn = Button("Flag") submit_btn.click( - lambda *args: self.process(args)[0][0] + lambda *args: self.run_prediction(args, return_duration=False)[0] if len(self.output_components) == 1 - else self.process(args)[0], + else self.run_prediction(args, return_duration=False), self.input_components, self.output_components, ) @@ -629,9 +629,19 @@ class Interface(Blocks): processed output: a list of processed outputs to return as the prediction(s). duration: a list of time deltas measuring inference time for each prediction fn. """ + processed_input = [ + input_component.preprocess(raw_input[i]) + for i, input_component in enumerate(self.input_components) + ] predictions, durations = self.run_prediction( - raw_input, return_duration=True + processed_input, return_duration=True ) + processed_output = [ + output_component.postprocess(predictions[i]) + if predictions[i] is not None + else None + for i, output_component in enumerate(self.output_components) + ] avg_durations = [] for i, duration in enumerate(durations): self.predict_durations[i][0] += duration @@ -642,7 +652,7 @@ class Interface(Blocks): if hasattr(self, "config"): self.config["avg_durations"] = avg_durations - return predictions, durations + return processed_output, durations def interpret(self, raw_input: List[Any]) -> List[Any]: return interpretation.run_interpret(self, raw_input)