Allow gr.Templates to take in arguments (#2600)

* template customization

* changelog

* fixes

* import

* added test

* formatting

* explicit parameters

* updated changelog

* fix typing

* fix test
This commit is contained in:
Abubakar Abid 2022-11-04 09:08:17 -07:00 committed by GitHub
parent 85e5fd0f62
commit 218fb9fa65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 459 additions and 80 deletions

View File

@ -16,7 +16,7 @@ No changes to highlight.
No changes to highlight.
## Full Changelog:
No changes to highlight.
* Allow `gr.Templates` to accept parameters to override the defaults by [@abidlabs](https://github.com/abidlabs) in [PR 2600](https://github.com/gradio-app/gradio/pull/2600)
## Contributors Shoutout:
No changes to highlight.
@ -57,6 +57,7 @@ No changes to highlight.
## Full Changelog:
* Add `api_name` to `Blocks.__call__` by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2593](https://github.com/gradio-app/gradio/pull/2593)
## Contributors Shoutout:
No changes to highlight.

View File

@ -23,6 +23,7 @@ from gradio.components import (
Dropdown,
File,
Gallery,
Highlight,
Highlightedtext,
HighlightedText,
Image,
@ -37,6 +38,7 @@ from gradio.components import (
Slider,
State,
StatusTracker,
Text,
Textbox,
TimeSeries,
Timeseries,
@ -60,7 +62,6 @@ from gradio.mix import Parallel, Series
from gradio.routes import mount_gradio_app
from gradio.templates import (
Files,
Highlight,
ImageMask,
ImagePaint,
List,
@ -72,7 +73,6 @@ from gradio.templates import (
Pil,
PlayableVideo,
Sketchpad,
Text,
TextArea,
Webcam,
)

View File

@ -3980,8 +3980,10 @@ def get_component_instance(comp: str | dict | Component, render=True) -> Compone
)
Text = Textbox
DataFrame = Dataframe
Highlightedtext = HighlightedText
Highlight = HighlightedText
Checkboxgroup = CheckboxGroup
TimeSeries = Timeseries
Json = JSON

View File

@ -1,17 +1,14 @@
from __future__ import annotations
import typing
from typing import Any, Callable, Optional, Tuple
import numpy as np
import PIL
from gradio import components
class Text(components.Textbox):
"""
Sets: lines=1
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(lines=1, **kwargs)
class TextArea(components.Textbox):
"""
Sets: lines=7
@ -19,73 +16,256 @@ class TextArea(components.Textbox):
is_template = True
def __init__(self, **kwargs):
super().__init__(lines=7, **kwargs)
def __init__(
self,
value: Optional[str | Callable] = "",
*,
lines: int = 7,
max_lines: int = 20,
placeholder: Optional[str] = None,
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
elem_id: Optional[str] = None,
**kwargs,
):
super().__init__(
value=value,
lines=lines,
max_lines=max_lines,
placeholder=placeholder,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
**kwargs,
)
class Webcam(components.Image):
"""
Sets: source="webcam"
Sets: source="webcam", interactive=True
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="webcam", interactive=True, **kwargs)
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
*,
shape: Tuple[int, int] = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "webcam",
tool: str = None,
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
value=value,
shape=shape,
image_mode=image_mode,
invert_colors=invert_colors,
source=source,
tool=tool,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visibile=visible,
streaming=streaming,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
class Sketchpad(components.Image):
"""
Sets: image_mode="L", source="canvas", shape=(28, 28), invert_colors=True
Sets: image_mode="L", source="canvas", shape=(28, 28), invert_colors=True, interactive=True
"""
is_template = True
def __init__(self, **kwargs):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
*,
shape: Tuple[int, int] = (28, 28),
image_mode: str = "L",
invert_colors: bool = True,
source: str = "canvas",
tool: str = None,
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
image_mode="L",
source="canvas",
shape=(28, 28),
invert_colors=True,
interactive=True,
**kwargs
value=value,
shape=shape,
image_mode=image_mode,
invert_colors=invert_colors,
source=source,
tool=tool,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visibile=visible,
streaming=streaming,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
class Paint(components.Image):
"""
Sets: source="canvas", tool="color-sketch"
Sets: source="canvas", tool="color-sketch", interactive=True
"""
is_template = True
def __init__(self, **kwargs):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
*,
shape: Tuple[int, int] = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "canvas",
tool: str = "color-sketch",
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
source="canvas", tool="color-sketch", interactive=True, **kwargs
value=value,
shape=shape,
image_mode=image_mode,
invert_colors=invert_colors,
source=source,
tool=tool,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visibile=visible,
streaming=streaming,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
class ImageMask(components.Image):
"""
Sets: source="canvas", tool="sketch"
Sets: source="upload", tool="sketch", interactive=True
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
*,
shape: Tuple[int, int] = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "sketch",
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
value=value,
shape=shape,
image_mode=image_mode,
invert_colors=invert_colors,
source=source,
tool=tool,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visibile=visible,
streaming=streaming,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
class ImagePaint(components.Image):
"""
Sets: source="upload", tool="color-sketch"
Sets: source="upload", tool="color-sketch", interactive=True
"""
is_template = True
def __init__(self, **kwargs):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
*,
shape: Tuple[int, int] = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "color-sketch",
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
source="upload", tool="color-sketch", interactive=True, **kwargs
value=value,
shape=shape,
image_mode=image_mode,
invert_colors=invert_colors,
source=source,
tool=tool,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visibile=visible,
streaming=streaming,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
@ -96,8 +276,42 @@ class Pil(components.Image):
is_template = True
def __init__(self, **kwargs):
super().__init__(type="pil", **kwargs)
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
*,
shape: Tuple[int, int] = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = None,
type: str = "pil",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
value=value,
shape=shape,
image_mode=image_mode,
invert_colors=invert_colors,
source=source,
tool=tool,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visibile=visible,
streaming=streaming,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
class PlayableVideo(components.Video):
@ -107,8 +321,32 @@ class PlayableVideo(components.Video):
is_template = True
def __init__(self, **kwargs):
super().__init__(format="mp4", **kwargs)
def __init__(
self,
value: Optional[str | Callable] = None,
*,
format: Optional[str] = "mp4",
source: str = "upload",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
elem_id: Optional[str] = None,
mirror_webcam: bool = True,
**kwargs,
):
super().__init__(
value=value,
format=format,
source=source,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
mirror_webcam=mirror_webcam,
**kwargs,
)
class Microphone(components.Audio):
@ -118,19 +356,32 @@ class Microphone(components.Audio):
is_template = True
def __init__(self, **kwargs):
super().__init__(source="microphone", **kwargs)
class Mic(components.Audio):
"""
Sets: source="microphone"
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="microphone", **kwargs)
def __init__(
self,
value: Optional[str | Tuple[int, np.array] | Callable] = None,
*,
source: str = "microphone",
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
**kwargs,
):
super().__init__(
value=value,
source=source,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
streaming=streaming,
elem_id=elem_id,
**kwargs,
)
class Files(components.File):
@ -140,8 +391,30 @@ class Files(components.File):
is_template = True
def __init__(self, **kwargs):
super().__init__(file_count="multiple", **kwargs)
def __init__(
self,
value: Optional[str | typing.List[str] | Callable] = None,
*,
file_count: str = "multiple",
type: str = "file",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
elem_id: Optional[str] = None,
**kwargs,
):
super().__init__(
value=value,
file_count=file_count,
type=type,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
**kwargs,
)
class Numpy(components.Dataframe):
@ -151,8 +424,44 @@ class Numpy(components.Dataframe):
is_template = True
def __init__(self, **kwargs):
super().__init__(type="numpy", **kwargs)
def __init__(
self,
value: Optional[typing.List[typing.List[Any]] | Callable] = None,
*,
headers: Optional[typing.List[str]] = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: Optional[int | Tuple[int, str]] = None,
datatype: str | typing.List[str] = "str",
type: str = "numpy",
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
overflow_row_behaviour: str = "paginate",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
elem_id: Optional[str] = None,
wrap: bool = False,
**kwargs,
):
super().__init__(
value=value,
headers=headers,
row_count=row_count,
col_count=col_count,
datatype=datatype,
type=type,
max_rows=max_rows,
max_cols=max_cols,
overflow_row_behaviour=overflow_row_behaviour,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
wrap=wrap,
**kwargs,
)
class Matrix(components.Dataframe):
@ -162,35 +471,91 @@ class Matrix(components.Dataframe):
is_template = True
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
super().__init__(type="array", **kwargs)
def __init__(
self,
value: Optional[typing.List[typing.List[Any]] | Callable] = None,
*,
headers: Optional[typing.List[str]] = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: Optional[int | Tuple[int, str]] = None,
datatype: str | typing.List[str] = "str",
type: str = "array",
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
overflow_row_behaviour: str = "paginate",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
elem_id: Optional[str] = None,
wrap: bool = False,
**kwargs,
):
super().__init__(
value=value,
headers=headers,
row_count=row_count,
col_count=col_count,
datatype=datatype,
type=type,
max_rows=max_rows,
max_cols=max_cols,
overflow_row_behaviour=overflow_row_behaviour,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
wrap=wrap,
**kwargs,
)
class List(components.Dataframe):
"""
Sets: type="array"
Sets: type="array", col_count=1
"""
is_template = True
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
super().__init__(type="array", col_count=1, **kwargs)
def __init__(
self,
value: Optional[typing.List[typing.List[Any]] | Callable] = None,
*,
headers: Optional[typing.List[str]] = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: Optional[int | Tuple[int, str]] = 1,
datatype: str | typing.List[str] = "str",
type: str = "array",
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
overflow_row_behaviour: str = "paginate",
label: Optional[str] = None,
show_label: bool = True,
interactive: Optional[bool] = None,
visible: bool = True,
elem_id: Optional[str] = None,
wrap: bool = False,
**kwargs,
):
super().__init__(
value=value,
headers=headers,
row_count=row_count,
col_count=col_count,
datatype=datatype,
type=type,
max_rows=max_rows,
max_cols=max_cols,
overflow_row_behaviour=overflow_row_behaviour,
label=label,
show_label=show_label,
interactive=interactive,
visible=visible,
elem_id=elem_id,
wrap=wrap,
**kwargs,
)
class Highlight(components.HighlightedText):
is_template = True
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
super().__init__(**kwargs)
Mic = Microphone

View File

@ -33,7 +33,7 @@ class TestComponent(unittest.TestCase):
"""
component
"""
assert isinstance(gr.components.component("text"), gr.templates.Text)
assert isinstance(gr.components.component("textarea"), gr.templates.TextArea)
def test_raise_warnings():
@ -165,6 +165,17 @@ class TestTextbox(unittest.TestCase):
component = gr.Textbox("abc")
self.assertEqual(component.get_config().get("value"), "abc")
def test_override_template(self):
"""
override template
"""
component = gr.TextArea(value="abc")
self.assertEqual(component.get_config().get("value"), "abc")
self.assertEqual(component.get_config().get("lines"), 7)
component = gr.TextArea(value="abc", lines=4)
self.assertEqual(component.get_config().get("value"), "abc")
self.assertEqual(component.get_config().get("lines"), 4)
class TestNumber(unittest.TestCase):
def test_component_functions(self):