mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Progress indicator bar (#997)
Create Status Tracker component to report progress on function calls Co-authored-by: Ali Abid <aliabid94@gmail.com>
This commit is contained in:
parent
f51117487f
commit
1c2f430a7e
@ -1,10 +1,13 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!!"
|
||||
|
||||
|
||||
demo = gr.Interface(fn=greet, inputs=gr.component("textarea"), outputs=gr.component("textarea"))
|
||||
demo = gr.Interface(
|
||||
fn=greet, inputs=gr.component("textarea"), outputs=gr.component("textarea")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -7,6 +7,7 @@ from gradio import Templates
|
||||
def snap(image):
|
||||
return np.flipud(image)
|
||||
|
||||
|
||||
demo = gr.Interface(snap, gr.component("webcam"), gr.component("image"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,9 +1,16 @@
|
||||
import random
|
||||
|
||||
import gradio as gr
|
||||
import random
|
||||
import time
|
||||
|
||||
xray_model = lambda diseases, img: {disease: random.random() for disease in diseases}
|
||||
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
|
||||
|
||||
def xray_model(diseases, img):
|
||||
time.sleep(4)
|
||||
return {disease: random.random() for disease in diseases}
|
||||
|
||||
|
||||
def ct_model(diseases, img):
|
||||
time.sleep(3)
|
||||
return {disease: 0.1 for disease in diseases}
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
@ -25,8 +32,12 @@ With this model you can lorem ipsum
|
||||
xray_scan = gr.Image()
|
||||
xray_results = gr.JSON()
|
||||
xray_run = gr.Button("Run")
|
||||
xray_progress = gr.StatusTracker(cover_container=True)
|
||||
xray_run.click(
|
||||
xray_model, inputs=[disease, xray_scan], outputs=xray_results
|
||||
xray_model,
|
||||
inputs=[disease, xray_scan],
|
||||
outputs=xray_results,
|
||||
status_tracker=xray_progress,
|
||||
)
|
||||
|
||||
with gr.TabItem("CT Scan"):
|
||||
@ -34,9 +45,21 @@ With this model you can lorem ipsum
|
||||
ct_scan = gr.Image()
|
||||
ct_results = gr.JSON()
|
||||
ct_run = gr.Button("Run")
|
||||
ct_run.click(ct_model, inputs=[disease, ct_scan], outputs=ct_results)
|
||||
ct_progress = gr.StatusTracker(cover_container=True)
|
||||
ct_run.click(
|
||||
ct_model,
|
||||
inputs=[disease, ct_scan],
|
||||
outputs=ct_results,
|
||||
status_tracker=ct_progress,
|
||||
)
|
||||
|
||||
overall_probability = gr.Textbox()
|
||||
upload_btn = gr.Button("Upload Results")
|
||||
upload_btn.click(
|
||||
lambda ct, xr: time.sleep(5),
|
||||
inputs=[ct_results, xray_results],
|
||||
outputs=[],
|
||||
status_tracker=gr.StatusTracker(),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -1,4 +1,4 @@
|
||||
Metadata-Version: 2.1
|
||||
Metadata-Version: 1.0
|
||||
Name: gradio
|
||||
Version: 2.7.0b70
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
@ -6,9 +6,6 @@ Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir
|
||||
Author-email: team@gradio.app
|
||||
License: Apache License 2.0
|
||||
Description: UNKNOWN
|
||||
Keywords: machine learning,visualization,reproducibility
|
||||
Platform: UNKNOWN
|
||||
License-File: LICENSE
|
||||
|
||||
UNKNOWN
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
analytics-python
|
||||
aiohttp
|
||||
analytics-python
|
||||
fastapi
|
||||
ffmpy
|
||||
markdown-it-py[linkify,plugins]
|
||||
@ -10,7 +10,7 @@ pandas
|
||||
paramiko
|
||||
pillow
|
||||
pycryptodome
|
||||
python-multipart
|
||||
pydub
|
||||
python-multipart
|
||||
requests
|
||||
uvicorn
|
||||
|
@ -25,6 +25,7 @@ from gradio.components import (
|
||||
Number,
|
||||
Radio,
|
||||
Slider,
|
||||
StatusTracker,
|
||||
Textbox,
|
||||
Timeseries,
|
||||
Variable,
|
||||
|
@ -15,7 +15,7 @@ from gradio.process_examples import cache_interface_examples
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from fastapi.applications import FastAPI
|
||||
|
||||
from gradio.components import Component
|
||||
from gradio.components import Component, StatusTracker
|
||||
|
||||
|
||||
class Block:
|
||||
@ -60,6 +60,7 @@ class Block:
|
||||
postprocess: bool = True,
|
||||
queue=False,
|
||||
no_target: bool = False,
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Adds an event to the component's dependencies.
|
||||
@ -70,7 +71,9 @@ class Block:
|
||||
outputs: output list
|
||||
preprocess: whether to run the preprocess methods of components
|
||||
postprocess: whether to run the postprocess methods of components
|
||||
queue: if True, will store multiple calls in queue and run in order instead of in parallel with multiple threads
|
||||
no_target: if True, sets "targets" to [], used for Blocks "load" event
|
||||
status_tracker: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
# Support for singular parameter
|
||||
@ -79,7 +82,7 @@ class Block:
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
Context.root_block.fns.append((fn, preprocess, postprocess))
|
||||
Context.root_block.fns.append(BlockFunction(fn, preprocess, postprocess))
|
||||
Context.root_block.dependencies.append(
|
||||
{
|
||||
"targets": [self._id] if not no_target else [],
|
||||
@ -87,6 +90,9 @@ class Block:
|
||||
"inputs": [block._id for block in inputs],
|
||||
"outputs": [block._id for block in outputs],
|
||||
"queue": queue,
|
||||
"status_tracker": status_tracker._id
|
||||
if status_tracker is not None
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
@ -183,6 +189,15 @@ class TabItem(BlockContext):
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
|
||||
|
||||
class BlockFunction:
|
||||
def __init__(self, fn: Callable, preprocess: bool, postprocess: bool):
|
||||
self.fn = fn
|
||||
self.preprocess = preprocess
|
||||
self.postprocess = postprocess
|
||||
self.total_runtime = 0
|
||||
self.total_runs = 0
|
||||
|
||||
|
||||
class Blocks(BlockContext):
|
||||
def __init__(
|
||||
self,
|
||||
@ -208,7 +223,7 @@ class Blocks(BlockContext):
|
||||
|
||||
super().__init__()
|
||||
self.blocks = {}
|
||||
self.fns = []
|
||||
self.fns: List[BlockFunction] = []
|
||||
self.dependencies = []
|
||||
self.mode = mode
|
||||
|
||||
@ -238,10 +253,10 @@ class Blocks(BlockContext):
|
||||
"""
|
||||
raw_input = data["data"]
|
||||
fn_index = data["fn_index"]
|
||||
fn, preprocess, postprocess = self.fns[fn_index]
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
if preprocess:
|
||||
if block_fn.preprocess:
|
||||
processed_input = []
|
||||
for i, input_id in enumerate(dependency["inputs"]):
|
||||
block = self.blocks[input_id]
|
||||
@ -249,12 +264,16 @@ class Blocks(BlockContext):
|
||||
processed_input.append(state.get(input_id))
|
||||
else:
|
||||
processed_input.append(block.preprocess(raw_input[i]))
|
||||
predictions = fn(*processed_input)
|
||||
else:
|
||||
predictions = fn(*raw_input)
|
||||
processed_input = raw_input
|
||||
start = time.time()
|
||||
predictions = block_fn.fn(*processed_input)
|
||||
duration = time.time() - start
|
||||
block_fn.total_runtime += duration
|
||||
block_fn.total_runs += 1
|
||||
if len(dependency["outputs"]) == 1:
|
||||
predictions = (predictions,)
|
||||
if postprocess:
|
||||
if block_fn.postprocess:
|
||||
output = []
|
||||
for i, output_id in enumerate(dependency["outputs"]):
|
||||
block = self.blocks[output_id]
|
||||
@ -269,7 +288,11 @@ class Blocks(BlockContext):
|
||||
)
|
||||
else:
|
||||
output = predictions
|
||||
return {"data": output}
|
||||
return {
|
||||
"data": output,
|
||||
"duration": duration,
|
||||
"average_duration": block_fn.total_runtime / block_fn.total_runs,
|
||||
}
|
||||
|
||||
def get_template_context(self):
|
||||
return {"type": "column"}
|
||||
|
@ -387,25 +387,43 @@ class Textbox(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def submit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def submit(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("submit", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"submit", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Number(Component):
|
||||
@ -517,25 +535,43 @@ class Number(Component):
|
||||
"""
|
||||
return y
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def submit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def submit(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("submit", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"submit", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Slider(Component):
|
||||
@ -643,15 +679,24 @@ class Slider(Component):
|
||||
"""
|
||||
return y
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Checkbox(Component):
|
||||
@ -735,15 +780,24 @@ class Checkbox(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class CheckboxGroup(Component):
|
||||
@ -863,15 +917,24 @@ class CheckboxGroup(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Radio(Component):
|
||||
@ -971,15 +1034,24 @@ class Radio(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Dropdown(Radio):
|
||||
@ -1312,15 +1384,24 @@ class Image(Component):
|
||||
y = processing_utils.decode_base64_to_file(x).name
|
||||
return y
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def edit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
@ -1454,15 +1535,24 @@ class Video(Component):
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def clear(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
@ -1747,15 +1837,24 @@ class Audio(Component):
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def edit(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
@ -1925,15 +2024,24 @@ class File(Component):
|
||||
"data": processing_utils.encode_file_to_base64(y),
|
||||
}
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def clear(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
@ -2102,15 +2210,24 @@ class Dataframe(Component):
|
||||
+ ". Please choose from: 'pandas', 'numpy', 'array'."
|
||||
)
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Timeseries(Component):
|
||||
@ -2200,15 +2317,24 @@ class Timeseries(Component):
|
||||
"""
|
||||
return {"headers": y.columns.values.tolist(), "data": y.values.tolist()}
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Variable(Component):
|
||||
@ -2339,15 +2465,24 @@ class Label(Component):
|
||||
except ValueError:
|
||||
return data
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class KeyValues(Component):
|
||||
@ -2430,15 +2565,24 @@ class HighlightedText(Component):
|
||||
def restore_flagged(self, dir, data, encryption_key):
|
||||
return json.loads(data)
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class JSON(Component):
|
||||
@ -2488,15 +2632,24 @@ class JSON(Component):
|
||||
def restore_flagged(self, dir, data, encryption_key):
|
||||
return json.loads(data)
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class HTML(Component):
|
||||
@ -2536,15 +2689,24 @@ class HTML(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Carousel(Component):
|
||||
@ -2627,15 +2789,24 @@ class Carousel(Component):
|
||||
for sample_set in json.loads(data)
|
||||
]
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
class Chatbot(Component):
|
||||
@ -2674,15 +2845,24 @@ class Chatbot(Component):
|
||||
"""
|
||||
return y
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
self.set_event_trigger(
|
||||
"change", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
|
||||
# Static Components
|
||||
@ -2745,27 +2925,48 @@ class Button(Component):
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
queue=False,
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("click", fn, inputs, outputs, queue=queue)
|
||||
self.set_event_trigger(
|
||||
"click",
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
queue=queue,
|
||||
status_tracker=status_tracker,
|
||||
)
|
||||
|
||||
def _click_no_preprocess(
|
||||
self, fn: Callable, inputs: List[Component], outputs: List[Component]
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("click", fn, inputs, outputs, preprocess=False)
|
||||
self.set_event_trigger(
|
||||
"click",
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess=False,
|
||||
status_tracker=status_tracker,
|
||||
)
|
||||
|
||||
|
||||
class Dataset(Component):
|
||||
@ -2808,27 +3009,48 @@ class Dataset(Component):
|
||||
elif self.type == "values":
|
||||
return self.samples[x]
|
||||
|
||||
def click(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("click", fn, inputs, outputs)
|
||||
|
||||
def _click_no_postprocess(
|
||||
self, fn: Callable, inputs: List[Component], outputs: List[Component]
|
||||
def click(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger("click", fn, inputs, outputs, postprocess=False)
|
||||
self.set_event_trigger(
|
||||
"click", fn, inputs, outputs, status_tracker=status_tracker
|
||||
)
|
||||
|
||||
def _click_no_postprocess(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
status_tracker: Optional[StatusTracker] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: Callable function
|
||||
inputs: List of inputs
|
||||
outputs: List of outputs
|
||||
status: StatusTracker to visualize function progress
|
||||
Returns: None
|
||||
"""
|
||||
self.set_event_trigger(
|
||||
"click",
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
postprocess=False,
|
||||
status_tracker=status_tracker,
|
||||
)
|
||||
|
||||
|
||||
class Interpretation(Component):
|
||||
@ -2893,3 +3115,32 @@ def get_component_instance(comp: str | dict | Component):
|
||||
raise ValueError(
|
||||
f"Component must provided as a `str` or `dict` or `Component` but is {comp}"
|
||||
)
|
||||
|
||||
|
||||
class StatusTracker(Component):
|
||||
"""
|
||||
Used to indicate status of a function call. Event listeners can bind to a StatusTracker with 'status=' keyword argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cover_container: bool = False,
|
||||
label: Optional[str] = None,
|
||||
css: Optional[Dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
cover_container (bool): If True, will expand to cover parent container while function pending.
|
||||
label (str): component name
|
||||
css (dict): optional css parameters for the component
|
||||
"""
|
||||
super().__init__(label=label, css=css, **kwargs)
|
||||
self.cover_container = cover_container
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"cover_container": self.cover_container,
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import inspect
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
@ -27,6 +26,7 @@ from gradio.components import (
|
||||
Dataset,
|
||||
Interpretation,
|
||||
Markdown,
|
||||
StatusTracker,
|
||||
Variable,
|
||||
get_component_instance,
|
||||
)
|
||||
@ -536,6 +536,7 @@ class Interface(Blocks):
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
status_tracker = StatusTracker(cover_container=True)
|
||||
for component in self.output_components:
|
||||
component.render()
|
||||
with Row():
|
||||
@ -543,9 +544,9 @@ class Interface(Blocks):
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button("Interpret")
|
||||
submit_fn = (
|
||||
lambda *args: self.run_prediction(args, return_duration=False)[0]
|
||||
lambda *args: self.run_prediction(args)[0]
|
||||
if len(self.output_components) == 1
|
||||
else self.run_prediction(args, return_duration=False)
|
||||
else self.run_prediction(args)
|
||||
)
|
||||
if self.live:
|
||||
for component in self.input_components:
|
||||
@ -558,6 +559,7 @@ class Interface(Blocks):
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
queue=self.enable_queue,
|
||||
status_tracker=status_tracker,
|
||||
)
|
||||
clear_btn.click(
|
||||
lambda: [
|
||||
@ -613,6 +615,7 @@ class Interface(Blocks):
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=interpretation_set
|
||||
+ [input_component_column, interpret_component_column],
|
||||
status_tracker=status_tracker,
|
||||
)
|
||||
|
||||
def __call__(self, *params):
|
||||
@ -621,7 +624,7 @@ class Interface(Blocks):
|
||||
): # skip the preprocessing/postprocessing if sending to a remote API
|
||||
output = self.run_prediction(params, called_directly=True)
|
||||
else:
|
||||
output, _ = self.process(params)
|
||||
output = self.process(params)
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
def __str__(self):
|
||||
@ -662,20 +665,16 @@ class Interface(Blocks):
|
||||
def run_prediction(
|
||||
self,
|
||||
processed_input: List[Any],
|
||||
return_duration: bool = False,
|
||||
called_directly: bool = False,
|
||||
) -> List[Any] | Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
Runs the prediction function with the given (already processed) inputs.
|
||||
Parameters:
|
||||
processed_input (list): A list of processed inputs.
|
||||
return_duration (bool): Whether to return the duration of the prediction.
|
||||
called_directly (bool): Whether the prediction is being called
|
||||
directly (i.e. as a function, not through the GUI).
|
||||
Returns:
|
||||
predictions (list): A list of predictions (not post-processed).
|
||||
durations (list): A list of durations for each prediction
|
||||
(only returned if `return_duration` is True).
|
||||
"""
|
||||
if self.api_mode: # Serialize the input
|
||||
processed_input = [
|
||||
@ -683,18 +682,15 @@ class Interface(Blocks):
|
||||
for i, input_component in enumerate(self.input_components)
|
||||
]
|
||||
predictions = []
|
||||
durations = []
|
||||
output_component_counter = 0
|
||||
|
||||
for predict_fn in self.predict:
|
||||
start = time.time()
|
||||
if self.capture_session and self.session is not None: # For TF 1.x
|
||||
graph, sess = self.session
|
||||
with graph.as_default(), sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
duration = time.time() - start
|
||||
|
||||
if len(self.output_components) == len(self.predict):
|
||||
prediction = [prediction]
|
||||
@ -714,13 +710,9 @@ class Interface(Blocks):
|
||||
)
|
||||
output_component_counter += 1
|
||||
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
||||
if return_duration:
|
||||
return predictions, durations
|
||||
else:
|
||||
return predictions
|
||||
return predictions
|
||||
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
@ -736,26 +728,14 @@ class Interface(Blocks):
|
||||
input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)
|
||||
]
|
||||
predictions, durations = self.run_prediction(
|
||||
processed_input, return_duration=True
|
||||
)
|
||||
predictions = self.run_prediction(processed_input)
|
||||
processed_output = [
|
||||
output_component.postprocess(predictions[i])
|
||||
if predictions[i] is not None
|
||||
else None
|
||||
for i, output_component in enumerate(self.output_components)
|
||||
]
|
||||
avg_durations = []
|
||||
for i, duration in enumerate(durations):
|
||||
self.predict_durations[i][0] += duration
|
||||
self.predict_durations[i][1] += 1
|
||||
avg_durations.append(
|
||||
self.predict_durations[i][0] / self.predict_durations[i][1]
|
||||
)
|
||||
if hasattr(self, "config"):
|
||||
self.config["avg_durations"] = avg_durations
|
||||
|
||||
return processed_output, durations
|
||||
return processed_output
|
||||
|
||||
def interpret(self, raw_input: List[Any]) -> List[Any]:
|
||||
return [
|
||||
|
@ -26,8 +26,8 @@ def process_example(
|
||||
interface.input_components[i].preprocess_example(example)
|
||||
for i, example in enumerate(example_set)
|
||||
]
|
||||
prediction, durations = interface.process(raw_input)
|
||||
return prediction, durations
|
||||
prediction = interface.process(raw_input)
|
||||
return prediction
|
||||
|
||||
|
||||
def cache_interface_examples(interface: Interface) -> None:
|
||||
@ -44,7 +44,7 @@ def cache_interface_examples(interface: Interface) -> None:
|
||||
cache_logger.setup(interface.output_components, CACHED_FOLDER)
|
||||
for example_id, _ in enumerate(interface.examples):
|
||||
try:
|
||||
prediction = process_example(interface, example_id)[0]
|
||||
prediction = process_example(interface, example_id)
|
||||
cache_logger.flag(prediction)
|
||||
except Exception as e:
|
||||
shutil.rmtree(CACHED_FOLDER)
|
||||
|
@ -45,8 +45,8 @@
|
||||
</script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
||||
<title>Gradio</title>
|
||||
<script type="module" crossorigin src="./assets/index.6c5c73c1.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index.38c11487.css">
|
||||
<script type="module" crossorigin src="./assets/index.2c3c09fa.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index.689bcbeb.css">
|
||||
</head>
|
||||
|
||||
<body style="height: 100%; margin: 0; padding: 0">
|
||||
|
@ -168,6 +168,7 @@ XRAY_CONFIG = {
|
||||
"inputs": [2, 6],
|
||||
"outputs": [7],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
@ -175,6 +176,7 @@ XRAY_CONFIG = {
|
||||
"inputs": [2, 11],
|
||||
"outputs": [12],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
{
|
||||
"targets": [],
|
||||
@ -182,6 +184,7 @@ XRAY_CONFIG = {
|
||||
"inputs": [],
|
||||
"outputs": [14],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -356,6 +359,7 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"inputs": [22, 6],
|
||||
"outputs": [7],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
@ -363,6 +367,7 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"inputs": [22, 11],
|
||||
"outputs": [12],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -537,6 +542,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
|
||||
"inputs": [2, 6],
|
||||
"outputs": [7],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
@ -544,6 +550,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
|
||||
"inputs": [2, 11],
|
||||
"outputs": [12],
|
||||
"queue": False,
|
||||
"status_tracker": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ class TestTextbox(unittest.TestCase):
|
||||
Interface, process, interpret,
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
|
||||
self.assertEqual(iface.process(["Hello"]), ["olleH"])
|
||||
iface = gr.Interface(
|
||||
lambda sentence: max([len(word) for word in sentence.split()]),
|
||||
gr.Textbox(),
|
||||
@ -142,9 +142,9 @@ class TestTextbox(unittest.TestCase):
|
||||
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x[-1], "textbox", gr.Textbox())
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["o"])
|
||||
self.assertEqual(iface.process(["Hello"]), ["o"])
|
||||
iface = gr.Interface(lambda x: x / 2, "number", gr.Textbox())
|
||||
self.assertEqual(iface.process([10])[0], ["5.0"])
|
||||
self.assertEqual(iface.process([10]), ["5.0"])
|
||||
|
||||
|
||||
class TestNumber(unittest.TestCase):
|
||||
@ -193,7 +193,7 @@ class TestNumber(unittest.TestCase):
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4.0"])
|
||||
self.assertEqual(iface.process([2]), ["4.0"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", "number", interpretation="default"
|
||||
)
|
||||
@ -216,7 +216,7 @@ class TestNumber(unittest.TestCase):
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: int(x) ** 2, "textbox", "number")
|
||||
self.assertEqual(iface.process([2])[0], [4.0])
|
||||
self.assertEqual(iface.process([2]), [4.0])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", "number", interpretation="default"
|
||||
)
|
||||
@ -275,7 +275,7 @@ class TestSlider(unittest.TestCase):
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4"])
|
||||
self.assertEqual(iface.process([2]), ["4"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "slider", "number", interpretation="default"
|
||||
)
|
||||
@ -328,7 +328,7 @@ class TestCheckbox(unittest.TestCase):
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
|
||||
self.assertEqual(iface.process([True])[0], [1])
|
||||
self.assertEqual(iface.process([True]), [1])
|
||||
iface = gr.Interface(
|
||||
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
|
||||
)
|
||||
@ -381,8 +381,8 @@ class TestCheckboxGroup(unittest.TestCase):
|
||||
"""
|
||||
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
|
||||
self.assertEqual(iface.process([[]])[0], [""])
|
||||
self.assertEqual(iface.process([["a", "c"]]), ["a|c"])
|
||||
self.assertEqual(iface.process([[]]), [""])
|
||||
_ = gr.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
|
||||
|
||||
@ -426,12 +426,12 @@ class TestRadio(unittest.TestCase):
|
||||
"""
|
||||
radio_input = gr.Radio(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
self.assertEqual(iface.process(["c"]), ["cc"])
|
||||
radio_input = gr.Radio(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(
|
||||
lambda x: 2 * x, radio_input, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
self.assertEqual(iface.process(["c"]), [4])
|
||||
scores = iface.interpret(["b"])[0]["interpretation"]
|
||||
self.assertEqual(scores, [-2.0, None, 2.0])
|
||||
|
||||
@ -557,7 +557,7 @@ class TestImage(unittest.TestCase):
|
||||
gr.Image(shape=(30, 10), type="file"),
|
||||
"image",
|
||||
)
|
||||
output = iface.process([img])[0][0]
|
||||
output = iface.process([img])[0]
|
||||
self.assertEqual(
|
||||
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
|
||||
)
|
||||
@ -591,9 +591,7 @@ class TestImage(unittest.TestCase):
|
||||
return np.random.randint(0, 256, (width, height, 3))
|
||||
|
||||
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
|
||||
self.assertTrue(
|
||||
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
|
||||
)
|
||||
self.assertTrue(iface.process([10, 20])[0].startswith("data:image/png;base64"))
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
@ -706,16 +704,16 @@ class TestAudio(unittest.TestCase):
|
||||
return (sr, np.flipud(data))
|
||||
|
||||
iface = gr.Interface(reverse_audio, "audio", "audio")
|
||||
reversed_data = iface.process([deepcopy(media_data.BASE64_AUDIO)])[0][0]
|
||||
reversed_data = iface.process([deepcopy(media_data.BASE64_AUDIO)])[0]
|
||||
reversed_input = {"name": "fake_name", "data": reversed_data}
|
||||
self.assertTrue(reversed_data.startswith("data:audio/wav;base64,UklGRgA/"))
|
||||
self.assertTrue(
|
||||
iface.process([deepcopy(media_data.BASE64_AUDIO)])[0][0].startswith(
|
||||
iface.process([deepcopy(media_data.BASE64_AUDIO)])[0].startswith(
|
||||
"data:audio/wav;base64,UklGRgA/"
|
||||
)
|
||||
)
|
||||
self.maxDiff = None
|
||||
reversed_reversed_data = iface.process([reversed_input])[0][0]
|
||||
reversed_reversed_data = iface.process([reversed_input])[0]
|
||||
similarity = SequenceMatcher(
|
||||
a=reversed_reversed_data, b=media_data.BASE64_AUDIO["data"]
|
||||
).ratio()
|
||||
@ -730,7 +728,7 @@ class TestAudio(unittest.TestCase):
|
||||
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int16)
|
||||
|
||||
iface = gr.Interface(generate_noise, "slider", "audio")
|
||||
self.assertTrue(iface.process([100])[0][0].startswith("data:audio/wav;base64"))
|
||||
self.assertTrue(iface.process([100])[0].startswith("data:audio/wav;base64"))
|
||||
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
@ -788,7 +786,7 @@ class TestFile(unittest.TestCase):
|
||||
return os.path.getsize(file_obj.name)
|
||||
|
||||
iface = gr.Interface(get_size_of_file, "file", "number")
|
||||
self.assertEqual(iface.process([[x_file]])[0], [10558])
|
||||
self.assertEqual(iface.process([[x_file]]), [10558])
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
"""
|
||||
@ -802,7 +800,7 @@ class TestFile(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(write_file, "text", "file")
|
||||
self.assertDictEqual(
|
||||
iface.process(["hello world"])[0][0],
|
||||
iface.process(["hello world"])[0],
|
||||
{
|
||||
"name": "test.txt",
|
||||
"size": 11,
|
||||
@ -943,14 +941,14 @@ class TestDataframe(unittest.TestCase):
|
||||
"""
|
||||
x_data = [[1, 2, 3], [4, 5, 6]]
|
||||
iface = gr.Interface(np.max, "numpy", "number")
|
||||
self.assertEqual(iface.process([x_data])[0], [6])
|
||||
self.assertEqual(iface.process([x_data]), [6])
|
||||
x_data = [["Tim"], ["Jon"], ["Sal"]]
|
||||
|
||||
def get_last(my_list):
|
||||
return my_list[-1]
|
||||
|
||||
iface = gr.Interface(get_last, "list", "text")
|
||||
self.assertEqual(iface.process([x_data])[0], ["Sal"])
|
||||
self.assertEqual(iface.process([x_data]), ["Sal"])
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
@ -961,9 +959,7 @@ class TestDataframe(unittest.TestCase):
|
||||
return array % 2 == 0
|
||||
|
||||
iface = gr.Interface(check_odd, "numpy", "numpy")
|
||||
self.assertEqual(
|
||||
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
|
||||
)
|
||||
self.assertEqual(iface.process([[2, 3, 4]])[0], {"data": [[True, False, True]]})
|
||||
|
||||
|
||||
class TestVideo(unittest.TestCase):
|
||||
@ -1034,7 +1030,7 @@ class TestVideo(unittest.TestCase):
|
||||
"""
|
||||
x_video = deepcopy(media_data.BASE64_VIDEO)
|
||||
iface = gr.Interface(lambda x: x, "video", "playable_video")
|
||||
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
|
||||
self.assertEqual(iface.process([x_video])[0]["data"], x_video["data"])
|
||||
|
||||
|
||||
class TestTimeseries(unittest.TestCase):
|
||||
@ -1147,7 +1143,7 @@ class TestTimeseries(unittest.TestCase):
|
||||
}
|
||||
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
|
||||
self.assertEqual(
|
||||
iface.process([x_timeseries])[0],
|
||||
iface.process([x_timeseries]),
|
||||
[
|
||||
{
|
||||
"headers": ["time", "retail", "food", "other"],
|
||||
@ -1176,7 +1172,7 @@ class TestTimeseries(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
iface.process([df])[0],
|
||||
iface.process([df]),
|
||||
[
|
||||
{
|
||||
"headers": ["time", "retail", "food", "other"],
|
||||
@ -1283,7 +1279,7 @@ class TestLabel(unittest.TestCase):
|
||||
}
|
||||
|
||||
iface = gr.Interface(rgb_distribution, "image", "label")
|
||||
output = iface.process([x_img])[0][0]
|
||||
output = iface.process([x_img])[0]
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
@ -1346,7 +1342,7 @@ class TestHighlightedText(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(highlight_vowels, "text", "highlight")
|
||||
self.assertListEqual(
|
||||
iface.process(["Helloooo"])[0][0],
|
||||
iface.process(["Helloooo"])[0],
|
||||
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
|
||||
)
|
||||
|
||||
@ -1403,7 +1399,7 @@ class TestJSON(unittest.TestCase):
|
||||
["O", 20],
|
||||
["F", 30],
|
||||
]
|
||||
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
|
||||
self.assertDictEqual(iface.process([y_data])[0], {"M": 35, "F": 25, "O": 20})
|
||||
|
||||
|
||||
class TestHTML(unittest.TestCase):
|
||||
@ -1432,7 +1428,7 @@ class TestHTML(unittest.TestCase):
|
||||
return "<strong>" + text + "</strong>"
|
||||
|
||||
iface = gr.Interface(bold_text, "text", "html")
|
||||
self.assertEqual(iface.process(["test"])[0][0], "<strong>test</strong>")
|
||||
self.assertEqual(iface.process(["test"])[0], "<strong>test</strong>")
|
||||
|
||||
|
||||
class TestCarousel(unittest.TestCase):
|
||||
@ -1511,17 +1507,17 @@ class TestCarousel(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
|
||||
result = iface.process([deepcopy(media_data.BASE64_IMAGE)])
|
||||
self.assertTrue(result[0][0][0][0] == "Red")
|
||||
self.assertTrue(result[0][0][0] == "Red")
|
||||
self.assertTrue(
|
||||
result[0][0][0][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
result[0][0][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
)
|
||||
self.assertTrue(result[0][0][1][0] == "Green")
|
||||
self.assertTrue(result[0][1][0] == "Green")
|
||||
self.assertTrue(
|
||||
result[0][0][1][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
result[0][1][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
)
|
||||
self.assertTrue(result[0][0][2][0] == "Blue")
|
||||
self.assertTrue(result[0][2][0] == "Blue")
|
||||
self.assertTrue(
|
||||
result[0][0][2][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
result[0][2][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
)
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ class TestTextbox(unittest.TestCase):
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
|
||||
self.assertEqual(iface.process(["Hello"]), ["olleH"])
|
||||
iface = gr.Interface(
|
||||
lambda sentence: max([len(word) for word in sentence.split()]),
|
||||
gr.inputs.Textbox(),
|
||||
@ -139,7 +139,7 @@ class TestNumber(unittest.TestCase):
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4.0"])
|
||||
self.assertEqual(iface.process([2]), ["4.0"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", "number", interpretation="default"
|
||||
)
|
||||
@ -190,7 +190,7 @@ class TestSlider(unittest.TestCase):
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4"])
|
||||
self.assertEqual(iface.process([2]), ["4"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "slider", "number", interpretation="default"
|
||||
)
|
||||
@ -236,7 +236,7 @@ class TestCheckbox(unittest.TestCase):
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
|
||||
self.assertEqual(iface.process([True])[0], [1])
|
||||
self.assertEqual(iface.process([True]), [1])
|
||||
iface = gr.Interface(
|
||||
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
|
||||
)
|
||||
@ -281,8 +281,8 @@ class TestCheckboxGroup(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
|
||||
self.assertEqual(iface.process([[]])[0], [""])
|
||||
self.assertEqual(iface.process([["a", "c"]]), ["a|c"])
|
||||
self.assertEqual(iface.process([[]]), [""])
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
|
||||
|
||||
@ -319,12 +319,12 @@ class TestRadio(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
radio_input = gr.inputs.Radio(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
self.assertEqual(iface.process(["c"]), ["cc"])
|
||||
radio_input = gr.inputs.Radio(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(
|
||||
lambda x: 2 * x, radio_input, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
self.assertEqual(iface.process(["c"]), [4])
|
||||
scores = iface.interpret(["b"])[0]["interpretation"]
|
||||
self.assertEqual(scores, [-2.0, None, 2.0])
|
||||
|
||||
@ -364,12 +364,12 @@ class TestDropdown(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
dropdown_input = gr.inputs.Dropdown(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
self.assertEqual(iface.process(["c"]), ["cc"])
|
||||
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(
|
||||
lambda x: 2 * x, dropdown, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
self.assertEqual(iface.process(["c"]), [4])
|
||||
scores = iface.interpret(["b"])[0]["interpretation"]
|
||||
self.assertEqual(scores, [-2.0, None, 2.0])
|
||||
|
||||
@ -445,7 +445,7 @@ class TestImage(unittest.TestCase):
|
||||
gr.inputs.Image(shape=(30, 10), type="file"),
|
||||
"image",
|
||||
)
|
||||
output = iface.process([img])[0][0]
|
||||
output = iface.process([img])[0]
|
||||
self.assertEqual(
|
||||
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
|
||||
)
|
||||
@ -574,7 +574,7 @@ class TestFile(unittest.TestCase):
|
||||
return os.path.getsize(file_obj.name)
|
||||
|
||||
iface = gr.Interface(get_size_of_file, "file", "number")
|
||||
self.assertEqual(iface.process([[x_file]])[0], [10558])
|
||||
self.assertEqual(iface.process([[x_file]]), [10558])
|
||||
|
||||
|
||||
class TestDataframe(unittest.TestCase):
|
||||
@ -631,14 +631,14 @@ class TestDataframe(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
x_data = [[1, 2, 3], [4, 5, 6]]
|
||||
iface = gr.Interface(np.max, "numpy", "number")
|
||||
self.assertEqual(iface.process([x_data])[0], [6])
|
||||
self.assertEqual(iface.process([x_data]), [6])
|
||||
x_data = [["Tim"], ["Jon"], ["Sal"]]
|
||||
|
||||
def get_last(my_list):
|
||||
return my_list[-1]
|
||||
|
||||
iface = gr.Interface(get_last, "list", "text")
|
||||
self.assertEqual(iface.process([x_data])[0], ["Sal"])
|
||||
self.assertEqual(iface.process([x_data]), ["Sal"])
|
||||
|
||||
|
||||
class TestVideo(unittest.TestCase):
|
||||
@ -680,7 +680,7 @@ class TestVideo(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
x_video = media_data.BASE64_VIDEO
|
||||
iface = gr.Interface(lambda x: x, "video", "playable_video")
|
||||
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
|
||||
self.assertEqual(iface.process([x_video])[0]["data"], x_video["data"])
|
||||
|
||||
|
||||
class TestTimeseries(unittest.TestCase):
|
||||
@ -729,7 +729,7 @@ class TestTimeseries(unittest.TestCase):
|
||||
}
|
||||
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
|
||||
self.assertEqual(
|
||||
iface.process([x_timeseries])[0],
|
||||
iface.process([x_timeseries]),
|
||||
[
|
||||
{
|
||||
"headers": ["time", "retail", "food", "other"],
|
||||
|
@ -17,7 +17,7 @@ class TestSeries(unittest.TestCase):
|
||||
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox())
|
||||
series = mix.Series(io1, io2)
|
||||
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
|
||||
self.assertEqual(series.process(["Hello"]), ["Hello World!"])
|
||||
|
||||
def test_with_external(self):
|
||||
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
|
||||
@ -33,7 +33,7 @@ class TestParallel(unittest.TestCase):
|
||||
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox())
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
self.assertEqual(
|
||||
parallel.process(["Hello"])[0], ["Hello World 1!", "Hello World 2!"]
|
||||
parallel.process(["Hello"]), ["Hello World 1!", "Hello World 2!"]
|
||||
)
|
||||
|
||||
def test_with_external(self):
|
||||
|
@ -19,9 +19,9 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["o"])
|
||||
self.assertEqual(iface.process(["Hello"]), ["o"])
|
||||
iface = gr.Interface(lambda x: x / 2, "number", gr.outputs.Textbox())
|
||||
self.assertEqual(iface.process([10])[0], ["5.0"])
|
||||
self.assertEqual(iface.process([10]), ["5.0"])
|
||||
|
||||
|
||||
class TestLabel(unittest.TestCase):
|
||||
@ -92,7 +92,7 @@ class TestLabel(unittest.TestCase):
|
||||
}
|
||||
|
||||
iface = gr.Interface(rgb_distribution, "image", "label")
|
||||
output = iface.process([x_img])[0][0]
|
||||
output = iface.process([x_img])[0]
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
@ -151,9 +151,7 @@ class TestImage(unittest.TestCase):
|
||||
return np.random.randint(0, 256, (width, height, 3))
|
||||
|
||||
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
|
||||
self.assertTrue(
|
||||
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
|
||||
)
|
||||
self.assertTrue(iface.process([10, 20])[0].startswith("data:image/png;base64"))
|
||||
|
||||
|
||||
class TestVideo(unittest.TestCase):
|
||||
@ -219,7 +217,7 @@ class TestHighlightedText(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(highlight_vowels, "text", "highlight")
|
||||
self.assertListEqual(
|
||||
iface.process(["Helloooo"])[0][0],
|
||||
iface.process(["Helloooo"])[0],
|
||||
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
|
||||
)
|
||||
|
||||
@ -264,7 +262,7 @@ class TestAudio(unittest.TestCase):
|
||||
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int32)
|
||||
|
||||
iface = gr.Interface(generate_noise, "slider", "audio")
|
||||
self.assertTrue(iface.process([100])[0][0].startswith("data:audio/wav;base64"))
|
||||
self.assertTrue(iface.process([100])[0].startswith("data:audio/wav;base64"))
|
||||
|
||||
|
||||
class TestJSON(unittest.TestCase):
|
||||
@ -302,7 +300,7 @@ class TestJSON(unittest.TestCase):
|
||||
["O", 20],
|
||||
["F", 30],
|
||||
]
|
||||
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
|
||||
self.assertDictEqual(iface.process([y_data])[0], {"M": 35, "F": 25, "O": 20})
|
||||
|
||||
|
||||
class TestHTML(unittest.TestCase):
|
||||
@ -311,7 +309,7 @@ class TestHTML(unittest.TestCase):
|
||||
return "<strong>" + text + "</strong>"
|
||||
|
||||
iface = gr.Interface(bold_text, "text", "html")
|
||||
self.assertEqual(iface.process(["test"])[0][0], "<strong>test</strong>")
|
||||
self.assertEqual(iface.process(["test"])[0], "<strong>test</strong>")
|
||||
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
@ -323,7 +321,7 @@ class TestFile(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(write_file, "text", "file")
|
||||
self.assertDictEqual(
|
||||
iface.process(["hello world"])[0][0],
|
||||
iface.process(["hello world"])[0],
|
||||
{
|
||||
"name": "test.txt",
|
||||
"size": 11,
|
||||
@ -408,9 +406,7 @@ class TestDataframe(unittest.TestCase):
|
||||
return array % 2 == 0
|
||||
|
||||
iface = gr.Interface(check_odd, "numpy", "numpy")
|
||||
self.assertEqual(
|
||||
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
|
||||
)
|
||||
self.assertEqual(iface.process([[2, 3, 4]])[0], {"data": [[True, False, True]]})
|
||||
|
||||
|
||||
class TestCarousel(unittest.TestCase):
|
||||
@ -477,7 +473,7 @@ class TestCarousel(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
|
||||
self.assertEqual(
|
||||
iface.process([media_data.BASE64_IMAGE])[0],
|
||||
iface.process([media_data.BASE64_IMAGE]),
|
||||
[
|
||||
[
|
||||
[
|
||||
|
@ -9,7 +9,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
class TestProcessExamples(unittest.TestCase):
|
||||
def test_process_example(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
|
||||
prediction, _ = process_examples.process_example(io, 0)
|
||||
prediction = process_examples.process_example(io, 0)
|
||||
self.assertEquals(prediction[0], "Hello World")
|
||||
|
||||
def test_caching(self):
|
||||
|
@ -29,6 +29,8 @@
|
||||
inputs: Array<number>;
|
||||
outputs: Array<number>;
|
||||
queue: boolean;
|
||||
status_tracker: number | null;
|
||||
status?: string;
|
||||
}
|
||||
|
||||
export let root: string;
|
||||
@ -113,6 +115,15 @@
|
||||
});
|
||||
|
||||
let handled_dependencies: Array<number[]> = [];
|
||||
let status_tracker_values: Record<number, string> = {};
|
||||
|
||||
let set_status = (dependency_index: number, status: string) => {
|
||||
dependencies[dependency_index].status = status;
|
||||
let status_tracker_id = dependencies[dependency_index].status_tracker;
|
||||
if (status_tracker_id !== null) {
|
||||
status_tracker_values[status_tracker_id] = status;
|
||||
}
|
||||
};
|
||||
|
||||
async function handle_mount({ detail }) {
|
||||
await tick();
|
||||
@ -139,11 +150,15 @@
|
||||
},
|
||||
queue,
|
||||
() => {}
|
||||
).then((output) => {
|
||||
output.data.forEach((value, i) => {
|
||||
instance_map[outputs[i]].value = value;
|
||||
)
|
||||
.then((output) => {
|
||||
output.data.forEach((value, i) => {
|
||||
instance_map[outputs[i]].value = value;
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(error);
|
||||
});
|
||||
});
|
||||
|
||||
handled_dependencies[i] = [-1];
|
||||
}
|
||||
@ -153,6 +168,10 @@
|
||||
if (handled_dependencies[i]?.includes(id) || !instance) return;
|
||||
// console.log(trigger, target_instances, instance);
|
||||
instance?.$on(trigger, () => {
|
||||
if (status === "pending") {
|
||||
return;
|
||||
}
|
||||
set_status(i, "pending");
|
||||
fn(
|
||||
"predict",
|
||||
{
|
||||
@ -161,11 +180,17 @@
|
||||
},
|
||||
queue,
|
||||
() => {}
|
||||
).then((output) => {
|
||||
output.data.forEach((value, i) => {
|
||||
instance_map[outputs[i]].value = value;
|
||||
)
|
||||
.then((output) => {
|
||||
set_status(i, "complete");
|
||||
output.data.forEach((value, i) => {
|
||||
instance_map[outputs[i]].value = value;
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
set_status(i, "error");
|
||||
console.error(error);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (!handled_dependencies[i]) handled_dependencies[i] = [];
|
||||
@ -194,6 +219,7 @@
|
||||
{instance_map}
|
||||
{theme}
|
||||
{root}
|
||||
{status_tracker_values}
|
||||
on:mount={handle_mount}
|
||||
on:destroy={({ detail }) => handle_destroy(detail)}
|
||||
/>
|
||||
|
@ -10,6 +10,7 @@
|
||||
export let theme;
|
||||
export let dynamic_ids: Set<number>;
|
||||
export let has_modes: boolean;
|
||||
export let status_tracker_values: Record<number, string>;
|
||||
|
||||
const dispatch = createEventDispatcher<{ mount: number; destroy: number }>();
|
||||
|
||||
@ -46,6 +47,7 @@
|
||||
{style}
|
||||
{...props}
|
||||
{root}
|
||||
tracked_status={status_tracker_values[id]}
|
||||
>
|
||||
{#if children && children.length}
|
||||
{#each children as { component, id, props, children, has_modes }}
|
||||
@ -59,6 +61,7 @@
|
||||
{children}
|
||||
{dynamic_ids}
|
||||
{has_modes}
|
||||
{status_tracker_values}
|
||||
on:destroy
|
||||
on:mount
|
||||
/>
|
||||
|
@ -6,6 +6,6 @@
|
||||
if (default_value) value = default_value;
|
||||
</script>
|
||||
|
||||
<div {style} class:hidden={!value} class="flex flex-1 flex-col gap-4">
|
||||
<div {style} class:hidden={!value} class="flex flex-1 flex-col gap-4 relative">
|
||||
<slot />
|
||||
</div>
|
||||
|
@ -6,6 +6,6 @@
|
||||
if (default_value) value = default_value;
|
||||
</script>
|
||||
|
||||
<div {style} class:hidden={!value} class="flex flex-row gap-4">
|
||||
<div {style} class:hidden={!value} class="flex flex-row gap-4 relative">
|
||||
<slot />
|
||||
</div>
|
||||
|
@ -0,0 +1,100 @@
|
||||
<script lang="ts">
|
||||
import { onDestroy } from "svelte";
|
||||
|
||||
export let style: string = "";
|
||||
export let cover_container: bool = false;
|
||||
export let eta: number | null = null;
|
||||
export let duration: number = 8.2;
|
||||
export let queue_pos: number | null = 0;
|
||||
export let tracked_status: "complete" | "pending" | "error";
|
||||
|
||||
$: progress = eta === null ? null : Math.min(duration / eta, 1);
|
||||
|
||||
let timer: NodeJS.Timeout = null;
|
||||
let timer_start = 0;
|
||||
let timer_diff = 0;
|
||||
|
||||
const start_timer = () => {
|
||||
timer_start = Date.now();
|
||||
timer_diff = 0;
|
||||
timer = setInterval(() => {
|
||||
timer_diff = (Date.now() - timer_start) / 1000;
|
||||
}, 100);
|
||||
};
|
||||
|
||||
const stop_timer = () => {
|
||||
clearInterval(timer);
|
||||
};
|
||||
|
||||
onDestroy(() => {
|
||||
if (timer) stop_timer();
|
||||
});
|
||||
|
||||
$: {
|
||||
if (tracked_status === "pending") {
|
||||
start_timer();
|
||||
} else {
|
||||
stop_timer();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if tracked_status === "pending"}
|
||||
<div class:cover_container {style}>
|
||||
<div class="text-xs font-mono text-gray-400">
|
||||
{#if queue_pos}
|
||||
{queue_pos} in line
|
||||
{:else}
|
||||
{timer_diff.toFixed(1)}s
|
||||
{/if}
|
||||
</div>
|
||||
<div class="border-gray-200 rounded border w-40 h-2 relative">
|
||||
{#if progress === null}
|
||||
<div class="bounce absolute bg-amber-500 shadow-inner h-full w-1/4" />
|
||||
{:else}
|
||||
<div
|
||||
class="blink bg-amber-500 shadow-inner h-full"
|
||||
style="width: {progress * 100}%;"
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{:else if tracked_status === "error"}
|
||||
<div class:cover_container {style}>
|
||||
<span class="text-red-400 font-mono font-semibold text-lg">ERROR</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style lang="postcss">
|
||||
.cover_container {
|
||||
@apply absolute top-0 left-0 w-full h-full z-10 flex flex-col justify-center items-center bg-gray-100 bg-opacity-25;
|
||||
}
|
||||
@keyframes blink {
|
||||
0% {
|
||||
opacity: 100%;
|
||||
}
|
||||
50% {
|
||||
opacity: 60%;
|
||||
}
|
||||
100% {
|
||||
opacity: 100%;
|
||||
}
|
||||
}
|
||||
.blink {
|
||||
animation: blink 2s infinite;
|
||||
}
|
||||
@keyframes bounce {
|
||||
0% {
|
||||
left: 0%;
|
||||
}
|
||||
50% {
|
||||
left: 75%;
|
||||
}
|
||||
100% {
|
||||
left: 0%;
|
||||
}
|
||||
}
|
||||
.bounce {
|
||||
animation: bounce 2s infinite linear;
|
||||
}
|
||||
</style>
|
2
ui/packages/app/src/components/StatusTracker/index.ts
Normal file
2
ui/packages/app/src/components/StatusTracker/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { default as Component } from "./StatusTracker.svelte";
|
||||
export const modes = ["static"];
|
@ -11,6 +11,7 @@ export const component_map: Record<string, any> = {
|
||||
dataset: () => import("./Dataset"),
|
||||
dropdown: () => import("./Dropdown"),
|
||||
file: () => import("./File"),
|
||||
statustracker: () => import("./StatusTracker"),
|
||||
highlightedtext: () => import("./HighlightedText"),
|
||||
html: () => import("./HTML"),
|
||||
image: () => import("./Image"),
|
||||
|
@ -15,7 +15,7 @@
|
||||
</script>
|
||||
|
||||
{#if $selected_tab === id}
|
||||
<div class="p-2 border-2 border-t-0 border-gray-200">
|
||||
<div class="p-2 border-2 border-t-0 border-gray-200 relative">
|
||||
<slot />
|
||||
</div>
|
||||
{/if}
|
||||
|
Loading…
x
Reference in New Issue
Block a user