mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
changes
This commit is contained in:
parent
398f5560d1
commit
0ef247aa7a
@ -55,8 +55,9 @@ class Block:
|
||||
|
||||
|
||||
class BlockContext(Block):
|
||||
def __init__(self):
|
||||
def __init__(self, css: Optional[str] = None):
|
||||
self.children = []
|
||||
self.css = css if css is not None else {}
|
||||
super().__init__()
|
||||
|
||||
def __enter__(self):
|
||||
@ -66,18 +67,36 @@ class BlockContext(Block):
|
||||
def __exit__(self, *args):
|
||||
Context.block = self.parent
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"css": self.css
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Row(BlockContext):
|
||||
def __init__(self, css: Optional[str] = None):
|
||||
super().__init__(css)
|
||||
|
||||
def get_template_context(self):
|
||||
return {"type": "row"}
|
||||
return {"type": "row", **super().get_template_context()}
|
||||
|
||||
|
||||
class Column(BlockContext):
|
||||
def __init__(self, css: Optional[str] = None):
|
||||
super().__init__(css)
|
||||
|
||||
def get_template_context(self):
|
||||
return {"type": "column"}
|
||||
return {
|
||||
"type": "column",
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
|
||||
class Tabs(BlockContext):
|
||||
def __init__(self, css: Optional[str] = None):
|
||||
super().__init__(css)
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
Parameters:
|
||||
@ -90,12 +109,13 @@ class Tabs(BlockContext):
|
||||
|
||||
|
||||
class TabItem(BlockContext):
|
||||
def __init__(self, label):
|
||||
def __init__(self, label, css: Optional[str] = None):
|
||||
super().__init__(css)
|
||||
self.label = label
|
||||
super(TabItem, self).__init__()
|
||||
|
||||
def get_template_context(self):
|
||||
return {"label": self.label}
|
||||
return {"label": self.label, **super().get_template_context()}
|
||||
|
||||
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
|
||||
"""
|
||||
|
@ -2765,21 +2765,25 @@ class Button(Component):
|
||||
class DatasetViewer(Component):
|
||||
def __init__(
|
||||
self,
|
||||
types: List[Component],
|
||||
default_value: List[List[Any]],
|
||||
*,
|
||||
components: List[Component],
|
||||
samples: List[List[Any]],
|
||||
value: Optional[Number] = None,
|
||||
label: Optional[str] = None,
|
||||
css: Optional[Dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(label=label, css=css, **kwargs)
|
||||
self.types = types
|
||||
self.value = default_value
|
||||
self.components = components
|
||||
self.headers = [c.label for c in components]
|
||||
self.samples = samples
|
||||
self.value = value
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"types": [_type.__class__.__name__.lower() for _type in types],
|
||||
"value": self.value,
|
||||
"components": [component.__class__.__name__.lower() for component in self.components],
|
||||
"headers": self.headers,
|
||||
"samples": self.samples,
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,11 @@ from typing import Any, List, Optional
|
||||
|
||||
import gradio as gr
|
||||
from gradio import encryptor
|
||||
from gradio import components
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
class FlaggingCallback(ABC):
|
||||
"""
|
||||
@ -18,7 +22,7 @@ class FlaggingCallback(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def setup(self, flagging_dir: str):
|
||||
def setup(self, components: List[Component], flagging_dir: str):
|
||||
"""
|
||||
This method should be overridden and ensure that everything is set up correctly for flag().
|
||||
This method gets called once at the beginning of the Interface.launch() method.
|
||||
@ -30,9 +34,7 @@ class FlaggingCallback(ABC):
|
||||
@abstractmethod
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
@ -42,8 +44,7 @@ class FlaggingCallback(ABC):
|
||||
This gets called every time the <flag> button is pressed.
|
||||
Parameters:
|
||||
interface: The Interface object that is being used to launch the flagging interface.
|
||||
input_data: The input data to be flagged.
|
||||
output_data: The output data to be flagged.
|
||||
flag_data: The data to be flagged.
|
||||
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
|
||||
flag_index (optional): The index of the sample that is being flagged.
|
||||
username (optional): The username of the user that is flagging the data, if logged in.
|
||||
@ -59,15 +60,14 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
provided for illustrative purposes.
|
||||
"""
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
def setup(self, components: List[Component], flagging_dir: str):
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
@ -76,26 +76,15 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
log_filepath = "{}/log.csv".format(flagging_dir)
|
||||
|
||||
csv_data = []
|
||||
for i, input in enumerate(interface.input_components):
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
csv_data.append(
|
||||
input.save_flagged(
|
||||
component.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["input_components"][i]["label"],
|
||||
input_data[i],
|
||||
component.label,
|
||||
sample,
|
||||
None,
|
||||
)
|
||||
)
|
||||
for i, output in enumerate(interface.output_components):
|
||||
csv_data.append(
|
||||
output.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["output_components"][i]["label"],
|
||||
output_data[i],
|
||||
None,
|
||||
)
|
||||
if output_data[i] is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
with open(log_filepath, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
@ -112,71 +101,44 @@ class CSVLogger(FlaggingCallback):
|
||||
Logs the input and output data to a CSV file. Supports encryption.
|
||||
"""
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
def setup(self, components: List[Component], flagging_dir: str, encryption_key: Optional[str] = None):
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
self.encryption_key = encryption_key
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_fp = "{}/log.csv".format(flagging_dir)
|
||||
encryption_key = interface.encryption_key if interface.encrypt else None
|
||||
is_new = not os.path.exists(log_fp)
|
||||
output_only_mode = input_data is None
|
||||
|
||||
if flag_index is None:
|
||||
csv_data = []
|
||||
if not output_only_mode:
|
||||
for i, input in enumerate(interface.input_components):
|
||||
csv_data.append(
|
||||
input.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["input_components"][i]["label"],
|
||||
input_data[i],
|
||||
encryption_key,
|
||||
)
|
||||
)
|
||||
for i, output in enumerate(interface.output_components):
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
csv_data.append(
|
||||
output.save_flagged(
|
||||
component.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["output_components"][i]["label"],
|
||||
output_data[i],
|
||||
encryption_key,
|
||||
component.label,
|
||||
sample,
|
||||
self.encryption_key,
|
||||
)
|
||||
if output_data[i] is not None
|
||||
if sample is not None
|
||||
else ""
|
||||
)
|
||||
if not output_only_mode:
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
if username is not None:
|
||||
csv_data.append(username)
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
csv_data.append(flag_option if flag_option is not None else "")
|
||||
csv_data.append(username if username is not None else "")
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
if is_new:
|
||||
headers = []
|
||||
if not output_only_mode:
|
||||
headers += [
|
||||
interface["label"]
|
||||
for interface in interface.config["input_components"]
|
||||
]
|
||||
headers += [
|
||||
interface["label"]
|
||||
for interface in interface.config["output_components"]
|
||||
]
|
||||
if not output_only_mode:
|
||||
if interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
if username is not None:
|
||||
headers.append("username")
|
||||
headers.append("timestamp")
|
||||
headers = [
|
||||
component.label
|
||||
for component in self.components
|
||||
] + ["flag", "username", "timestamp"]
|
||||
|
||||
def replace_flag_at_index(file_content):
|
||||
file_content = io.StringIO(file_content)
|
||||
@ -189,13 +151,13 @@ class CSVLogger(FlaggingCallback):
|
||||
writer.writerows(content)
|
||||
return output.getvalue()
|
||||
|
||||
if interface.encrypt:
|
||||
if self.encryption_key:
|
||||
output = io.StringIO()
|
||||
if not is_new:
|
||||
with open(log_fp, "rb") as csvfile:
|
||||
encrypted_csv = csvfile.read()
|
||||
decrypted_csv = encryptor.decrypt(
|
||||
interface.encryption_key, encrypted_csv
|
||||
self.encryption_key, encrypted_csv
|
||||
)
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
@ -209,7 +171,7 @@ class CSVLogger(FlaggingCallback):
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
csvfile.write(
|
||||
encryptor.encrypt(
|
||||
interface.encryption_key, output.getvalue().encode()
|
||||
self.encryption_key, output.getvalue().encode()
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -264,7 +226,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
self.dataset_private = private
|
||||
self.verbose = verbose
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
def setup(self, components: List[Component], flagging_dir: str):
|
||||
"""
|
||||
Params:
|
||||
flagging_dir (str): local directory where the dataset is cloned,
|
||||
@ -285,6 +247,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
exist_ok=True,
|
||||
)
|
||||
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
|
||||
self.repo = huggingface_hub.Repository(
|
||||
@ -300,9 +263,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
@ -325,12 +286,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
if is_new:
|
||||
headers = []
|
||||
|
||||
for i, component in enumerate(interface.input_components):
|
||||
component_label = interface.config["input_components"][i][
|
||||
"label"
|
||||
] or "Input_{}".format(i)
|
||||
headers.append(component_label)
|
||||
infos["flagged"]["features"][component_label] = {
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
headers.append(component.label)
|
||||
infos["flagged"]["features"][component.label] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value",
|
||||
}
|
||||
@ -343,66 +301,26 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
] = {"_type": _type}
|
||||
break
|
||||
|
||||
for i, component in enumerate(interface.output_components):
|
||||
component_label = interface.config["output_components"][i][
|
||||
"label"
|
||||
] or "Output_{}".format(i)
|
||||
headers.append(component_label)
|
||||
infos["flagged"]["features"][component_label] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value",
|
||||
}
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
headers.append(component_label + " file")
|
||||
for _component, _type in file_preview_types.items():
|
||||
if isinstance(component, _component):
|
||||
infos["flagged"]["features"][
|
||||
component_label + " file"
|
||||
] = {"_type": _type}
|
||||
break
|
||||
|
||||
if interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
infos["flagged"]["features"]["flag"] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value",
|
||||
}
|
||||
headers.append("flag")
|
||||
infos["flagged"]["features"]["flag"] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value",
|
||||
}
|
||||
|
||||
writer.writerow(headers)
|
||||
|
||||
# Generate the row corresponding to the flagged sample
|
||||
csv_data = []
|
||||
for i, component in enumerate(interface.input_components):
|
||||
label = interface.config["input_components"][i][
|
||||
"label"
|
||||
] or "Input_{}".format(i)
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
filepath = component.save_flagged(
|
||||
self.dataset_dir, label, input_data[i], None
|
||||
)
|
||||
self.dataset_dir, component.label, sample, None
|
||||
) if sample is not None else ""
|
||||
csv_data.append(filepath)
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
csv_data.append(
|
||||
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
|
||||
)
|
||||
for i, component in enumerate(interface.output_components):
|
||||
label = interface.config["output_components"][i][
|
||||
"label"
|
||||
] or "Output_{}".format(i)
|
||||
filepath = (
|
||||
component.save_flagged(
|
||||
self.dataset_dir, label, output_data[i], None
|
||||
)
|
||||
if output_data[i] is not None
|
||||
else ""
|
||||
)
|
||||
csv_data.append(filepath)
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
csv_data.append(
|
||||
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
|
||||
)
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
|
||||
csv_data.append(flag_option if flag_option is not None else "")
|
||||
writer.writerow(csv_data)
|
||||
|
||||
if is_new:
|
||||
|
@ -18,8 +18,14 @@ from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.footnote import footnote_plugin
|
||||
|
||||
from gradio import interpretation, utils
|
||||
from gradio.blocks import BlockContext, Column, Row
|
||||
from gradio.components import Button, Component, Markdown, get_component_instance
|
||||
from gradio.blocks import BlockContext, Block, Column, Row, Blocks
|
||||
from gradio.components import (
|
||||
Button,
|
||||
Component,
|
||||
Markdown,
|
||||
DatasetViewer,
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.external import load_from_pipeline, load_interface # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.inputs import State as i_State # type: ignore
|
||||
@ -27,6 +33,7 @@ from gradio.launchable import Launchable
|
||||
from gradio.outputs import State as o_State # type: ignore
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
from gradio.routes import predict
|
||||
import inspect
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import flask
|
||||
@ -447,10 +454,82 @@ class Interface(Launchable):
|
||||
if self.analytics_enabled:
|
||||
utils.initiated_analytics(data)
|
||||
|
||||
# Alert user if a more recent version of the library exists
|
||||
utils.version_check()
|
||||
Interface.instances.add(self)
|
||||
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
for component, param_name in zip(self.input_components, param_names):
|
||||
if component.label is None:
|
||||
component.label = param_name
|
||||
for i, component in enumerate(self.output_components):
|
||||
if component.label is None:
|
||||
component.label = "output_" + str(i)
|
||||
|
||||
if self.allow_flagging != "never":
|
||||
self.flagging_callback.setup(
|
||||
self.input_components + self.output_components, self.flagging_dir
|
||||
)
|
||||
|
||||
self.blocks = Blocks()
|
||||
with self.blocks:
|
||||
if self.title:
|
||||
Markdown(
|
||||
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
||||
+ self.title
|
||||
+ "</h1>"
|
||||
)
|
||||
if self.description:
|
||||
Markdown(self.description)
|
||||
with Row():
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
"padding": "0.5rem",
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
for component in self.input_components:
|
||||
Block.__init__(component)
|
||||
with Row():
|
||||
submit_btn = Button("Submit")
|
||||
clear_btn = Button("Clear")
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
"padding": "0.5rem",
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
for component in self.output_components:
|
||||
Block.__init__(component)
|
||||
with Row():
|
||||
flag_btn = Button("Flag")
|
||||
if self.examples:
|
||||
examples = DatasetViewer(
|
||||
components=self.input_components, samples=self.examples
|
||||
)
|
||||
submit_btn.click(
|
||||
lambda *args: self.process(args)[0][0]
|
||||
if len(self.output_components) == 1
|
||||
else self.process(args)[0],
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
)
|
||||
clear_btn.click(
|
||||
lambda: [None]
|
||||
* (len(self.input_components) + len(self.output_components)),
|
||||
[],
|
||||
self.input_components + self.output_components,
|
||||
)
|
||||
examples.click(
|
||||
lambda x: x, inputs=[examples], outputs=self.input_components
|
||||
)
|
||||
flag_btn.click(
|
||||
lambda *flag_data: self.flagging_callback.flag(flag_data),
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def __call__(self, *params):
|
||||
if (
|
||||
self.api_mode
|
||||
@ -477,80 +556,7 @@ class Interface(Launchable):
|
||||
return repr
|
||||
|
||||
def get_config_file(self):
|
||||
components = []
|
||||
layout = {"id": 0, "children": []}
|
||||
dependencies = []
|
||||
|
||||
def add_component(parent, component):
|
||||
id = len(components) + 1
|
||||
components.append(
|
||||
{
|
||||
"id": len(components) + 1,
|
||||
"type": component.__class__.__name__.lower(),
|
||||
"props": component.get_template_context(),
|
||||
}
|
||||
)
|
||||
layout_context = {"id": id}
|
||||
if isinstance(component, BlockContext):
|
||||
layout_context["children"] = []
|
||||
parent["children"].append(layout_context)
|
||||
return layout_context
|
||||
|
||||
if self.title:
|
||||
add_component(layout, Markdown("<h1>" + self.title + "</h1>"))
|
||||
if self.description:
|
||||
add_component(layout, Markdown(self.description))
|
||||
panel_row = add_component(layout, Row())
|
||||
input_panel = add_component(panel_row, Column())
|
||||
input_ids = []
|
||||
for component in self.input_components:
|
||||
input_id = add_component(input_panel, component)["id"]
|
||||
input_ids.append(input_id)
|
||||
input_panel_btns = add_component(input_panel, Row())
|
||||
submit_btn = add_component(input_panel_btns, Button("Submit"))
|
||||
clear_btn = add_component(input_panel_btns, Button("Clear"))
|
||||
|
||||
output_panel = add_component(panel_row, Column())
|
||||
output_ids = []
|
||||
for component in self.output_components:
|
||||
output_id = add_component(output_panel, component)["id"]
|
||||
output_ids.append(output_id)
|
||||
output_panel_btns = add_component(output_panel, Row())
|
||||
flag_btn = add_component(output_panel_btns, Button("Flag"))
|
||||
dependencies.append(
|
||||
{
|
||||
"id": 0,
|
||||
"trigger": "click",
|
||||
"targets": [submit_btn["id"]],
|
||||
"inputs": input_ids,
|
||||
"outputs": output_ids,
|
||||
}
|
||||
)
|
||||
dependencies.append(
|
||||
{
|
||||
"id": 1,
|
||||
"trigger": "click",
|
||||
"targets": [clear_btn["id"]],
|
||||
"inputs": [],
|
||||
"outputs": input_ids + output_ids,
|
||||
}
|
||||
)
|
||||
dependencies.append(
|
||||
{
|
||||
"id": 2,
|
||||
"trigger": "click",
|
||||
"targets": [flag_btn["id"]],
|
||||
"inputs": input_ids + output_ids,
|
||||
"outputs": [],
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"mode": "blocks",
|
||||
"components": components,
|
||||
"layout": layout,
|
||||
"dependencies": dependencies,
|
||||
}
|
||||
return self.blocks.get_config_file()
|
||||
|
||||
def run_prediction(
|
||||
self,
|
||||
@ -615,52 +621,8 @@ class Interface(Launchable):
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
||||
flag_index = None
|
||||
if data.get("example_id") is not None:
|
||||
example_id = data["example_id"]
|
||||
if self.cache_examples:
|
||||
prediction = load_from_cache(self, example_id)
|
||||
durations = None
|
||||
else:
|
||||
prediction, durations = process_example(self, example_id)
|
||||
else:
|
||||
raw_input = data["data"]
|
||||
prediction, durations = self.process(raw_input)
|
||||
if self.allow_flagging == "auto":
|
||||
flag_index = self.flagging_callback.flag(
|
||||
self,
|
||||
raw_input,
|
||||
prediction,
|
||||
flag_option="" if self.flagging_options else None,
|
||||
username=username,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"avg_durations": self.config.get("avg_durations"),
|
||||
"flag_index": flag_index,
|
||||
}
|
||||
|
||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
||||
class RequestApi:
|
||||
SUBMIT = 0
|
||||
CLEAR = 1
|
||||
FLAG = 2
|
||||
|
||||
raw_input = data["data"]
|
||||
fn_index = data["fn_index"]
|
||||
if fn_index == RequestApi.SUBMIT:
|
||||
prediction, durations = self.process(raw_input)
|
||||
return {"data": prediction}
|
||||
elif fn_index == RequestApi.CLEAR:
|
||||
return {
|
||||
"data": [None]
|
||||
* (len(self.input_components) + len(self.output_components))
|
||||
}
|
||||
elif fn_index == RequestApi.FLAG: # flag
|
||||
pass
|
||||
def process_api(self, *args) -> Dict[str, Any]:
|
||||
return self.blocks.process_api(*args)
|
||||
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
@ -716,11 +678,6 @@ class Interface(Launchable):
|
||||
print("PASSED")
|
||||
continue
|
||||
|
||||
def launch(self, **args):
|
||||
if self.allow_flagging != "never":
|
||||
self.flagging_callback.setup(self.flagging_dir)
|
||||
return super().launch(**args)
|
||||
|
||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
|
||||
"""
|
||||
A catch-all method for integrating with other libraries.
|
||||
|
@ -45,9 +45,9 @@
|
||||
</script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
||||
<title>Gradio</title>
|
||||
<script type="module" crossorigin src="./assets/index.b8adfcc9.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index.9a03ef16.js"></script>
|
||||
<link rel="modulepreload" href="./assets/vendor.cb9b505c.js">
|
||||
<link rel="stylesheet" href="./assets/index.7e32d9ef.css">
|
||||
<link rel="stylesheet" href="./assets/index.cd197090.css">
|
||||
</head>
|
||||
|
||||
<body style="height: 100%; margin: 0; padding: 0">
|
||||
|
@ -124,7 +124,6 @@
|
||||
if (handled_dependencies[i]?.includes(id) || !instance) return;
|
||||
// console.log(trigger, target_instances, instance);
|
||||
instance?.$on(trigger, () => {
|
||||
console.log("boo");
|
||||
fn("predict", {
|
||||
fn_index: i,
|
||||
data: inputs.map((id) => instance_map[id].value)
|
||||
|
@ -1,3 +1,7 @@
|
||||
<div class="flex flex-col flex-1 gap-4">
|
||||
<script lang="ts">
|
||||
export let style:string = "";
|
||||
</script>
|
||||
|
||||
<div {style} class="flex flex-1 flex-col gap-4">
|
||||
<slot />
|
||||
</div>
|
||||
|
@ -1,17 +1,162 @@
|
||||
<script lang="ts">
|
||||
export let types: Array<string>;
|
||||
import { createEventDispatcher } from "svelte";
|
||||
import { component_map } from "./directory";
|
||||
|
||||
export let components: Array<string>;
|
||||
export let headers: Array<string>;
|
||||
export let value: Array<Array<any>>;
|
||||
export let samples: Array<Array<any>>;
|
||||
export let value: Number | null = null;
|
||||
export let samples_dir: string = "file/";
|
||||
export let samples_per_page: number = 10;
|
||||
|
||||
export let theme: string;
|
||||
export let style: string | null;
|
||||
|
||||
const dispatch = createEventDispatcher<{ click: number }>();
|
||||
|
||||
let sample_id: number | null = null;
|
||||
let page = 0;
|
||||
let gallery = headers.length === 1;
|
||||
let paginate = samples.length > samples_per_page;
|
||||
|
||||
let selected_samples: Array<Array<unknown>>;
|
||||
let page_count: number;
|
||||
let visible_pages: Array<number> = [];
|
||||
$: {
|
||||
if (paginate) {
|
||||
visible_pages = [];
|
||||
selected_samples = samples.slice(
|
||||
page * samples_per_page,
|
||||
(page + 1) * samples_per_page
|
||||
);
|
||||
page_count = Math.ceil(samples.length / samples_per_page);
|
||||
[0, page, page_count - 1].forEach((anchor) => {
|
||||
for (let i = anchor - 2; i <= anchor + 2; i++) {
|
||||
if (i >= 0 && i < page_count && !visible_pages.includes(i)) {
|
||||
if (
|
||||
visible_pages.length > 0 &&
|
||||
i - visible_pages[visible_pages.length - 1] > 1
|
||||
) {
|
||||
visible_pages.push(-1);
|
||||
}
|
||||
visible_pages.push(i);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
selected_samples = samples.slice();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
{#each headers as header}
|
||||
<th>{header}</th>
|
||||
<div
|
||||
class="samples-holder mt-4 inline-block max-w-full"
|
||||
class:gallery
|
||||
class:overflow-x-auto={!gallery}
|
||||
>
|
||||
{#if gallery}
|
||||
<div class="samples-gallery flex gap-2 flex-wrap">
|
||||
{#each selected_samples as sample_row, i}
|
||||
<button
|
||||
class="sample cursor-pointer p-2 rounded bg-gray-50 dark:bg-gray-700 transition"
|
||||
class:selected={i + page * samples_per_page === sample_id}
|
||||
on:click={() => {
|
||||
value = samples[i];
|
||||
dispatch("click", i + page * samples_per_page);
|
||||
}}
|
||||
>
|
||||
<svelte:component
|
||||
this={component_map[components[0]]}
|
||||
{theme}
|
||||
value={sample_row[0]}
|
||||
{samples_dir}
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<table
|
||||
class="samples-table table-auto p-2 bg-gray-50 dark:bg-gray-600 rounded max-w-full border-collapse"
|
||||
>
|
||||
<thead class="border-b-2 dark:border-gray-600">
|
||||
<tr>
|
||||
{#each headers as header}
|
||||
<th class="py-2 px-4">
|
||||
{header}
|
||||
</th>
|
||||
{/each}
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each selected_samples as sample_row, i}
|
||||
<tr
|
||||
class="cursor-pointer transition"
|
||||
class:selected={i + page * samples_per_page === sample_id}
|
||||
on:click={() => {
|
||||
value = samples[i];
|
||||
dispatch("click", i + page * samples_per_page);
|
||||
}}
|
||||
>
|
||||
{#each sample_row as sample_cell, j}
|
||||
<td class="py-2 px-4">
|
||||
<svelte:component
|
||||
this={component_map[components[j]]}
|
||||
{theme}
|
||||
value={sample_cell}
|
||||
{samples_dir}
|
||||
/>
|
||||
</td>
|
||||
{/each}
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
{/if}
|
||||
</div>
|
||||
{#if paginate}
|
||||
<div class="flex gap-2 items-center mt-4">
|
||||
Pages:
|
||||
{#each visible_pages as visible_page}
|
||||
{#if visible_page === -1}
|
||||
<div>...</div>
|
||||
{:else}
|
||||
<button
|
||||
class="page"
|
||||
class:font-bold={page === visible_page}
|
||||
on:click={() => (page = visible_page)}
|
||||
>
|
||||
{visible_page + 1}
|
||||
</button>
|
||||
{/if}
|
||||
{/each}
|
||||
</thead>
|
||||
|
||||
</table>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style lang="postcss" global>
|
||||
.samples-holder:not(.gallery) {
|
||||
@apply shadow;
|
||||
.samples-table {
|
||||
@apply rounded dark:bg-gray-700;
|
||||
thead {
|
||||
@apply border-gray-300 dark:border-gray-600;
|
||||
}
|
||||
tbody tr:hover {
|
||||
@apply bg-amber-500 dark:bg-red-700 text-white;
|
||||
}
|
||||
}
|
||||
}
|
||||
.samples-holder .samples-gallery {
|
||||
.sample {
|
||||
@apply shadow;
|
||||
}
|
||||
.sample:hover {
|
||||
@apply bg-amber-500 text-white;
|
||||
}
|
||||
}
|
||||
.samples-table tr.selected {
|
||||
@apply font-semibold;
|
||||
}
|
||||
.page {
|
||||
@apply py-1 px-2 bg-gray-100 dark:bg-gray-700 rounded;
|
||||
}
|
||||
</style>
|
||||
|
@ -1,15 +1,9 @@
|
||||
import ExampleNumber from "./ExampleComponents/Number.svelte"
|
||||
import ExampleDropdown from "./ExampleComponents/Dropdown.svelte"
|
||||
import ExampleRadio from "./ExampleComponents/Radio.svelte"
|
||||
|
||||
export const component_map = {
|
||||
audio: () => import("./ExampleComponents/Audio.svelte"),
|
||||
checkbox: () => import("./ExampleComponents/Checkbox.svelte"),
|
||||
checkboxgroup: () => import("./ExampleComponents/CheckboxGroup.svelte"),
|
||||
dropdown: () => import("./ExampleComponents/Dropdown.svelte"),
|
||||
file: () => import("./ExampleComponents/File.svelte"),
|
||||
highlightedtext: () => import("./ExampleComponents/HighlightedText.svelte"),
|
||||
html: () => import("./ExampleComponents/HTML.svelte"),
|
||||
image: () => import("./ExampleComponents/Image.svelte"),
|
||||
number: () => import("./ExampleComponents/Number.svelte"),
|
||||
radio: () => import("./ExampleComponents/Radio.svelte"),
|
||||
slider: () => import("./ExampleComponents/slider.svelte"),
|
||||
textbox: () => import("./ExampleComponents/Textbox.svelte"),
|
||||
video: () => import("./ExampleComponents/Video.svelte")
|
||||
dropdown: ExampleDropdown,
|
||||
number: ExampleNumber,
|
||||
radio: ExampleRadio,
|
||||
};
|
||||
|
2
ui/packages/app/src/components/DatasetViewer/index.ts
Normal file
2
ui/packages/app/src/components/DatasetViewer/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { default as Component } from "./DatasetViewer.svelte";
|
||||
export const modes = ["dynamic"];
|
@ -1,3 +1,7 @@
|
||||
<div class="flex flex-row gap-4">
|
||||
<script lang="ts">
|
||||
export let style:string = "";
|
||||
</script>
|
||||
|
||||
<div {style} class="flex flex-row gap-4">
|
||||
<slot />
|
||||
</div>
|
||||
|
@ -8,6 +8,7 @@ export const component_map = {
|
||||
checkboxgroup: () => import("./CheckboxGroup"),
|
||||
column: () => import("./Column"),
|
||||
dataframe: () => import("./DataFrame"),
|
||||
datasetviewer: () => import("./DatasetViewer"),
|
||||
dropdown: () => import("./Dropdown"),
|
||||
file: () => import("./File"),
|
||||
highlightedtext: () => import("./HighlightedText"),
|
||||
|
@ -11,6 +11,6 @@
|
||||
$: value, dispatch("change");
|
||||
</script>
|
||||
|
||||
<div class="output-markdown prose" {theme}>
|
||||
<div class="output-markdown prose" style="max-width: 100%" {theme}>
|
||||
{@html value}
|
||||
</div>
|
||||
|
Loading…
Reference in New Issue
Block a user