mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
changes
This commit is contained in:
parent
49bd10b69b
commit
1b567e7144
@ -506,9 +506,9 @@ class Interface(Blocks):
|
||||
with Row():
|
||||
flag_btn = Button("Flag")
|
||||
submit_btn.click(
|
||||
lambda *args: self.process(args)[0][0]
|
||||
lambda *args: self.run_prediction(args, return_duration=False)[0]
|
||||
if len(self.output_components) == 1
|
||||
else self.process(args)[0],
|
||||
else self.run_prediction(args, return_duration=False),
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
)
|
||||
@ -629,9 +629,19 @@ class Interface(Blocks):
|
||||
processed output: a list of processed outputs to return as the prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each prediction fn.
|
||||
"""
|
||||
processed_input = [
|
||||
input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)
|
||||
]
|
||||
predictions, durations = self.run_prediction(
|
||||
raw_input, return_duration=True
|
||||
processed_input, return_duration=True
|
||||
)
|
||||
processed_output = [
|
||||
output_component.postprocess(predictions[i])
|
||||
if predictions[i] is not None
|
||||
else None
|
||||
for i, output_component in enumerate(self.output_components)
|
||||
]
|
||||
avg_durations = []
|
||||
for i, duration in enumerate(durations):
|
||||
self.predict_durations[i][0] += duration
|
||||
@ -642,7 +652,7 @@ class Interface(Blocks):
|
||||
if hasattr(self, "config"):
|
||||
self.config["avg_durations"] = avg_durations
|
||||
|
||||
return predictions, durations
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(self, raw_input: List[Any]) -> List[Any]:
|
||||
return interpretation.run_interpret(self, raw_input)
|
||||
|
Loading…
Reference in New Issue
Block a user