Enable multi-select on gradio.Dropdown (#2871)

* multiselect dropdown

* fixes

* more fixes

* changes

* changelog

* formatting

* format notebooks

* type fixes

* notebok fix

* remove console log

* notebook fix

* type fix

* Revert "format notebooks"

This reverts commit fb8762ecffbc727425d435262e60be4ea7feec6e.

* notebook fix

* bug fixes

* Update CHANGELOG.md

* Excluding untracked files from demo notebook check action (#2897)

* excluding untracked files from wget

* changelog

* fix setting default values

* typeability and arrow key support

* python types

* reformat

* another type check

* minor fixes + interactive false fix

* change remove token styling

* separate multiselect into separate file

* style fixes

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* some more style fixes

* small bug fix

* addressed pr comments

* fix active color highlighting

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Dawood Khan 2023-01-04 19:13:46 -05:00 committed by GitHub
parent c02001da7d
commit 9fff1e0fe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 391 additions and 34 deletions

View File

@ -2,7 +2,7 @@
## New Features:
* Send custom progress updates by adding a `gr.Progress` argument after the input arguments to any function. Example:
### Send custom progress updates by adding a `gr.Progress` argument after the input arguments to any function. Example:
```python
def reverse(word, progress=gr.Progress()):
@ -21,6 +21,15 @@ Progress indicator bar by [@aliabid94](https://github.com/aliabid94) in [PR 2750
* Added `title` argument to `TabbedInterface` by @MohamedAliRashad in [#2888](https://github.com/gradio-app/gradio/pull/2888)
* Add support for specifying file extensions for `gr.File` and `gr.UploadButton`, using `file_types` parameter (e.g `gr.File(file_count="multiple", file_types=["text", ".json", ".csv"])`) by @dawoodkhan82 in [#2901](https://github.com/gradio-app/gradio/pull/2901)
* Added `multiselect` option to `Dropdown` by @dawoodkhan82 in [#2871](https://github.com/gradio-app/gradio/pull/2871)
### With `multiselect` set to `true` a user can now select multiple options from the `gr.Dropdown` component.
```python
gr.Dropdown(["angola", "pakistan", "canada"], multiselect=True, value=["angola"])
```
<img width="610" alt="Screenshot 2023-01-03 at 4 14 36 PM" src="https://user-images.githubusercontent.com/12725292/210442547-c86975c9-4b4f-4b8e-8803-9d96e6a8583a.png">
## Bug Fixes:
* Fixed bug where an error opening an audio file led to a crash by [@FelixDombek](https://github.com/FelixDombek) in [PR 2898](https://github.com/gradio-app/gradio/pull/2898)

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: sentence_builder"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "\n", "def sentence_builder(quantity, animal, place, activity_list, morning):\n", " return f\"\"\"The {quantity} {animal}s went to the {place} where they {\" and \".join(activity_list)} until the {\"morning\" if morning else \"night\"}\"\"\"\n", "\n", "\n", "demo = gr.Interface(\n", " sentence_builder,\n", " [\n", " gr.Slider(2, 20, value=4),\n", " gr.Dropdown([\"cat\", \"dog\", \"bird\"]),\n", " gr.Radio([\"park\", \"zoo\", \"road\"]),\n", " gr.CheckboxGroup([\"ran\", \"swam\", \"ate\", \"slept\"]),\n", " gr.Checkbox(label=\"Is it the morning?\"),\n", " ],\n", " \"text\",\n", " examples=[\n", " [2, \"cat\", \"park\", [\"ran\", \"swam\"], True],\n", " [4, \"dog\", \"zoo\", [\"ate\", \"swam\"], False],\n", " [10, \"bird\", \"road\", [\"ran\"], False],\n", " [8, \"cat\", \"zoo\", [\"ate\"], True],\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: sentence_builder"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "\n", "def sentence_builder(quantity, animal, place, activity_list, morning):\n", " return f\"\"\"The {quantity} {animal}s went to the {place} where they {\" and \".join(activity_list)} until the {\"morning\" if morning else \"night\"}\"\"\"\n", "\n", "\n", "demo = gr.Interface(\n", " sentence_builder,\n", " [\n", " gr.Slider(2, 20, value=4),\n", " gr.Dropdown([\"cat\", \"dog\", \"bird\"]),\n", " gr.Radio([\"park\", \"zoo\", \"road\"]),\n", " gr.Dropdown([\"ran\", \"swam\", \"ate\", \"slept\"], value=[\"swam\", \"slept\"], multiselect=True),\n", " gr.Checkbox(label=\"Is it the morning?\"),\n", " ],\n", " \"text\",\n", " examples=[\n", " [2, \"cat\", \"park\", [\"ran\", \"swam\"], True],\n", " [4, \"dog\", \"zoo\", [\"ate\", \"swam\"], False],\n", " [10, \"bird\", \"road\", [\"ran\"], False],\n", " [8, \"cat\", \"zoo\", [\"ate\"], True],\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -11,7 +11,7 @@ demo = gr.Interface(
gr.Slider(2, 20, value=4),
gr.Dropdown(["cat", "dog", "bird"]),
gr.Radio(["park", "zoo", "road"]),
gr.CheckboxGroup(["ran", "swam", "ate", "slept"]),
gr.Dropdown(["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True),
gr.Checkbox(label="Is it the morning?"),
],
"text",

View File

@ -1167,9 +1167,9 @@ class Radio(
@document("change", "style")
class Dropdown(Radio):
class Dropdown(Changeable, IOComponent, SimpleSerializable, FormComponent):
"""
Creates a dropdown of which only one entry can be selected.
Creates a dropdown of choices from which entries can be selected.
Preprocessing: passes the value of the selected dropdown entry as a {str} or its index as an {int} into the function, depending on `type`.
Postprocessing: expects a {str} corresponding to the value of the dropdown entry to be selected.
Examples-format: a {str} representing the drop down value to select.
@ -1178,10 +1178,11 @@ class Dropdown(Radio):
def __init__(
self,
choices: List[str] | None = None,
choices: str | List[str] | None = None,
*,
value: str | Callable | None = None,
value: str | List[str] | Callable | None = None,
type: str = "value",
multiselect: bool | None = None,
label: str | None = None,
every: float | None = None,
show_label: bool = True,
@ -1193,8 +1194,9 @@ class Dropdown(Radio):
"""
Parameters:
choices: list of options to select from.
value: default value selected in dropdown. If None, no value is selected by default. If callable, the function will be called whenever the app loads to set the initial value of the component.
value: default value(s) selected in dropdown. If None, no value is selected by default. If callable, the function will be called whenever the app loads to set the initial value of the component.
type: Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
multiselect: if True, multiple choices can be selected.
label: component name in interface.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
show_label: if True, will display label.
@ -1202,19 +1204,109 @@ class Dropdown(Radio):
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
Radio.__init__(
self.choices = choices or []
valid_types = ["value", "index"]
if type not in valid_types:
raise ValueError(
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
)
self.type = type
self.multiselect = multiselect
if multiselect:
if isinstance(value, str):
value = [value]
self.test_input = self.choices[0] if len(self.choices) else None
self.interpret_by_tokens = False
IOComponent.__init__(
self,
value=value,
choices=choices,
type=type,
label=label,
every=every,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
value=value,
**kwargs,
)
self.cleared_value = self.value
def get_config(self):
return {
"choices": self.choices,
"value": self.value,
"multiselect": self.multiselect,
**IOComponent.get_config(self),
}
@staticmethod
def update(
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
choices: str | List[str] | None = None,
label: str | None = None,
show_label: bool | None = None,
interactive: bool | None = None,
visible: bool | None = None,
):
updated_config = {
"choices": choices,
"label": label,
"show_label": show_label,
"interactive": interactive,
"visible": visible,
"value": value,
"__type__": "update",
}
return IOComponent.add_interactive_to_config(updated_config, interactive)
def generate_sample(self):
return self.choices[0]
def preprocess(
self, x: str | List[str]
) -> str | int | List[str] | List[int] | None:
"""
Parameters:
x: selected choice(s)
Returns:
selected choice(s) as string or index within choice list or list of string or indices
"""
if self.type == "value":
return x
elif self.type == "index":
if x is None:
return None
elif self.multiselect:
return [self.choices.index(c) for c in x]
else:
if isinstance(x, str):
return self.choices.index(x)
else:
raise ValueError(
"Unknown type: "
+ str(self.type)
+ ". Please choose from: 'value', 'index'."
)
def set_interpret_parameters(self):
"""
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
"""
return self
def get_interpretation_neighbors(self, x):
choices = list(self.choices)
choices.remove(x)
return choices, {}
def get_interpretation_scores(
self, x, neighbors, scores: List[float | None], **kwargs
) -> List:
"""
Returns:
Each value represents the interpretation score corresponding to each choice.
"""
scores.insert(self.choices.index(x), None)
return scores
def style(self, *, container: bool | None = None, **kwargs):
"""

View File

@ -554,6 +554,56 @@ class TestRadio:
assert scores == [-2.0, None, 2.0]
class TestDropdown:
def test_component_functions(self):
"""
Preprocess, postprocess, serialize, generate_sample, get_config
"""
dropdown_input = gr.Dropdown(["a", "b", "c"], multiselect=True)
assert dropdown_input.preprocess("a") == "a"
assert dropdown_input.postprocess("a") == "a"
dropdown_input_multiselect = gr.Dropdown(["a", "b", "c"], multiselect=True)
assert dropdown_input_multiselect.preprocess(["a", "c"]) == ["a", "c"]
assert dropdown_input_multiselect.postprocess(["a", "c"]) == ["a", "c"]
assert dropdown_input_multiselect.serialize(["a", "c"], True) == ["a", "c"]
assert isinstance(dropdown_input_multiselect.generate_sample(), str)
dropdown_input_multiselect = gr.Dropdown(
value=["a", "c"],
choices=["a", "b", "c"],
label="Select Your Inputs",
)
assert dropdown_input_multiselect.get_config() == {
"choices": ["a", "b", "c"],
"value": ["a", "c"],
"name": "dropdown",
"show_label": True,
"label": "Select Your Inputs",
"style": {},
"elem_id": None,
"visible": True,
"interactive": None,
"root_url": None,
"multiselect": None,
}
with pytest.raises(ValueError):
gr.Dropdown(["a"], type="unknown")
dropdown = gr.Dropdown(choices=["a", "b"], value="c")
assert dropdown.get_config()["value"] == "c"
assert dropdown.postprocess("a") == "a"
def test_in_interface(self):
"""
Interface, process
"""
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
assert iface(["a", "c"]) == "a|c"
assert iface([]) == ""
_ = gr.CheckboxGroup(["a", "b", "c"], type="index")
class TestImage:
def test_component_functions(self):
"""

View File

@ -12,7 +12,7 @@
<div
id={elem_id}
class="overflow-hidden flex flex-col relative col {create_classes(style)}"
class="flex flex-col relative col {create_classes(style)}"
class:gap-4={style.gap !== false}
class:gr-compact={variant === "compact"}
class:gr-panel={variant === "panel"}

View File

@ -8,7 +8,8 @@
export let label: string = "Dropdown";
export let elem_id: string = "";
export let visible: boolean = true;
export let value: string = "";
export let value: string | Array<string> = [];
export let multiselect: boolean = false;
export let choices: Array<string>;
export let show_label: boolean;
export let style: Styles = {};
@ -27,6 +28,7 @@
<Dropdown
bind:value
{choices}
{multiselect}
{label}
{show_label}
on:change

View File

@ -1,5 +1,5 @@
<div
class="gr-form overflow-hidden flex border-solid border bg-gray-200 dark:bg-gray-700 gap-px rounded-lg flex-wrap"
class="gr-form flex border-solid border bg-gray-200 dark:bg-gray-700 gap-px rounded-lg flex-wrap"
style="flex-direction: inherit"
>
<slot />

View File

@ -45,9 +45,9 @@
data-testid={test_id}
id={elem_id}
class:!hidden={visible === false}
class="gr-block gr-box relative w-full overflow-hidden {styles[
variant
]} {styles[color]} {classes}"
class="gr-block gr-box relative w-full {styles[variant]} {styles[
color
]} {classes}"
class:gr-padded={padding}
style={size_style || null}
>

View File

@ -1,27 +1,36 @@
<script lang="ts">
import { createEventDispatcher } from "svelte";
import MultiSelect from "./MultiSelect.svelte";
import { BlockTitle } from "@gradio/atoms";
export let label: string;
export let value: string | undefined = undefined;
export let value: string | Array<string> | undefined = undefined;
export let multiselect: boolean = false;
export let choices: Array<string>;
export let disabled: boolean = false;
export let show_label: boolean;
const dispatch = createEventDispatcher<{ change: string }>();
$: dispatch("change", value);
</script>
<!-- svelte-ignore a11y-label-has-associated-control -->
<label>
<BlockTitle {show_label}>{label}</BlockTitle>
<select
class="gr-box gr-input w-full disabled:cursor-not-allowed"
bind:value
{disabled}
>
{#each choices as choice}
<option>{choice}</option>
{/each}
</select>
{#if !multiselect}
<select
class="gr-box gr-input w-full disabled:cursor-not-allowed"
bind:value
{disabled}
>
{#each choices as choice}
<option>{choice}</option>
{/each}
</select>
{:else}
<MultiSelect
bind:value
{choices}
{multiselect}
{label}
{show_label}
on:change
{disabled}
/>
{/if}
</label>

View File

@ -0,0 +1,195 @@
<script lang="ts">
import { createEventDispatcher } from "svelte";
import { fly } from "svelte/transition";
export let placeholder = "";
export let label: string;
export let value: string | Array<string> | undefined = undefined;
export let multiselect: boolean = false;
export let choices: Array<string>;
export let disabled: boolean = false;
export let show_label: boolean;
const dispatch = createEventDispatcher<{
change: string | Array<string> | undefined;
}>();
let inputValue: string,
activeOption: string,
showOptions = false;
$: filtered = choices.filter((o) =>
inputValue ? o.toLowerCase().includes(inputValue.toLowerCase()) : o
);
$: if (
(activeOption && !filtered.includes(activeOption)) ||
(!activeOption && inputValue)
)
activeOption = filtered[0];
const iconClearPath =
"M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z";
const iconCheckMarkPath =
"M 2.328125 4.222656 L 27.734375 4.222656 L 27.734375 24.542969 L 2.328125 24.542969 Z M 2.328125 4.222656";
function add(option: string) {
if (Array.isArray(value)) {
value.push(option);
dispatch("change", value);
}
value = value;
}
function remove(option: string) {
if (Array.isArray(value)) {
value = value.filter((v: string) => v !== option);
dispatch("change", value);
}
}
function optionsVisibility(show: boolean) {
if (typeof show === "boolean") {
showOptions = show;
} else {
showOptions = !showOptions;
}
}
function handleBlur(e: any) {
optionsVisibility(false);
}
function handleKeyup(e: any) {
if (e.key === "Enter") {
if (Array.isArray(value) && activeOption != undefined) {
value.includes(activeOption) ? remove(activeOption) : add(activeOption);
inputValue = "";
}
}
if (e.key === "ArrowUp" || e.key === "ArrowDown") {
const increment = e.key === "ArrowUp" ? -1 : 1;
const calcIndex = filtered.indexOf(activeOption) + increment;
activeOption =
calcIndex < 0
? filtered[filtered.length - 1]
: calcIndex === filtered.length
? filtered[0]
: filtered[calcIndex];
}
}
function handleTokenClick(e: any) {
e.preventDefault();
if (e.target.closest(".token-remove")) {
e.stopPropagation();
remove(
e.target.closest(".token").getElementsByTagName("span")[0].textContent
);
} else if (e.target.closest(".remove-all")) {
value = [];
inputValue = "";
} else {
optionsVisibility(true);
}
}
function handleOptionMousedown(e: any) {
const option = e.target.dataset.value;
inputValue = "";
if (option !== undefined) {
if (value?.includes(option)) {
remove(option);
} else {
add(option);
}
}
}
</script>
<div class="relative border rounded-md">
<div
class="items-center flex flex-wrap relative"
class:showOptions
on:click={handleTokenClick}
>
{#if Array.isArray(value)}
{#each value as s}
<div
class="token gr-input-label flex items-center text-gray-700 text-sm space-x-2 border py-1.5 px-3 rounded-lg cursor-pointer bg-white shadow-sm checked:shadow-inner my-1 mx-1"
>
<div
class:hidden={disabled}
class="token-remove items-center bg-gray-400 dark:bg-gray-700 rounded-full fill-white flex justify-center min-w-min p-0.5"
title="Remove {s}"
>
<svg
class="icon-clear"
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
viewBox="0 0 24 24"
>
<path d={iconClearPath} />
</svg>
</div>
<span>{s}</span>
</div>
{/each}
{/if}
<div class="items-center flex flex-1 min-w-min border-none">
<input
class="border-none bg-inherit ml-2 text-lg w-full outline-none text-gray-700 dark:text-white disabled:cursor-not-allowed"
{disabled}
autocomplete="off"
bind:value={inputValue}
on:blur={handleBlur}
on:keyup={handleKeyup}
{placeholder}
/>
<div
class:hidden={!value?.length || disabled}
class="remove-all items-center bg-gray-400 dark:bg-gray-700 rounded-full fill-white flex justify-center h-5 ml-1 min-w-min disabled:hidden p-0.5"
title="Remove All"
>
<svg
class="icon-clear"
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
viewBox="0 0 24 24"
>
<path d={iconClearPath} />
</svg>
</div>
<svg
class="dropdown-arrow mr-2 min-w-min fill-gray-500 dark:fill-white"
xmlns="http://www.w3.org/2000/svg"
width="18"
height="18"
viewBox="0 0 18 18"><path d="M5 8l4 4 4-4z" /></svg
>
</div>
</div>
{#if showOptions && !disabled}
<ul
class="z-50 shadow ml-0 list-none max-h-32 overflow-auto absolute w-full fill-gray-500 bg-white dark:bg-gray-700 dark:text-white rounded-md"
transition:fly={{ duration: 200, y: 5 }}
on:mousedown|preventDefault={handleOptionMousedown}
>
{#each filtered as choice}
<li
class="cursor-pointer flex p-2 hover:bg-gray-100 dark:hover:bg-gray-600"
class:selected={value?.includes(choice)}
class:active={activeOption === choice}
class:bg-gray-100={activeOption === choice}
class:dark:bg-gray-600={activeOption === choice}
data-value={choice}
>
<span class:invisible={!value?.includes(choice)} class="pr-1"></span>
{choice}
</li>
{/each}
</ul>
{/if}
</div>