Progress indicator bar (#997)

Create Status Tracker component to report progress on function calls

Co-authored-by: Ali Abid <aliabid94@gmail.com>
This commit is contained in:
aliabid94 2022-04-15 02:20:19 -07:00 committed by GitHub
parent f51117487f
commit 1c2f430a7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 609 additions and 199 deletions

View File

@ -1,10 +1,13 @@
import gradio as gr
def greet(name):
return "Hello " + name + "!!"
demo = gr.Interface(fn=greet, inputs=gr.component("textarea"), outputs=gr.component("textarea"))
demo = gr.Interface(
fn=greet, inputs=gr.component("textarea"), outputs=gr.component("textarea")
)
if __name__ == "__main__":

View File

@ -7,6 +7,7 @@ from gradio import Templates
def snap(image):
return np.flipud(image)
demo = gr.Interface(snap, gr.component("webcam"), gr.component("image"))
if __name__ == "__main__":

View File

@ -1,9 +1,16 @@
import random
import gradio as gr
import random
import time
xray_model = lambda diseases, img: {disease: random.random() for disease in diseases}
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
def xray_model(diseases, img):
time.sleep(4)
return {disease: random.random() for disease in diseases}
def ct_model(diseases, img):
time.sleep(3)
return {disease: 0.1 for disease in diseases}
with gr.Blocks() as demo:
@ -25,8 +32,12 @@ With this model you can lorem ipsum
xray_scan = gr.Image()
xray_results = gr.JSON()
xray_run = gr.Button("Run")
xray_progress = gr.StatusTracker(cover_container=True)
xray_run.click(
xray_model, inputs=[disease, xray_scan], outputs=xray_results
xray_model,
inputs=[disease, xray_scan],
outputs=xray_results,
status_tracker=xray_progress,
)
with gr.TabItem("CT Scan"):
@ -34,9 +45,21 @@ With this model you can lorem ipsum
ct_scan = gr.Image()
ct_results = gr.JSON()
ct_run = gr.Button("Run")
ct_run.click(ct_model, inputs=[disease, ct_scan], outputs=ct_results)
ct_progress = gr.StatusTracker(cover_container=True)
ct_run.click(
ct_model,
inputs=[disease, ct_scan],
outputs=ct_results,
status_tracker=ct_progress,
)
overall_probability = gr.Textbox()
upload_btn = gr.Button("Upload Results")
upload_btn.click(
lambda ct, xr: time.sleep(5),
inputs=[ct_results, xray_results],
outputs=[],
status_tracker=gr.StatusTracker(),
)
if __name__ == "__main__":
demo.launch()

View File

@ -1,4 +1,4 @@
Metadata-Version: 2.1
Metadata-Version: 1.0
Name: gradio
Version: 2.7.0b70
Summary: Python library for easily interacting with trained machine learning models
@ -6,9 +6,6 @@ Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir
Author-email: team@gradio.app
License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN
License-File: LICENSE
UNKNOWN

View File

@ -1,5 +1,5 @@
analytics-python
aiohttp
analytics-python
fastapi
ffmpy
markdown-it-py[linkify,plugins]
@ -10,7 +10,7 @@ pandas
paramiko
pillow
pycryptodome
python-multipart
pydub
python-multipart
requests
uvicorn

View File

@ -25,6 +25,7 @@ from gradio.components import (
Number,
Radio,
Slider,
StatusTracker,
Textbox,
Timeseries,
Variable,

View File

@ -15,7 +15,7 @@ from gradio.process_examples import cache_interface_examples
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from fastapi.applications import FastAPI
from gradio.components import Component
from gradio.components import Component, StatusTracker
class Block:
@ -60,6 +60,7 @@ class Block:
postprocess: bool = True,
queue=False,
no_target: bool = False,
status_tracker: Optional[StatusTracker] = None,
) -> None:
"""
Adds an event to the component's dependencies.
@ -70,7 +71,9 @@ class Block:
outputs: output list
preprocess: whether to run the preprocess methods of components
postprocess: whether to run the postprocess methods of components
queue: if True, will store multiple calls in queue and run in order instead of in parallel with multiple threads
no_target: if True, sets "targets" to [], used for Blocks "load" event
status_tracker: StatusTracker to visualize function progress
Returns: None
"""
# Support for singular parameter
@ -79,7 +82,7 @@ class Block:
if not isinstance(outputs, list):
outputs = [outputs]
Context.root_block.fns.append((fn, preprocess, postprocess))
Context.root_block.fns.append(BlockFunction(fn, preprocess, postprocess))
Context.root_block.dependencies.append(
{
"targets": [self._id] if not no_target else [],
@ -87,6 +90,9 @@ class Block:
"inputs": [block._id for block in inputs],
"outputs": [block._id for block in outputs],
"queue": queue,
"status_tracker": status_tracker._id
if status_tracker is not None
else None,
}
)
@ -183,6 +189,15 @@ class TabItem(BlockContext):
self.set_event_trigger("change", fn, inputs, outputs)
class BlockFunction:
def __init__(self, fn: Callable, preprocess: bool, postprocess: bool):
self.fn = fn
self.preprocess = preprocess
self.postprocess = postprocess
self.total_runtime = 0
self.total_runs = 0
class Blocks(BlockContext):
def __init__(
self,
@ -208,7 +223,7 @@ class Blocks(BlockContext):
super().__init__()
self.blocks = {}
self.fns = []
self.fns: List[BlockFunction] = []
self.dependencies = []
self.mode = mode
@ -238,10 +253,10 @@ class Blocks(BlockContext):
"""
raw_input = data["data"]
fn_index = data["fn_index"]
fn, preprocess, postprocess = self.fns[fn_index]
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
if preprocess:
if block_fn.preprocess:
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block = self.blocks[input_id]
@ -249,12 +264,16 @@ class Blocks(BlockContext):
processed_input.append(state.get(input_id))
else:
processed_input.append(block.preprocess(raw_input[i]))
predictions = fn(*processed_input)
else:
predictions = fn(*raw_input)
processed_input = raw_input
start = time.time()
predictions = block_fn.fn(*processed_input)
duration = time.time() - start
block_fn.total_runtime += duration
block_fn.total_runs += 1
if len(dependency["outputs"]) == 1:
predictions = (predictions,)
if postprocess:
if block_fn.postprocess:
output = []
for i, output_id in enumerate(dependency["outputs"]):
block = self.blocks[output_id]
@ -269,7 +288,11 @@ class Blocks(BlockContext):
)
else:
output = predictions
return {"data": output}
return {
"data": output,
"duration": duration,
"average_duration": block_fn.total_runtime / block_fn.total_runs,
}
def get_template_context(self):
return {"type": "column"}

View File

@ -387,25 +387,43 @@ class Textbox(Component):
"""
return x
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
def submit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def submit(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("submit", fn, inputs, outputs)
self.set_event_trigger(
"submit", fn, inputs, outputs, status_tracker=status_tracker
)
class Number(Component):
@ -517,25 +535,43 @@ class Number(Component):
"""
return y
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
def submit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def submit(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("submit", fn, inputs, outputs)
self.set_event_trigger(
"submit", fn, inputs, outputs, status_tracker=status_tracker
)
class Slider(Component):
@ -643,15 +679,24 @@ class Slider(Component):
"""
return y
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Checkbox(Component):
@ -735,15 +780,24 @@ class Checkbox(Component):
"""
return x
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class CheckboxGroup(Component):
@ -863,15 +917,24 @@ class CheckboxGroup(Component):
"""
return x
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Radio(Component):
@ -971,15 +1034,24 @@ class Radio(Component):
"""
return x
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Dropdown(Radio):
@ -1312,15 +1384,24 @@ class Image(Component):
y = processing_utils.decode_base64_to_file(x).name
return y
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
def edit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
@ -1454,15 +1535,24 @@ class Video(Component):
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
def clear(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
@ -1747,15 +1837,24 @@ class Audio(Component):
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
def edit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
@ -1925,15 +2024,24 @@ class File(Component):
"data": processing_utils.encode_file_to_base64(y),
}
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
def clear(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
@ -2102,15 +2210,24 @@ class Dataframe(Component):
+ ". Please choose from: 'pandas', 'numpy', 'array'."
)
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Timeseries(Component):
@ -2200,15 +2317,24 @@ class Timeseries(Component):
"""
return {"headers": y.columns.values.tolist(), "data": y.values.tolist()}
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Variable(Component):
@ -2339,15 +2465,24 @@ class Label(Component):
except ValueError:
return data
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class KeyValues(Component):
@ -2430,15 +2565,24 @@ class HighlightedText(Component):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class JSON(Component):
@ -2488,15 +2632,24 @@ class JSON(Component):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class HTML(Component):
@ -2536,15 +2689,24 @@ class HTML(Component):
"""
return x
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Carousel(Component):
@ -2627,15 +2789,24 @@ class Carousel(Component):
for sample_set in json.loads(data)
]
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
class Chatbot(Component):
@ -2674,15 +2845,24 @@ class Chatbot(Component):
"""
return y
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("change", fn, inputs, outputs)
self.set_event_trigger(
"change", fn, inputs, outputs, status_tracker=status_tracker
)
# Static Components
@ -2745,27 +2925,48 @@ class Button(Component):
inputs: List[Component],
outputs: List[Component],
queue=False,
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("click", fn, inputs, outputs, queue=queue)
self.set_event_trigger(
"click",
fn,
inputs,
outputs,
queue=queue,
status_tracker=status_tracker,
)
def _click_no_preprocess(
self, fn: Callable, inputs: List[Component], outputs: List[Component]
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("click", fn, inputs, outputs, preprocess=False)
self.set_event_trigger(
"click",
fn,
inputs,
outputs,
preprocess=False,
status_tracker=status_tracker,
)
class Dataset(Component):
@ -2808,27 +3009,48 @@ class Dataset(Component):
elif self.type == "values":
return self.samples[x]
def click(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
Returns: None
"""
self.set_event_trigger("click", fn, inputs, outputs)
def _click_no_postprocess(
self, fn: Callable, inputs: List[Component], outputs: List[Component]
def click(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger("click", fn, inputs, outputs, postprocess=False)
self.set_event_trigger(
"click", fn, inputs, outputs, status_tracker=status_tracker
)
def _click_no_postprocess(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
status: StatusTracker to visualize function progress
Returns: None
"""
self.set_event_trigger(
"click",
fn,
inputs,
outputs,
postprocess=False,
status_tracker=status_tracker,
)
class Interpretation(Component):
@ -2893,3 +3115,32 @@ def get_component_instance(comp: str | dict | Component):
raise ValueError(
f"Component must provided as a `str` or `dict` or `Component` but is {comp}"
)
class StatusTracker(Component):
"""
Used to indicate status of a function call. Event listeners can bind to a StatusTracker with 'status=' keyword argument.
"""
def __init__(
self,
*,
cover_container: bool = False,
label: Optional[str] = None,
css: Optional[Dict] = None,
**kwargs,
):
"""
Parameters:
cover_container (bool): If True, will expand to cover parent container while function pending.
label (str): component name
css (dict): optional css parameters for the component
"""
super().__init__(label=label, css=css, **kwargs)
self.cover_container = cover_container
def get_template_context(self):
return {
"cover_container": self.cover_container,
**super().get_template_context(),
}

View File

@ -11,7 +11,6 @@ import inspect
import os
import random
import re
import time
import warnings
import weakref
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
@ -27,6 +26,7 @@ from gradio.components import (
Dataset,
Interpretation,
Markdown,
StatusTracker,
Variable,
get_component_instance,
)
@ -536,6 +536,7 @@ class Interface(Blocks):
"border-radius": "0.5rem",
}
):
status_tracker = StatusTracker(cover_container=True)
for component in self.output_components:
component.render()
with Row():
@ -543,9 +544,9 @@ class Interface(Blocks):
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = (
lambda *args: self.run_prediction(args, return_duration=False)[0]
lambda *args: self.run_prediction(args)[0]
if len(self.output_components) == 1
else self.run_prediction(args, return_duration=False)
else self.run_prediction(args)
)
if self.live:
for component in self.input_components:
@ -558,6 +559,7 @@ class Interface(Blocks):
self.input_components,
self.output_components,
queue=self.enable_queue,
status_tracker=status_tracker,
)
clear_btn.click(
lambda: [
@ -613,6 +615,7 @@ class Interface(Blocks):
inputs=self.input_components + self.output_components,
outputs=interpretation_set
+ [input_component_column, interpret_component_column],
status_tracker=status_tracker,
)
def __call__(self, *params):
@ -621,7 +624,7 @@ class Interface(Blocks):
): # skip the preprocessing/postprocessing if sending to a remote API
output = self.run_prediction(params, called_directly=True)
else:
output, _ = self.process(params)
output = self.process(params)
return output[0] if len(output) == 1 else output
def __str__(self):
@ -662,20 +665,16 @@ class Interface(Blocks):
def run_prediction(
self,
processed_input: List[Any],
return_duration: bool = False,
called_directly: bool = False,
) -> List[Any] | Tuple[List[Any], List[float]]:
"""
Runs the prediction function with the given (already processed) inputs.
Parameters:
processed_input (list): A list of processed inputs.
return_duration (bool): Whether to return the duration of the prediction.
called_directly (bool): Whether the prediction is being called
directly (i.e. as a function, not through the GUI).
Returns:
predictions (list): A list of predictions (not post-processed).
durations (list): A list of durations for each prediction
(only returned if `return_duration` is True).
"""
if self.api_mode: # Serialize the input
processed_input = [
@ -683,18 +682,15 @@ class Interface(Blocks):
for i, input_component in enumerate(self.input_components)
]
predictions = []
durations = []
output_component_counter = 0
for predict_fn in self.predict:
start = time.time()
if self.capture_session and self.session is not None: # For TF 1.x
graph, sess = self.session
with graph.as_default(), sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
duration = time.time() - start
if len(self.output_components) == len(self.predict):
prediction = [prediction]
@ -714,13 +710,9 @@ class Interface(Blocks):
)
output_component_counter += 1
durations.append(duration)
predictions.extend(prediction)
if return_duration:
return predictions, durations
else:
return predictions
return predictions
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
"""
@ -736,26 +728,14 @@ class Interface(Blocks):
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)
]
predictions, durations = self.run_prediction(
processed_input, return_duration=True
)
predictions = self.run_prediction(processed_input)
processed_output = [
output_component.postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_component in enumerate(self.output_components)
]
avg_durations = []
for i, duration in enumerate(durations):
self.predict_durations[i][0] += duration
self.predict_durations[i][1] += 1
avg_durations.append(
self.predict_durations[i][0] / self.predict_durations[i][1]
)
if hasattr(self, "config"):
self.config["avg_durations"] = avg_durations
return processed_output, durations
return processed_output
def interpret(self, raw_input: List[Any]) -> List[Any]:
return [

View File

@ -26,8 +26,8 @@ def process_example(
interface.input_components[i].preprocess_example(example)
for i, example in enumerate(example_set)
]
prediction, durations = interface.process(raw_input)
return prediction, durations
prediction = interface.process(raw_input)
return prediction
def cache_interface_examples(interface: Interface) -> None:
@ -44,7 +44,7 @@ def cache_interface_examples(interface: Interface) -> None:
cache_logger.setup(interface.output_components, CACHED_FOLDER)
for example_id, _ in enumerate(interface.examples):
try:
prediction = process_example(interface, example_id)[0]
prediction = process_example(interface, example_id)
cache_logger.flag(prediction)
except Exception as e:
shutil.rmtree(CACHED_FOLDER)

View File

@ -45,8 +45,8 @@
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<title>Gradio</title>
<script type="module" crossorigin src="./assets/index.6c5c73c1.js"></script>
<link rel="stylesheet" href="./assets/index.38c11487.css">
<script type="module" crossorigin src="./assets/index.2c3c09fa.js"></script>
<link rel="stylesheet" href="./assets/index.689bcbeb.css">
</head>
<body style="height: 100%; margin: 0; padding: 0">

View File

@ -168,6 +168,7 @@ XRAY_CONFIG = {
"inputs": [2, 6],
"outputs": [7],
"queue": False,
"status_tracker": None,
},
{
"targets": [13],
@ -175,6 +176,7 @@ XRAY_CONFIG = {
"inputs": [2, 11],
"outputs": [12],
"queue": False,
"status_tracker": None,
},
{
"targets": [],
@ -182,6 +184,7 @@ XRAY_CONFIG = {
"inputs": [],
"outputs": [14],
"queue": False,
"status_tracker": None,
},
],
}
@ -356,6 +359,7 @@ XRAY_CONFIG_DIFF_IDS = {
"inputs": [22, 6],
"outputs": [7],
"queue": False,
"status_tracker": None,
},
{
"targets": [13],
@ -363,6 +367,7 @@ XRAY_CONFIG_DIFF_IDS = {
"inputs": [22, 11],
"outputs": [12],
"queue": False,
"status_tracker": None,
},
],
}
@ -537,6 +542,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
"inputs": [2, 6],
"outputs": [7],
"queue": False,
"status_tracker": None,
},
{
"targets": [13],
@ -544,6 +550,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
"inputs": [2, 11],
"outputs": [12],
"queue": False,
"status_tracker": None,
},
],
}

View File

@ -100,7 +100,7 @@ class TestTextbox(unittest.TestCase):
Interface, process, interpret,
"""
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
self.assertEqual(iface.process(["Hello"]), ["olleH"])
iface = gr.Interface(
lambda sentence: max([len(word) for word in sentence.split()]),
gr.Textbox(),
@ -142,9 +142,9 @@ class TestTextbox(unittest.TestCase):
"""
iface = gr.Interface(lambda x: x[-1], "textbox", gr.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"])
self.assertEqual(iface.process(["Hello"]), ["o"])
iface = gr.Interface(lambda x: x / 2, "number", gr.Textbox())
self.assertEqual(iface.process([10])[0], ["5.0"])
self.assertEqual(iface.process([10]), ["5.0"])
class TestNumber(unittest.TestCase):
@ -193,7 +193,7 @@ class TestNumber(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: x**2, "number", "textbox")
self.assertEqual(iface.process([2])[0], ["4.0"])
self.assertEqual(iface.process([2]), ["4.0"])
iface = gr.Interface(
lambda x: x**2, "number", "number", interpretation="default"
)
@ -216,7 +216,7 @@ class TestNumber(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: int(x) ** 2, "textbox", "number")
self.assertEqual(iface.process([2])[0], [4.0])
self.assertEqual(iface.process([2]), [4.0])
iface = gr.Interface(
lambda x: x**2, "number", "number", interpretation="default"
)
@ -275,7 +275,7 @@ class TestSlider(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
self.assertEqual(iface.process([2])[0], ["4"])
self.assertEqual(iface.process([2]), ["4"])
iface = gr.Interface(
lambda x: x**2, "slider", "number", interpretation="default"
)
@ -328,7 +328,7 @@ class TestCheckbox(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
self.assertEqual(iface.process([True])[0], [1])
self.assertEqual(iface.process([True]), [1])
iface = gr.Interface(
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
)
@ -381,8 +381,8 @@ class TestCheckboxGroup(unittest.TestCase):
"""
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
self.assertEqual(iface.process([[]])[0], [""])
self.assertEqual(iface.process([["a", "c"]]), ["a|c"])
self.assertEqual(iface.process([[]]), [""])
_ = gr.CheckboxGroup(["a", "b", "c"], type="index")
@ -426,12 +426,12 @@ class TestRadio(unittest.TestCase):
"""
radio_input = gr.Radio(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
self.assertEqual(iface.process(["c"])[0], ["cc"])
self.assertEqual(iface.process(["c"]), ["cc"])
radio_input = gr.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(
lambda x: 2 * x, radio_input, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
self.assertEqual(iface.process(["c"]), [4])
scores = iface.interpret(["b"])[0]["interpretation"]
self.assertEqual(scores, [-2.0, None, 2.0])
@ -557,7 +557,7 @@ class TestImage(unittest.TestCase):
gr.Image(shape=(30, 10), type="file"),
"image",
)
output = iface.process([img])[0][0]
output = iface.process([img])[0]
self.assertEqual(
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
)
@ -591,9 +591,7 @@ class TestImage(unittest.TestCase):
return np.random.randint(0, 256, (width, height, 3))
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
)
self.assertTrue(iface.process([10, 20])[0].startswith("data:image/png;base64"))
class TestAudio(unittest.TestCase):
@ -706,16 +704,16 @@ class TestAudio(unittest.TestCase):
return (sr, np.flipud(data))
iface = gr.Interface(reverse_audio, "audio", "audio")
reversed_data = iface.process([deepcopy(media_data.BASE64_AUDIO)])[0][0]
reversed_data = iface.process([deepcopy(media_data.BASE64_AUDIO)])[0]
reversed_input = {"name": "fake_name", "data": reversed_data}
self.assertTrue(reversed_data.startswith("data:audio/wav;base64,UklGRgA/"))
self.assertTrue(
iface.process([deepcopy(media_data.BASE64_AUDIO)])[0][0].startswith(
iface.process([deepcopy(media_data.BASE64_AUDIO)])[0].startswith(
"data:audio/wav;base64,UklGRgA/"
)
)
self.maxDiff = None
reversed_reversed_data = iface.process([reversed_input])[0][0]
reversed_reversed_data = iface.process([reversed_input])[0]
similarity = SequenceMatcher(
a=reversed_reversed_data, b=media_data.BASE64_AUDIO["data"]
).ratio()
@ -730,7 +728,7 @@ class TestAudio(unittest.TestCase):
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int16)
iface = gr.Interface(generate_noise, "slider", "audio")
self.assertTrue(iface.process([100])[0][0].startswith("data:audio/wav;base64"))
self.assertTrue(iface.process([100])[0].startswith("data:audio/wav;base64"))
class TestFile(unittest.TestCase):
@ -788,7 +786,7 @@ class TestFile(unittest.TestCase):
return os.path.getsize(file_obj.name)
iface = gr.Interface(get_size_of_file, "file", "number")
self.assertEqual(iface.process([[x_file]])[0], [10558])
self.assertEqual(iface.process([[x_file]]), [10558])
def test_as_component_as_output(self):
"""
@ -802,7 +800,7 @@ class TestFile(unittest.TestCase):
iface = gr.Interface(write_file, "text", "file")
self.assertDictEqual(
iface.process(["hello world"])[0][0],
iface.process(["hello world"])[0],
{
"name": "test.txt",
"size": 11,
@ -943,14 +941,14 @@ class TestDataframe(unittest.TestCase):
"""
x_data = [[1, 2, 3], [4, 5, 6]]
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(iface.process([x_data])[0], [6])
self.assertEqual(iface.process([x_data]), [6])
x_data = [["Tim"], ["Jon"], ["Sal"]]
def get_last(my_list):
return my_list[-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data])[0], ["Sal"])
self.assertEqual(iface.process([x_data]), ["Sal"])
def test_in_interface_as_output(self):
"""
@ -961,9 +959,7 @@ class TestDataframe(unittest.TestCase):
return array % 2 == 0
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
)
self.assertEqual(iface.process([[2, 3, 4]])[0], {"data": [[True, False, True]]})
class TestVideo(unittest.TestCase):
@ -1034,7 +1030,7 @@ class TestVideo(unittest.TestCase):
"""
x_video = deepcopy(media_data.BASE64_VIDEO)
iface = gr.Interface(lambda x: x, "video", "playable_video")
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
self.assertEqual(iface.process([x_video])[0]["data"], x_video["data"])
class TestTimeseries(unittest.TestCase):
@ -1147,7 +1143,7 @@ class TestTimeseries(unittest.TestCase):
}
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
self.assertEqual(
iface.process([x_timeseries])[0],
iface.process([x_timeseries]),
[
{
"headers": ["time", "retail", "food", "other"],
@ -1176,7 +1172,7 @@ class TestTimeseries(unittest.TestCase):
}
)
self.assertEqual(
iface.process([df])[0],
iface.process([df]),
[
{
"headers": ["time", "retail", "food", "other"],
@ -1283,7 +1279,7 @@ class TestLabel(unittest.TestCase):
}
iface = gr.Interface(rgb_distribution, "image", "label")
output = iface.process([x_img])[0][0]
output = iface.process([x_img])[0]
self.assertDictEqual(
output,
{
@ -1346,7 +1342,7 @@ class TestHighlightedText(unittest.TestCase):
iface = gr.Interface(highlight_vowels, "text", "highlight")
self.assertListEqual(
iface.process(["Helloooo"])[0][0],
iface.process(["Helloooo"])[0],
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
)
@ -1403,7 +1399,7 @@ class TestJSON(unittest.TestCase):
["O", 20],
["F", 30],
]
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
self.assertDictEqual(iface.process([y_data])[0], {"M": 35, "F": 25, "O": 20})
class TestHTML(unittest.TestCase):
@ -1432,7 +1428,7 @@ class TestHTML(unittest.TestCase):
return "<strong>" + text + "</strong>"
iface = gr.Interface(bold_text, "text", "html")
self.assertEqual(iface.process(["test"])[0][0], "<strong>test</strong>")
self.assertEqual(iface.process(["test"])[0], "<strong>test</strong>")
class TestCarousel(unittest.TestCase):
@ -1511,17 +1507,17 @@ class TestCarousel(unittest.TestCase):
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
result = iface.process([deepcopy(media_data.BASE64_IMAGE)])
self.assertTrue(result[0][0][0][0] == "Red")
self.assertTrue(result[0][0][0] == "Red")
self.assertTrue(
result[0][0][0][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
result[0][0][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
)
self.assertTrue(result[0][0][1][0] == "Green")
self.assertTrue(result[0][1][0] == "Green")
self.assertTrue(
result[0][0][1][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
result[0][1][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
)
self.assertTrue(result[0][0][2][0] == "Blue")
self.assertTrue(result[0][2][0] == "Blue")
self.assertTrue(
result[0][0][2][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
result[0][2][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
)

View File

@ -66,7 +66,7 @@ class TestTextbox(unittest.TestCase):
def test_in_interface(self):
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
self.assertEqual(iface.process(["Hello"]), ["olleH"])
iface = gr.Interface(
lambda sentence: max([len(word) for word in sentence.split()]),
gr.inputs.Textbox(),
@ -139,7 +139,7 @@ class TestNumber(unittest.TestCase):
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "number", "textbox")
self.assertEqual(iface.process([2])[0], ["4.0"])
self.assertEqual(iface.process([2]), ["4.0"])
iface = gr.Interface(
lambda x: x**2, "number", "number", interpretation="default"
)
@ -190,7 +190,7 @@ class TestSlider(unittest.TestCase):
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
self.assertEqual(iface.process([2])[0], ["4"])
self.assertEqual(iface.process([2]), ["4"])
iface = gr.Interface(
lambda x: x**2, "slider", "number", interpretation="default"
)
@ -236,7 +236,7 @@ class TestCheckbox(unittest.TestCase):
def test_in_interface(self):
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
self.assertEqual(iface.process([True])[0], [1])
self.assertEqual(iface.process([True]), [1])
iface = gr.Interface(
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
)
@ -281,8 +281,8 @@ class TestCheckboxGroup(unittest.TestCase):
def test_in_interface(self):
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
self.assertEqual(iface.process([[]])[0], [""])
self.assertEqual(iface.process([["a", "c"]]), ["a|c"])
self.assertEqual(iface.process([[]]), [""])
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
@ -319,12 +319,12 @@ class TestRadio(unittest.TestCase):
def test_in_interface(self):
radio_input = gr.inputs.Radio(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
self.assertEqual(iface.process(["c"])[0], ["cc"])
self.assertEqual(iface.process(["c"]), ["cc"])
radio_input = gr.inputs.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(
lambda x: 2 * x, radio_input, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
self.assertEqual(iface.process(["c"]), [4])
scores = iface.interpret(["b"])[0]["interpretation"]
self.assertEqual(scores, [-2.0, None, 2.0])
@ -364,12 +364,12 @@ class TestDropdown(unittest.TestCase):
def test_in_interface(self):
dropdown_input = gr.inputs.Dropdown(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, dropdown_input, "textbox")
self.assertEqual(iface.process(["c"])[0], ["cc"])
self.assertEqual(iface.process(["c"]), ["cc"])
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
iface = gr.Interface(
lambda x: 2 * x, dropdown, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
self.assertEqual(iface.process(["c"]), [4])
scores = iface.interpret(["b"])[0]["interpretation"]
self.assertEqual(scores, [-2.0, None, 2.0])
@ -445,7 +445,7 @@ class TestImage(unittest.TestCase):
gr.inputs.Image(shape=(30, 10), type="file"),
"image",
)
output = iface.process([img])[0][0]
output = iface.process([img])[0]
self.assertEqual(
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
)
@ -574,7 +574,7 @@ class TestFile(unittest.TestCase):
return os.path.getsize(file_obj.name)
iface = gr.Interface(get_size_of_file, "file", "number")
self.assertEqual(iface.process([[x_file]])[0], [10558])
self.assertEqual(iface.process([[x_file]]), [10558])
class TestDataframe(unittest.TestCase):
@ -631,14 +631,14 @@ class TestDataframe(unittest.TestCase):
def test_in_interface(self):
x_data = [[1, 2, 3], [4, 5, 6]]
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(iface.process([x_data])[0], [6])
self.assertEqual(iface.process([x_data]), [6])
x_data = [["Tim"], ["Jon"], ["Sal"]]
def get_last(my_list):
return my_list[-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data])[0], ["Sal"])
self.assertEqual(iface.process([x_data]), ["Sal"])
class TestVideo(unittest.TestCase):
@ -680,7 +680,7 @@ class TestVideo(unittest.TestCase):
def test_in_interface(self):
x_video = media_data.BASE64_VIDEO
iface = gr.Interface(lambda x: x, "video", "playable_video")
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
self.assertEqual(iface.process([x_video])[0]["data"], x_video["data"])
class TestTimeseries(unittest.TestCase):
@ -729,7 +729,7 @@ class TestTimeseries(unittest.TestCase):
}
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
self.assertEqual(
iface.process([x_timeseries])[0],
iface.process([x_timeseries]),
[
{
"headers": ["time", "retail", "food", "other"],

View File

@ -17,7 +17,7 @@ class TestSeries(unittest.TestCase):
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox())
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox())
series = mix.Series(io1, io2)
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
self.assertEqual(series.process(["Hello"]), ["Hello World!"])
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
@ -33,7 +33,7 @@ class TestParallel(unittest.TestCase):
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox())
parallel = mix.Parallel(io1, io2)
self.assertEqual(
parallel.process(["Hello"])[0], ["Hello World 1!", "Hello World 2!"]
parallel.process(["Hello"]), ["Hello World 1!", "Hello World 2!"]
)
def test_with_external(self):

View File

@ -19,9 +19,9 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestTextbox(unittest.TestCase):
def test_in_interface(self):
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"])
self.assertEqual(iface.process(["Hello"]), ["o"])
iface = gr.Interface(lambda x: x / 2, "number", gr.outputs.Textbox())
self.assertEqual(iface.process([10])[0], ["5.0"])
self.assertEqual(iface.process([10]), ["5.0"])
class TestLabel(unittest.TestCase):
@ -92,7 +92,7 @@ class TestLabel(unittest.TestCase):
}
iface = gr.Interface(rgb_distribution, "image", "label")
output = iface.process([x_img])[0][0]
output = iface.process([x_img])[0]
self.assertDictEqual(
output,
{
@ -151,9 +151,7 @@ class TestImage(unittest.TestCase):
return np.random.randint(0, 256, (width, height, 3))
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
)
self.assertTrue(iface.process([10, 20])[0].startswith("data:image/png;base64"))
class TestVideo(unittest.TestCase):
@ -219,7 +217,7 @@ class TestHighlightedText(unittest.TestCase):
iface = gr.Interface(highlight_vowels, "text", "highlight")
self.assertListEqual(
iface.process(["Helloooo"])[0][0],
iface.process(["Helloooo"])[0],
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
)
@ -264,7 +262,7 @@ class TestAudio(unittest.TestCase):
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int32)
iface = gr.Interface(generate_noise, "slider", "audio")
self.assertTrue(iface.process([100])[0][0].startswith("data:audio/wav;base64"))
self.assertTrue(iface.process([100])[0].startswith("data:audio/wav;base64"))
class TestJSON(unittest.TestCase):
@ -302,7 +300,7 @@ class TestJSON(unittest.TestCase):
["O", 20],
["F", 30],
]
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
self.assertDictEqual(iface.process([y_data])[0], {"M": 35, "F": 25, "O": 20})
class TestHTML(unittest.TestCase):
@ -311,7 +309,7 @@ class TestHTML(unittest.TestCase):
return "<strong>" + text + "</strong>"
iface = gr.Interface(bold_text, "text", "html")
self.assertEqual(iface.process(["test"])[0][0], "<strong>test</strong>")
self.assertEqual(iface.process(["test"])[0], "<strong>test</strong>")
class TestFile(unittest.TestCase):
@ -323,7 +321,7 @@ class TestFile(unittest.TestCase):
iface = gr.Interface(write_file, "text", "file")
self.assertDictEqual(
iface.process(["hello world"])[0][0],
iface.process(["hello world"])[0],
{
"name": "test.txt",
"size": 11,
@ -408,9 +406,7 @@ class TestDataframe(unittest.TestCase):
return array % 2 == 0
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
)
self.assertEqual(iface.process([[2, 3, 4]])[0], {"data": [[True, False, True]]})
class TestCarousel(unittest.TestCase):
@ -477,7 +473,7 @@ class TestCarousel(unittest.TestCase):
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
self.assertEqual(
iface.process([media_data.BASE64_IMAGE])[0],
iface.process([media_data.BASE64_IMAGE]),
[
[
[

View File

@ -9,7 +9,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestProcessExamples(unittest.TestCase):
def test_process_example(self):
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
prediction, _ = process_examples.process_example(io, 0)
prediction = process_examples.process_example(io, 0)
self.assertEquals(prediction[0], "Hello World")
def test_caching(self):

View File

@ -29,6 +29,8 @@
inputs: Array<number>;
outputs: Array<number>;
queue: boolean;
status_tracker: number | null;
status?: string;
}
export let root: string;
@ -113,6 +115,15 @@
});
let handled_dependencies: Array<number[]> = [];
let status_tracker_values: Record<number, string> = {};
let set_status = (dependency_index: number, status: string) => {
dependencies[dependency_index].status = status;
let status_tracker_id = dependencies[dependency_index].status_tracker;
if (status_tracker_id !== null) {
status_tracker_values[status_tracker_id] = status;
}
};
async function handle_mount({ detail }) {
await tick();
@ -139,11 +150,15 @@
},
queue,
() => {}
).then((output) => {
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
)
.then((output) => {
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
});
})
.catch((error) => {
console.error(error);
});
});
handled_dependencies[i] = [-1];
}
@ -153,6 +168,10 @@
if (handled_dependencies[i]?.includes(id) || !instance) return;
// console.log(trigger, target_instances, instance);
instance?.$on(trigger, () => {
if (status === "pending") {
return;
}
set_status(i, "pending");
fn(
"predict",
{
@ -161,11 +180,17 @@
},
queue,
() => {}
).then((output) => {
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
)
.then((output) => {
set_status(i, "complete");
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
});
})
.catch((error) => {
set_status(i, "error");
console.error(error);
});
});
});
if (!handled_dependencies[i]) handled_dependencies[i] = [];
@ -194,6 +219,7 @@
{instance_map}
{theme}
{root}
{status_tracker_values}
on:mount={handle_mount}
on:destroy={({ detail }) => handle_destroy(detail)}
/>

View File

@ -10,6 +10,7 @@
export let theme;
export let dynamic_ids: Set<number>;
export let has_modes: boolean;
export let status_tracker_values: Record<number, string>;
const dispatch = createEventDispatcher<{ mount: number; destroy: number }>();
@ -46,6 +47,7 @@
{style}
{...props}
{root}
tracked_status={status_tracker_values[id]}
>
{#if children && children.length}
{#each children as { component, id, props, children, has_modes }}
@ -59,6 +61,7 @@
{children}
{dynamic_ids}
{has_modes}
{status_tracker_values}
on:destroy
on:mount
/>

View File

@ -6,6 +6,6 @@
if (default_value) value = default_value;
</script>
<div {style} class:hidden={!value} class="flex flex-1 flex-col gap-4">
<div {style} class:hidden={!value} class="flex flex-1 flex-col gap-4 relative">
<slot />
</div>

View File

@ -6,6 +6,6 @@
if (default_value) value = default_value;
</script>
<div {style} class:hidden={!value} class="flex flex-row gap-4">
<div {style} class:hidden={!value} class="flex flex-row gap-4 relative">
<slot />
</div>

View File

@ -0,0 +1,100 @@
<script lang="ts">
import { onDestroy } from "svelte";
export let style: string = "";
export let cover_container: bool = false;
export let eta: number | null = null;
export let duration: number = 8.2;
export let queue_pos: number | null = 0;
export let tracked_status: "complete" | "pending" | "error";
$: progress = eta === null ? null : Math.min(duration / eta, 1);
let timer: NodeJS.Timeout = null;
let timer_start = 0;
let timer_diff = 0;
const start_timer = () => {
timer_start = Date.now();
timer_diff = 0;
timer = setInterval(() => {
timer_diff = (Date.now() - timer_start) / 1000;
}, 100);
};
const stop_timer = () => {
clearInterval(timer);
};
onDestroy(() => {
if (timer) stop_timer();
});
$: {
if (tracked_status === "pending") {
start_timer();
} else {
stop_timer();
}
}
</script>
{#if tracked_status === "pending"}
<div class:cover_container {style}>
<div class="text-xs font-mono text-gray-400">
{#if queue_pos}
{queue_pos} in line
{:else}
{timer_diff.toFixed(1)}s
{/if}
</div>
<div class="border-gray-200 rounded border w-40 h-2 relative">
{#if progress === null}
<div class="bounce absolute bg-amber-500 shadow-inner h-full w-1/4" />
{:else}
<div
class="blink bg-amber-500 shadow-inner h-full"
style="width: {progress * 100}%;"
/>
{/if}
</div>
</div>
{:else if tracked_status === "error"}
<div class:cover_container {style}>
<span class="text-red-400 font-mono font-semibold text-lg">ERROR</span>
</div>
{/if}
<style lang="postcss">
.cover_container {
@apply absolute top-0 left-0 w-full h-full z-10 flex flex-col justify-center items-center bg-gray-100 bg-opacity-25;
}
@keyframes blink {
0% {
opacity: 100%;
}
50% {
opacity: 60%;
}
100% {
opacity: 100%;
}
}
.blink {
animation: blink 2s infinite;
}
@keyframes bounce {
0% {
left: 0%;
}
50% {
left: 75%;
}
100% {
left: 0%;
}
}
.bounce {
animation: bounce 2s infinite linear;
}
</style>

View File

@ -0,0 +1,2 @@
export { default as Component } from "./StatusTracker.svelte";
export const modes = ["static"];

View File

@ -11,6 +11,7 @@ export const component_map: Record<string, any> = {
dataset: () => import("./Dataset"),
dropdown: () => import("./Dropdown"),
file: () => import("./File"),
statustracker: () => import("./StatusTracker"),
highlightedtext: () => import("./HighlightedText"),
html: () => import("./HTML"),
image: () => import("./Image"),

View File

@ -15,7 +15,7 @@
</script>
{#if $selected_tab === id}
<div class="p-2 border-2 border-t-0 border-gray-200">
<div class="p-2 border-2 border-t-0 border-gray-200 relative">
<slot />
</div>
{/if}