Restore Interpretation, Live, Auth, Queueing (#915)

Restore Interpretation, Live, Auth, Queueing

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2022-04-04 15:47:51 -07:00 committed by GitHub
parent e5ea806b5e
commit 5c44ae3536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 421 additions and 304 deletions

View File

@ -14,9 +14,7 @@ def diff_texts(text1, text2):
demo = gr.Interface(
diff_texts,
[
gr.Textbox(
lines=3, default="The quick brown fox jumped over the lazy dogs."
),
gr.Textbox(lines=3, default="The quick brown fox jumped over the lazy dogs."),
gr.Textbox(lines=3, default="The fast brown fox jumps over lazy dogs."),
],
gr.HighlightedText(),

View File

@ -18,9 +18,7 @@ def recognize_digit(image):
return {str(i): prediction[i] for i in range(10)}
im = gradio.Image(
shape=(28, 28), image_mode="L", invert_colors=False, source="canvas"
)
im = gradio.Image(shape=(28, 28), image_mode="L", invert_colors=False, source="canvas")
demo = gr.Interface(
recognize_digit,

View File

@ -26,9 +26,7 @@ demo = gr.Interface(
plot_forecast,
[
gr.Radio([2025, 2030, 2035, 2040], label="Project to:"),
gr.CheckboxGroup(
["Google", "Microsoft", "Gradio"], label="Company Selection"
),
gr.CheckboxGroup(["Google", "Microsoft", "Gradio"], label="Company Selection"),
gr.Slider(minimum=1, maximum=100, label="Noise Level"),
gr.Checkbox(label="Show Legend"),
gr.Dropdown(["cross", "line", "circle"], label="Style"),

View File

@ -35,7 +35,7 @@ demo = gr.Interface(
fn=gender_of_sentence,
inputs=gr.Textbox(default="She went to his house to get her keys."),
outputs="label",
interpretation=interpret_gender
interpretation=interpret_gender,
)
if __name__ == "__main__":

View File

@ -20,8 +20,9 @@ def classify_image(inp):
image = gr.Image(shape=(224, 224))
label = gr.Label(num_top_classes=3)
demo = gr.Interface(fn=classify_image, inputs=image, outputs=label,
interpretation="default")
demo = gr.Interface(
fn=classify_image, inputs=image, outputs=label, interpretation="default"
)
if __name__ == "__main__":
demo.launch()
demo.launch()

View File

@ -4,6 +4,7 @@ import gradio as gr
def image_mod(text):
return text[::-1]
demo = gr.Blocks()
with demo:

View File

@ -37,7 +37,7 @@ def main_note(audio):
if pitch not in volume_per_pitch:
volume_per_pitch[pitch] = 0
volume_per_pitch[pitch] += 1.0 * volume / total_volume
volume_per_pitch = {k:float(v) for k,v in volume_per_pitch.items()}
volume_per_pitch = {k: float(v) for k, v in volume_per_pitch.items()}
return volume_per_pitch

View File

@ -31,9 +31,7 @@ demo = gr.Interface(
outbreak,
[
gr.Slider(minimum=1, maximum=4, default_value=3.2, label="R"),
gr.Dropdown(
["January", "February", "March", "April", "May"], label="Month"
),
gr.Dropdown(["January", "February", "March", "April", "May"], label="Month"),
gr.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
gr.Checkbox(label="Social Distancing?"),
],

View File

@ -5,25 +5,28 @@ import gradio as gr
asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
classifier = pipeline("text-classification")
def speech_to_text(speech):
text = asr(speech)["text"]
return text
def text_to_sentiment(text):
return classifier(text)[0]["label"]
demo = gr.Blocks()
with demo:
m = gr.Audio(type="filepath")
t = gr.Textbox()
l = gr.Label()
b1 = gr.Button("Recognize Speech")
b2 = gr.Button("Classify Sentiment")
b1.click(speech_to_text, inputs=m, outputs=t)
b2.click(text_to_sentiment, inputs=t, outputs=l)
if __name__ == "__main__":
demo.launch()
demo.launch()

View File

@ -69,6 +69,8 @@ predictions = clf.predict(X_test)
def predict_survival(passenger_class, is_male, age, company, fare, embark_point):
if passenger_class is None or embark_point is None:
return None
df = pd.DataFrame.from_dict(
{
"Pclass": [passenger_class + 1],
@ -93,9 +95,7 @@ demo = gr.Interface(
gr.Dropdown(["first", "second", "third"], type="index"),
"checkbox",
gr.Slider(minimum=0, maximum=80),
gr.CheckboxGroup(
["Sibling", "Child"], label="Travelling with (select all)"
),
gr.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"),
gr.Number(),
gr.Radio(["S", "C", "Q"], type="index"),
],
@ -106,6 +106,7 @@ demo = gr.Interface(
["third", True, 30, ["Child"], 20, "S"],
],
interpretation="default",
live=True,
)
if __name__ == "__main__":

View File

@ -8,7 +8,7 @@ ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
with gr.Blocks() as demo:
gr.Markdown(
"""
"""
# Detect Disease From Scan
With this model you can lorem ipsum
- ipsum 1

View File

@ -33,6 +33,8 @@ class Block:
fn: Callable,
inputs: List[Component],
outputs: List[Component],
preprocess=True,
queue=False,
) -> None:
"""
Adds an event to the component's dependencies.
@ -49,24 +51,26 @@ class Block:
if not isinstance(outputs, list):
outputs = [outputs]
Context.root_block.fns.append(fn)
Context.root_block.fns.append((fn, preprocess))
Context.root_block.dependencies.append(
{
"targets": [self._id],
"trigger": event_name,
"inputs": [block._id for block in inputs],
"outputs": [block._id for block in outputs],
"queue": queue,
}
)
class BlockContext(Block):
def __init__(self, css: Optional[Dict[str, str]] = None):
def __init__(self, visible: bool = True, css: Optional[Dict[str, str]] = None):
"""
css: Css rules to apply to block.
"""
self.children = []
self.css = css if css is not None else {}
self.visible = visible
super().__init__()
def __enter__(self):
@ -78,26 +82,29 @@ class BlockContext(Block):
Context.block = self.parent
def get_template_context(self):
return {"css": self.css}
return {"css": self.css, "default_value": self.visible}
def postprocess(self, y):
return y
class Row(BlockContext):
def __init__(self, css: Optional[str] = None):
def __init__(self, visible: bool = True, css: Optional[Dict[str, str]] = None):
"""
css: Css rules to apply to block.
"""
super().__init__(css)
super().__init__(visible, css)
def get_template_context(self):
return {"type": "row", **super().get_template_context()}
class Column(BlockContext):
def __init__(self, css: Optional[str] = None):
def __init__(self, visible: bool = True, css: Optional[Dict[str, str]] = None):
"""
css: Css rules to apply to block.
"""
super().__init__(css)
super().__init__(visible, css)
def get_template_context(self):
return {
@ -107,11 +114,11 @@ class Column(BlockContext):
class Tabs(BlockContext):
def __init__(self, css: Optional[dict] = None):
def __init__(self, visible: bool = True, css: Optional[Dict[str, str]] = None):
"""
css: css rules to apply to block.
"""
super().__init__(css)
super().__init__(visible, css)
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
@ -125,11 +132,13 @@ class Tabs(BlockContext):
class TabItem(BlockContext):
def __init__(self, label, css: Optional[str] = None):
def __init__(
self, label, visible: bool = True, css: Optional[Dict[str, str]] = None
):
"""
css: Css rules to apply to block.
"""
super().__init__(css)
super().__init__(visible, css)
self.label = label
def get_template_context(self):
@ -167,14 +176,17 @@ class Blocks(Launchable, BlockContext):
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
raw_input = data["data"]
fn_index = data["fn_index"]
fn = self.fns[fn_index]
fn, preprocess = self.fns[fn_index]
dependency = self.dependencies[fn_index]
processed_input = [
self.blocks[input_id].preprocess(raw_input[i])
for i, input_id in enumerate(dependency["inputs"])
]
predictions = fn(*processed_input)
if preprocess:
processed_input = [
self.blocks[input_id].preprocess(raw_input[i])
for i, input_id in enumerate(dependency["inputs"])
]
predictions = fn(*processed_input)
else:
predictions = fn(*raw_input)
if len(dependency["outputs"]) == 1:
predictions = (predictions,)
processed_output = [

View File

@ -923,7 +923,10 @@ class Radio(Component):
if self.type == "value":
return x
elif self.type == "index":
return self.choices.index(x)
if x is None:
return None
else:
return self.choices.index(x)
else:
raise ValueError(
"Unknown type: "
@ -2820,7 +2823,13 @@ class Button(Component):
def get_template_context(self):
return {"default_value": self.default_value, **super().get_template_context()}
def click(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
def click(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
queue=False,
):
"""
Parameters:
fn: Callable function
@ -2828,7 +2837,19 @@ class Button(Component):
outputs: List of outputs
Returns: None
"""
self.set_event_trigger("click", fn, inputs, outputs)
self.set_event_trigger("click", fn, inputs, outputs, queue=queue)
def _click_no_preprocess(
self, fn: Callable, inputs: List[Component], outputs: List[Component]
):
"""
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
Returns: None
"""
self.set_event_trigger("click", fn, inputs, outputs, preprocess=False)
class Dataset(Component):
@ -2873,6 +2894,29 @@ class Dataset(Component):
self.set_event_trigger("click", fn, inputs, outputs)
class Interpretation(Component):
"""
Used to create an interpretation widget for a component.
"""
def __init__(
self,
component: Component,
*,
label: Optional[str] = None,
css: Optional[Dict] = None,
**kwargs,
):
super().__init__(label=label, css=css, **kwargs)
self.component = component
def get_template_context(self):
return {
"component": self.component.__class__.__name__.lower(),
"component_props": self.component.get_template_context(),
}
# TODO: (faruk) does this take component or interface as a input?
# see this line in Carousel
# self.components = [get_component_instance(component) for component in components]

View File

@ -24,6 +24,7 @@ from gradio.components import (
Button,
Component,
Dataset,
Interpretation,
Markdown,
get_component_instance,
)
@ -489,11 +490,20 @@ class Interface(Blocks):
"border-radius": "0.5rem",
}
):
for component in self.input_components:
component.render()
input_component_column = Column()
with input_component_column:
for component in self.input_components:
component.render()
if self.interpretation:
interpret_component_column = Column(visible=False)
interpretation_set = []
with interpret_component_column:
for component in self.input_components:
interpretation_set.append(Interpretation(component))
with Row():
clear_btn = Button("Clear")
submit_btn = Button("Submit")
if not self.live:
submit_btn = Button("Submit")
with Column(
css={
"background-color": "rgb(249,250,251)",
@ -505,18 +515,39 @@ class Interface(Blocks):
component.render()
with Row():
flag_btn = Button("Flag")
submit_btn.click(
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = (
lambda *args: self.run_prediction(args, return_duration=False)[0]
if len(self.output_components) == 1
else self.run_prediction(args, return_duration=False),
self.input_components,
self.output_components,
else self.run_prediction(args, return_duration=False)
)
if self.live:
for component in self.input_components:
component.change(
submit_fn, self.input_components, self.output_components
)
else:
submit_btn.click(
submit_fn,
self.input_components,
self.output_components,
queue=self.enable_queue,
)
clear_btn.click(
lambda: [None]
* (len(self.input_components) + len(self.output_components)),
lambda: [
component.default_value
if hasattr(component, "default_value")
else None
for component in self.input_components + self.output_components
]
+ [True]
+ ([False] if self.interpretation else []),
[],
self.input_components + self.output_components,
self.input_components
+ self.output_components
+ [input_component_column]
+ ([interpret_component_column] if self.interpretation else []),
)
if self.examples:
examples = Dataset(
@ -530,6 +561,13 @@ class Interface(Blocks):
inputs=self.input_components + self.output_components,
outputs=[],
)
if self.interpretation:
interpretation_btn._click_no_preprocess(
lambda *data: self.interpret(data) + [False, True],
inputs=self.input_components + self.output_components,
outputs=interpretation_set
+ [input_component_column, interpret_component_column],
)
def __call__(self, *params):
if (
@ -655,7 +693,12 @@ class Interface(Blocks):
return processed_output, durations
def interpret(self, raw_input: List[Any]) -> List[Any]:
return interpretation.run_interpret(self, raw_input)
return [
{"original": raw_value, "interpretation": interpretation}
for interpretation, raw_value in zip(
interpretation.run_interpret(self, raw_input)[0], raw_input
)
]
def test_launch(self) -> None:
"""

View File

@ -153,7 +153,7 @@ def run_interpret(interface, raw_input):
scores.append(None)
alternative_outputs.append([])
else:
raise ValueError("Uknown intepretation method: {}".format(interp))
raise ValueError("Unknown intepretation method: {}".format(interp))
return scores, alternative_outputs
else: # custom interpretation function
processed_input = [

View File

@ -45,9 +45,9 @@
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<title>Gradio</title>
<script type="module" crossorigin src="./assets/index.6ccb5958.js"></script>
<link rel="modulepreload" href="./assets/vendor.c988cbcf.js">
<link rel="stylesheet" href="./assets/index.cdf32a5f.css">
<script type="module" crossorigin src="./assets/index.a16d1a2b.js"></script>
<link rel="modulepreload" href="./assets/vendor.47c5eb78.js">
<link rel="stylesheet" href="./assets/index.f6d9a6c5.css">
</head>
<body style="height: 100%; margin: 0; padding: 0">

0
scripts/format_frontend.sh Normal file → Executable file
View File

View File

@ -22,9 +22,17 @@ XRAY_CONFIG = {
"css": {},
},
},
{"id": 3, "type": "tabs", "props": {"css": {}}},
{"id": 4, "type": "tabitem", "props": {"label": "X-ray", "css": {}}},
{"id": 5, "type": "row", "props": {"type": "row", "css": {}}},
{"id": 3, "type": "tabs", "props": {"css": {}, "default_value": True}},
{
"id": 4,
"type": "tabitem",
"props": {"label": "X-ray", "css": {}, "default_value": True},
},
{
"id": 5,
"type": "row",
"props": {"type": "row", "css": {}, "default_value": True},
},
{
"id": 6,
"type": "image",
@ -54,8 +62,16 @@ XRAY_CONFIG = {
"css": {"background-color": "red", "--hover-color": "orange"},
},
},
{"id": 9, "type": "tabitem", "props": {"label": "CT Scan", "css": {}}},
{"id": 10, "type": "row", "props": {"type": "row", "css": {}}},
{
"id": 9,
"type": "tabitem",
"props": {"label": "CT Scan", "css": {}, "default_value": True},
},
{
"id": 10,
"type": "row",
"props": {"type": "row", "css": {}, "default_value": True},
},
{
"id": 11,
"type": "image",
@ -127,7 +143,19 @@ XRAY_CONFIG = {
],
},
"dependencies": [
{"targets": [8], "trigger": "click", "inputs": [2, 6], "outputs": [7]},
{"targets": [13], "trigger": "click", "inputs": [2, 11], "outputs": [12]},
{
"targets": [8],
"trigger": "click",
"inputs": [2, 6],
"outputs": [7],
"queue": False,
},
{
"targets": [13],
"trigger": "click",
"inputs": [2, 11],
"outputs": [12],
"queue": False,
},
],
}

View File

@ -69,40 +69,34 @@ class TestTextbox(unittest.TestCase):
"number",
interpretation="default",
)
scores, alternative_outputs = iface.interpret(
scores = iface.interpret(
["Return the length of the longest word in this sentence"]
)
)[0]["interpretation"]
self.assertEqual(
scores,
[
[
("Return", 0.0),
(" ", 0),
("the", 0.0),
(" ", 0),
("length", 0.0),
(" ", 0),
("of", 0.0),
(" ", 0),
("the", 0.0),
(" ", 0),
("longest", 0.0),
(" ", 0),
("word", 0.0),
(" ", 0),
("in", 0.0),
(" ", 0),
("this", 0.0),
(" ", 0),
("sentence", 1.0),
(" ", 0),
]
("Return", 0.0),
(" ", 0),
("the", 0.0),
(" ", 0),
("length", 0.0),
(" ", 0),
("of", 0.0),
(" ", 0),
("the", 0.0),
(" ", 0),
("longest", 0.0),
(" ", 0),
("word", 0.0),
(" ", 0),
("in", 0.0),
(" ", 0),
("this", 0.0),
(" ", 0),
("sentence", 1.0),
(" ", 0),
],
)
self.assertEqual(
alternative_outputs,
[[[8], [8], [8], [8], [8], [8], [8], [8], [8], [7]]],
)
class TestNumber(unittest.TestCase):
@ -139,32 +133,17 @@ class TestNumber(unittest.TestCase):
iface = gr.Interface(
lambda x: x**2, "number", "number", interpretation="default"
)
scores, alternative_outputs = iface.interpret([2])
scores = iface.interpret([2])[0]["interpretation"]
self.assertEqual(
scores,
[
[
(1.94, -0.23640000000000017),
(1.96, -0.15840000000000032),
(1.98, -0.07960000000000012),
[2, None],
(2.02, 0.08040000000000003),
(2.04, 0.16159999999999997),
(2.06, 0.24359999999999982),
]
],
)
self.assertEqual(
alternative_outputs,
[
[
[3.7636],
[3.8415999999999997],
[3.9204],
[4.0804],
[4.1616],
[4.2436],
]
(1.94, -0.23640000000000017),
(1.96, -0.15840000000000032),
(1.98, -0.07960000000000012),
[2, None],
(2.02, 0.08040000000000003),
(2.04, 0.16159999999999997),
(2.06, 0.24359999999999982),
],
)
@ -204,35 +183,18 @@ class TestSlider(unittest.TestCase):
iface = gr.Interface(
lambda x: x**2, "slider", "number", interpretation="default"
)
scores, alternative_outputs = iface.interpret([2])
scores = iface.interpret([2])[0]["interpretation"]
self.assertEqual(
scores,
[
[
-4.0,
200.08163265306123,
812.3265306122449,
1832.7346938775513,
3261.3061224489797,
5098.040816326531,
7342.938775510205,
9996.0,
]
],
)
self.assertEqual(
alternative_outputs,
[
[
[0.0],
[204.08163265306123],
[816.3265306122449],
[1836.7346938775513],
[3265.3061224489797],
[5102.040816326531],
[7346.938775510205],
[10000.0],
]
-4.0,
200.08163265306123,
812.3265306122449,
1832.7346938775513,
3261.3061224489797,
5098.040816326531,
7342.938775510205,
9996.0,
],
)
@ -266,12 +228,10 @@ class TestCheckbox(unittest.TestCase):
iface = gr.Interface(
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
)
scores, alternative_outputs = iface.interpret([False])
self.assertEqual(scores, [(None, 1.0)])
self.assertEqual(alternative_outputs, [[[1]]])
scores, alternative_outputs = iface.interpret([True])
self.assertEqual(scores, [(-1.0, None)])
self.assertEqual(alternative_outputs, [[[0]]])
scores = iface.interpret([False])[0]["interpretation"]
self.assertEqual(scores, (None, 1.0))
scores = iface.interpret([True])[0]["interpretation"]
self.assertEqual(scores, (-1.0, None))
class TestCheckboxGroup(unittest.TestCase):
@ -351,9 +311,8 @@ class TestRadio(unittest.TestCase):
lambda x: 2 * x, radio_input, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
scores, alternative_outputs = iface.interpret(["b"])
self.assertEqual(scores, [[-2.0, None, 2.0]])
self.assertEqual(alternative_outputs, [[[0], [4]]])
scores = iface.interpret(["b"])[0]["interpretation"]
self.assertEqual(scores, [-2.0, None, 2.0])
class TestDropdown(unittest.TestCase):
@ -396,9 +355,8 @@ class TestDropdown(unittest.TestCase):
lambda x: 2 * x, dropdown, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
scores, alternative_outputs = iface.interpret(["b"])
self.assertEqual(scores, [[-2.0, None, 2.0]])
self.assertEqual(alternative_outputs, [[[0], [4]]])
scores = iface.interpret(["b"])[0]["interpretation"]
self.assertEqual(scores, [-2.0, None, 2.0])
class TestImage(unittest.TestCase):
@ -478,23 +436,15 @@ class TestImage(unittest.TestCase):
iface = gr.Interface(
lambda x: np.sum(x), image_input, "number", interpretation="default"
)
scores, alternative_outputs = iface.interpret([img])
self.assertEqual(scores, gr.test_data.SUM_PIXELS_INTERPRETATION["scores"])
self.assertEqual(
alternative_outputs,
gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"],
)
scores = iface.interpret([img])[0]["interpretation"]
self.assertEqual(scores, gr.test_data.SUM_PIXELS_INTERPRETATION["scores"][0])
iface = gr.Interface(
lambda x: np.sum(x), image_input, "label", interpretation="shap"
)
scores, alternative_outputs = iface.interpret([img])
scores = iface.interpret([img])[0]["interpretation"]
self.assertEqual(
len(scores[0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]),
)
self.assertEqual(
len(alternative_outputs[0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0][0]),
)
image_input = gr.inputs.Image(shape=(30, 10))
iface = gr.Interface(

View File

@ -100,8 +100,8 @@ class TestInterface(unittest.TestCase):
def test_interface_none_interp(self):
interface = Interface(lambda x: x, "textbox", "label", interpretation=[None])
scores, alternative_outputs = interface.interpret(["quickest brown fox"])
self.assertIsNone(scores[0])
scores = interface.interpret(["quickest brown fox"])[0]["interpretation"]
self.assertIsNone(scores)
@mock.patch("webbrowser.open")
def test_interface_browser(self, mock_browser):

View File

@ -17,7 +17,9 @@ class TestDefault(unittest.TestCase):
text_interface = Interface(
max_word_len, "textbox", "label", interpretation="default"
)
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
interpretation = text_interface.interpret(["quickest brown fox"])[0][
"interpretation"
]
self.assertGreater(
interpretation[0][1], 0
) # Checks to see if the first word has >0 score.
@ -32,13 +34,12 @@ class TestShapley(unittest.TestCase):
text_interface = Interface(
max_word_len, "textbox", "label", interpretation="shapley"
)
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
interpretation = text_interface.interpret(["quickest brown fox"])[0][
"interpretation"
][0]
self.assertGreater(
interpretation[0][1], 0
interpretation[1], 0
) # Checks to see if the first word has >0 score.
self.assertEqual(
interpretation[-1][1], 0
) # Checks to see if the last word has 0 score.
class TestCustom(unittest.TestCase):
@ -48,9 +49,11 @@ class TestCustom(unittest.TestCase):
text_interface = Interface(
max_word_len, "textbox", "label", interpretation=custom
)
result = text_interface.interpret(["quickest brown fox"])[0][0]
result = text_interface.interpret(["quickest brown fox"])[0]["interpretation"][
0
]
self.assertEqual(
result[0][1], 1
result[1], 1
) # Checks to see if the first letter has score of 1.
def test_custom_img(self):
@ -59,7 +62,9 @@ class TestCustom(unittest.TestCase):
img_interface = Interface(
max_pixel_value, "image", "label", interpretation=custom
)
result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0][0]
result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0][
"interpretation"
]
expected_result = np.asarray(
decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert("RGB")
).tolist()

View File

@ -28,6 +28,7 @@
targets: Array<number>;
inputs: Array<string>;
outputs: Array<string>;
queue: boolean;
}
export let root: string;
@ -114,7 +115,7 @@
async function handle_mount({ detail }) {
await tick();
dependencies.forEach(({ targets, trigger, inputs, outputs }, i) => {
dependencies.forEach(({ targets, trigger, inputs, outputs, queue }, i) => {
const target_instances: [number, Instance][] = targets.map((t) => [
t,
instance_map[t]
@ -125,10 +126,15 @@
if (handled_dependencies[i]?.includes(id) || !instance) return;
// console.log(trigger, target_instances, instance);
instance?.$on(trigger, () => {
fn("predict", {
fn_index: i,
data: inputs.map((id) => instance_map[id].value)
}).then((output) => {
fn(
"predict",
{
fn_index: i,
data: inputs.map((id) => instance_map[id].value)
},
queue,
() => {}
).then((output) => {
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
});

View File

@ -1,4 +1,3 @@
export { default as Component } from "./Audio.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export { loadAsFile } from "../utils/example_processors";
export const modes = ["static", "dynamic"];

View File

@ -1,3 +1,2 @@
export { default as Component } from "./Checkbox.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export const modes = ["static", "dynamic"];

View File

@ -1,3 +1,2 @@
export { default as Component } from "./CheckboxGroup.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export const modes = ["static", "dynamic"];

View File

@ -1,7 +1,10 @@
<script lang="ts">
export let value: boolean = true;
export let style: string = "";
if ($$props.default) value = $$props.default;
</script>
<div {style} class="flex flex-1 flex-col gap-4">
<div {style} class:hidden={!value} class="flex flex-1 flex-col gap-4">
<slot />
</div>

View File

@ -1,3 +1,2 @@
export { default as Component } from "./Dropdown.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export const modes = ["dynamic"];

View File

@ -1,5 +1,4 @@
export { default as Component } from "./Image.svelte";
export { default as ExampleComponent } from "../Dataset/ExampleComponents/Image.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export { loadAsData } from "../utils/example_processors";
export const modes = ["static", "dynamic"];

View File

@ -0,0 +1,21 @@
<script lang="ts">
import { component_map } from "./directory";
export let component: string;
export let component_props: Record<string, any>;
export let value: any;
export let theme: string;
</script>
{#if value}
<svelte:component
this={component_map[component]}
{theme}
{...component_props}
original={value.original}
interpretation={value.interpretation}
/>
{/if}
<style lang="postcss" global>
</style>

View File

@ -1,17 +1,10 @@
<script lang="ts">
import type { AudioData } from "@gradio/audio";
import { getSaliencyColor } from "../utils/helpers";
export let value: AudioData;
import { getSaliencyColor } from "../utils";
export let interpretation: Array<number>;
export let theme: string;
export let style: string | null;
</script>
<div class="input-audio" {theme} {style}>
<audio class="w-full" controls>
<source src={value.data} />
</audio>
<div class="input-audio" {theme}>
<div class="interpret_range flex">
{#each interpretation as interpret_value}
<div

View File

@ -1,16 +1,15 @@
<script lang="ts">
import { getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor } from "../utils";
export let value: boolean;
export let original: boolean;
export let interpretation: [number, number];
export let theme: string;
export let style: string | null;
</script>
<div class="input-checkbox inline-block" {theme} {style}>
<div class="input-checkbox inline-block" {theme}>
<button
class="checkbox-item py-2 px-3 rounded cursor-pointer flex gap-1"
class:selected={value}
class:selected={original}
>
<div
class="checkbox w-4 h-4 bg-white flex items-center justify-center border border-gray-400 box-border"

View File

@ -1,19 +1,18 @@
<script lang="ts">
import { getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor } from "../utils";
export let value: Array<string>;
export let original: Array<string>;
export let interpretation: Array<[number, number]>;
export let choices: Array<string>;
export let theme: string;
export let style: string | null;
</script>
<div class="input-checkbox-group flex flex-wrap gap-2" {theme} {style}>
<div class="input-checkbox-group flex flex-wrap gap-2" {theme}>
{#each choices as choice, i}
<button
class="checkbox-item py-2 px-3 font-semibold rounded cursor-pointer flex items-center gap-1"
class:selected={value.includes(choice)}
class:selected={original.includes(choice)}
>
<div
class="checkbox w-4 h-4 bg-white flex items-center justify-center border border-gray-400 box-border"

View File

@ -1,14 +1,13 @@
<script lang="ts">
import { getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor } from "../utils";
export let value: string;
export let original: string;
export let interpretation: Array<number>;
export let theme: string;
export let style: string | null;
export let choices: Array<string>;
</script>
<div class="input-dropdown" {theme} {style}>
<div class="input-dropdown" {theme}>
<ul class="dropdown-menu">
{#each choices as choice, i}
<li

View File

@ -1,8 +1,8 @@
<script lang="ts">
import { getObjectFitSize, getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor, getObjectFitSize } from "../utils";
import { afterUpdate } from "svelte";
export let value: string;
export let original: string;
export let interpretation: Array<Array<number>>;
export let shape: undefined | [number, number];
@ -70,7 +70,11 @@
<canvas bind:this={saliency_layer} />
</div>
<!-- svelte-ignore a11y-missing-attribute -->
<img class="w-full h-full object-contain" bind:this={image} src={value} />
<img
class="w-full h-full object-contain"
bind:this={image}
src={original}
/>
</div>
</div>

View File

@ -1,10 +1,9 @@
<script lang="ts">
import { getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor } from "../utils";
export let value: string;
export let original: string;
export let interpretation: Array<[number, number]>;
export let theme: string;
export let style: string | null;
</script>
<div class="input-number">

View File

@ -1,18 +1,17 @@
<script lang="ts">
import { getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor } from "../utils";
export let value: string;
export let original: string;
export let interpretation: Array<number>;
export let theme: string;
export let style: string | null;
export let choices: Array<string>;
</script>
<div class="input-radio flex flex-wrap gap-2" {theme} {style}>
<div class="input-radio flex flex-wrap gap-2" {theme}>
{#each choices as choice, i}
<button
class="radio-item py-2 px-3 font-semibold rounded cursor-pointer flex items-center gap-2"
class:selected={value === choice}
class:selected={original === choice}
>
<div
class="radio-circle w-4 h-4 rounded-full box-border"

View File

@ -1,21 +1,20 @@
<script lang="ts">
import { getSaliencyColor } from "../utils/helpers";
import { getSaliencyColor } from "../utils";
export let value: number;
export let original: number;
export let interpretation: Array<number>;
export let theme: string;
export let style: string | null;
export let minimum: number;
export let maximum: number;
export let step: number;
</script>
<div class="input-slider text-center" {theme} {style}>
<div class="input-slider text-center" {theme}>
<input
type="range"
class="range w-full appearance-none transition rounded h-4"
disabled
{value}
{original}
min={minimum}
max={maximum}
{step}
@ -28,7 +27,9 @@
/>
{/each}
</div>
<div class="value inline-block mx-auto mt-1 px-2 py-0.5 rounded">{value}</div>
<div class="original inline-block mx-auto mt-1 px-2 py-0.5 rounded">
{original}
</div>
</div>
<style lang="postcss">

View File

@ -1,16 +1,11 @@
<script lang="ts">
import { getSaliencyColor } from "../../utils/helpers";
import { getSaliencyColor } from "../utils";
export let interpretation: Array<[string, number]>;
export let theme: string;
export let style: string | null;
</script>
<div
class="input-text w-full rounded box-border p-2 break-word"
{theme}
{style}
>
<div class="input-text w-full rounded box-border p-2 break-word" {theme}>
{#each interpretation as [text, saliency]}
<span
class="textspan p-1 bg-opacity-20 dark:bg-opacity-80"

View File

@ -0,0 +1,21 @@
import InterpretationNumber from "./InterpretationComponents/Number.svelte";
import InterpretationDropdown from "./InterpretationComponents/Dropdown.svelte";
import InterpretationCheckbox from "./InterpretationComponents/Checkbox.svelte";
import InterpretationCheckboxGroup from "./InterpretationComponents/CheckboxGroup.svelte";
import InterpretationSlider from "./InterpretationComponents/Slider.svelte";
import InterpretationRadio from "./InterpretationComponents/Radio.svelte";
import InterpretationImage from "./InterpretationComponents/Image.svelte";
import InterpretationAudio from "./InterpretationComponents/Audio.svelte";
import InterpretationTextbox from "./InterpretationComponents/Textbox.svelte";
export const component_map = {
audio: InterpretationAudio,
dropdown: InterpretationDropdown,
checkbox: InterpretationCheckbox,
checkboxgroup: InterpretationCheckboxGroup,
number: InterpretationNumber,
slider: InterpretationSlider,
radio: InterpretationRadio,
image: InterpretationImage,
textbox: InterpretationTextbox
};

View File

@ -0,0 +1,2 @@
export { default as Component } from "./Interpretation.svelte";
export const modes = ["dynamic"];

View File

@ -0,0 +1,59 @@
export const getSaliencyColor = (value: number): string => {
var color: [number, number, number] | null = null;
if (value < 0) {
color = [52, 152, 219];
} else {
color = [231, 76, 60];
}
return colorToString(interpolate(Math.abs(value), [255, 255, 255], color));
};
const interpolate = (
val: number,
rgb1: [number, number, number],
rgb2: [number, number, number]
): [number, number, number] => {
if (val > 1) {
val = 1;
}
val = Math.sqrt(val);
var rgb: [number, number, number] = [0, 0, 0];
var i;
for (i = 0; i < 3; i++) {
rgb[i] = Math.round(rgb1[i] * (1.0 - val) + rgb2[i] * val);
}
return rgb;
};
const colorToString = (rgb: [number, number, number]): string => {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
};
export const getObjectFitSize = (
contains: boolean /* true = contain, false = cover */,
containerWidth: number,
containerHeight: number,
width: number,
height: number
) => {
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? doRatio > cRatio : doRatio < cRatio;
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
};

View File

@ -21,6 +21,6 @@
$: value, dispatch("change");
</script>
{#if value !== undefined}
{#if value !== undefined && value !== null}
<Label {theme} {style} {value} />
{/if}

View File

@ -1,3 +1,2 @@
export { default as Component } from "./Number.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export const modes = ["static", "dynamic"];

View File

@ -1,3 +1,2 @@
export { default as Component } from "./Radio.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export const modes = ["static", "dynamic"];

View File

@ -1,7 +1,10 @@
<script lang="ts">
export let value: boolean;
export let style: string = "";
if ($$props.default_value) value = $$props.default_value;
</script>
<div {style} class="flex flex-row gap-4">
<div {style} class:hidden={!value} class="flex flex-row gap-4">
<slot />
</div>

View File

@ -1,3 +1,2 @@
export { default as Component } from "./Slider.svelte";
export { default as Interpretation } from "./Interpretation.svelte";
export const modes = ["static", "dynamic"];

View File

@ -1,4 +1,4 @@
export const component_map = {
export const component_map: Record<string, any> = {
audio: () => import("./Audio"),
button: () => import("./Button"),
carousel: () => import("./Carousel"),
@ -14,6 +14,7 @@ export const component_map = {
highlightedtext: () => import("./HighlightedText"),
html: () => import("./HTML"),
image: () => import("./Image"),
interpretation: () => import("./Interpretation"),
json: () => import("./Json"),
label: () => import("./Label"),
number: () => import("./Number"),

View File

@ -58,63 +58,3 @@ export const prettyBytes = (bytes: number): string => {
let unit = units[i];
return bytes.toFixed(1) + " " + unit;
};
export const getSaliencyColor = (value: number): string => {
var color: [number, number, number] | null = null;
if (value < 0) {
color = [52, 152, 219];
} else {
color = [231, 76, 60];
}
return colorToString(interpolate(Math.abs(value), [255, 255, 255], color));
};
const interpolate = (
val: number,
rgb1: [number, number, number],
rgb2: [number, number, number]
): [number, number, number] => {
if (val > 1) {
val = 1;
}
val = Math.sqrt(val);
var rgb: [number, number, number] = [0, 0, 0];
var i;
for (i = 0; i < 3; i++) {
rgb[i] = Math.round(rgb1[i] * (1.0 - val) + rgb2[i] * val);
}
return rgb;
};
const colorToString = (rgb: [number, number, number]): string => {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
};
export const getObjectFitSize = (
contains: boolean /* true = contain, false = cover */,
containerWidth: number,
containerHeight: number,
width: number,
height: number
) => {
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? doRatio > cRatio : doRatio < cRatio;
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
};

View File

@ -21,6 +21,7 @@ interface Component {
}
interface Config {
auth_required: boolean | undefined;
allow_flagging: string;
allow_interpretation: boolean;
allow_screenshot: boolean;
@ -78,7 +79,7 @@ window.launchGradio = (config: Config, element_query: string) => {
style.innerHTML = config.css;
document.head.appendChild(style);
}
if (config.detail === "Not authenticated") {
if (config.detail === "Not authenticated" || config.auth_required) {
new Login({
target: target,
props: config
@ -123,7 +124,7 @@ window.launchGradioFromSpaces = async (space: string, target: string) => {
async function get_config() {
if (BUILD_MODE === "dev" || location.origin === "http://localhost:3000") {
let config = await fetch(BACKEND_URL + "/config");
let config = await fetch(BACKEND_URL + "config");
config = await config.json();
return config;
} else {