Refactor component shortcuts (#995)

* custom-components
- create template components
- changes in PKG and requires comes from scripts/install_gradio.sh

* custom-components
- tweaks

* update-components
- tweaks

* update-components
- fix get_block_name

* update-components
- add webcam demo

* custom-components
- make use of get_block_name function whenever possible

* custom-components
- tweaks

* refactor-component-shortcuts
- no description whatsoever :D

* refactor-component-shortcuts
- tweaks

* refactor-component-shortcuts
- create shortcut function "component"

* refactor-component-shortcuts
- reformat

* refactor-component-shortcuts
- tweaks

* refactor-component-shortcuts
- tweaks
This commit is contained in:
Ömer Faruk Özdemir 2022-04-15 00:24:14 +03:00 committed by GitHub
parent ad75b06f9a
commit 8464aa7258
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 149 deletions

View File

@ -1,12 +1,11 @@
import gradio as gr
from gradio import Templates
def greet(name):
return "Hello " + name + "!!"
demo = gr.Interface(fn=greet, inputs=Templates.TextArea(), outputs=Templates.TextArea())
demo = gr.Interface(fn=greet, inputs=gr.component("textarea"), outputs=gr.component("textarea"))
if __name__ == "__main__":
demo.launch()

View File

@ -7,8 +7,7 @@ from gradio import Templates
def snap(image):
return np.flipud(image)
demo = gr.Interface(snap, Templates.Webcam(), gr.Image())
demo = gr.Interface(snap, gr.component("webcam"), gr.component("image"))
if __name__ == "__main__":
demo.launch()

View File

@ -30,6 +30,7 @@ from gradio.components import (
Timeseries,
Variable,
Video,
component,
)
from gradio.flagging import (
CSVLogger,

View File

@ -70,13 +70,6 @@ class Component(Block):
"interactive": self.interactive,
}
@classmethod
def get_shortcut_implementations(cls):
"""
Return dictionary of shortcut implementations
"""
return {}
def save_flagged(
self, dir: str, label: Optional[str], data: Any, encryption_key: bool
) -> Any:
@ -131,12 +124,25 @@ class Component(Block):
return {"name": file, "data": data}
@classmethod
def get_all_shortcut_implementations(cls):
shortcuts = {}
def get_component_shortcut(cls, str_shortcut: str) -> Optional[Component]:
"""
Creates a component, where class name equals to str_shortcut.
@param str_shortcut: string shortcut of a component
@return:
True, found_class or
False, None
"""
# Make it suitable with class names
str_shortcut = str_shortcut.replace("_", "")
for sub_cls in cls.__subclasses__():
for shortcut, parameters in sub_cls.get_shortcut_implementations().items():
shortcuts[shortcut] = (sub_cls, parameters)
return shortcuts
if sub_cls.__name__.lower() == str_shortcut:
return sub_cls()
# For template components
for sub_sub_cls in sub_cls.__subclasses__():
if sub_sub_cls.__name__.lower() == str_shortcut:
return sub_sub_cls()
return None
# Input Functionalities
def preprocess(self, x: Any) -> Any:
@ -276,13 +282,6 @@ class Textbox(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"text": {},
"textbox": {"lines": 7},
}
# Input Functionalities
def preprocess(self, x: str | None) -> Any:
"""
@ -437,12 +436,6 @@ class Number(Component):
def get_template_context(self):
return {"default_value": self.default_value, **super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
return {
"number": {},
}
def preprocess(self, x: float | None) -> Optional[float]:
"""
Parameters:
@ -590,12 +583,6 @@ class Slider(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"slider": {},
}
def preprocess(self, x: float) -> float:
"""
Parameters:
@ -694,12 +681,6 @@ class Checkbox(Component):
def get_template_context(self):
return {"default_value": self.default_value, **super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
return {
"checkbox": {},
}
def preprocess(self, x: bool) -> bool:
"""
Parameters:
@ -1095,21 +1076,6 @@ class Image(Component):
label=label, requires_permissions=requires_permissions, **kwargs
)
@classmethod
def get_shortcut_implementations(cls):
return {
"image": {},
"webcam": {"source": "webcam"},
"sketchpad": {
"image_mode": "L",
"source": "canvas",
"shape": (28, 28),
"invert_colors": True,
},
"plot": {"type": "plot"},
"pil": {"type": "pil"},
}
def get_template_context(self):
return {
"image_mode": self.image_mode,
@ -1410,13 +1376,6 @@ class Video(Component):
self.source = source
super().__init__(label=label, css=css, **kwargs)
@classmethod
def get_shortcut_implementations(cls):
return {
"video": {},
"playable_video": {"type": "mp4"},
}
def get_template_context(self):
return {
"source": self.source,
@ -1592,14 +1551,6 @@ class Audio(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"audio": {},
"microphone": {"source": "microphone"},
"mic": {"source": "microphone"},
}
def preprocess_example(self, x):
return {"name": x, "data": None, "is_example": True}
@ -1899,13 +1850,6 @@ class File(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"file": {},
"files": {"file_count": "multiple"},
}
def preprocess_example(self, x):
return {"name": x, "data": None, "is_example": True}
@ -2081,15 +2025,6 @@ class Dataframe(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"dataframe": {"type": "pandas"},
"numpy": {"type": "numpy"},
"matrix": {"type": "array"},
"list": {"type": "array", "col_count": 1},
}
def preprocess(self, x: List[List[str | Number | bool]]):
"""
Parameters:
@ -2218,12 +2153,6 @@ class Timeseries(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"timeseries": {},
}
def preprocess_example(self, x):
return {"name": x, "is_example": True}
@ -2305,12 +2234,6 @@ class Variable(Component):
def get_template_context(self):
return {"default_value": self.default_value, **super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
return {
"state": {},
}
# Only Output Components
class Label(Component):
@ -2392,12 +2315,6 @@ class Label(Component):
return y
raise ValueError("Unable to deserialize output: {}".format(y))
@classmethod
def get_shortcut_implementations(cls):
return {
"label": {},
}
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (Union[str, Dict[str, number]]): Either a string representing the main category label, or a dictionary with category keys mapping to confidence levels.
@ -2494,12 +2411,6 @@ class HighlightedText(Component):
**super().get_template_context(),
}
@classmethod
def get_shortcut_implementations(cls):
return {
"highlight": {},
}
def postprocess(self, y):
"""
Parameters:
@ -2568,12 +2479,6 @@ class JSON(Component):
else:
return y
@classmethod
def get_shortcut_implementations(cls):
return {
"json": {},
}
def save_flagged(self, dir, label, data, encryption_key):
return json.dumps(data)
@ -2628,12 +2533,6 @@ class HTML(Component):
"""
return x
@classmethod
def get_shortcut_implementations(cls):
return {
"html": {},
}
def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
"""
Parameters:
@ -2762,12 +2661,6 @@ class Chatbot(Component):
def get_template_context(self):
return {"default_value": self.default_value, **super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
return {
"chatbot": {},
}
def postprocess(self, y):
"""
Parameters:
@ -2958,10 +2851,29 @@ class Interpretation(Component):
}
def component(str_shortcut: str) -> (bool, Optional[Component]):
"""
Creates a component, where class name equals to str_shortcut.
@param str_shortcut: string shortcut of a component
@return:
True, found_class or
False, None
"""
component = Component.get_component_shortcut(str_shortcut)
if component is None:
raise ValueError(f"No such component: {str_shortcut}")
else:
return component
def get_component_instance(comp: str | dict | Component):
if isinstance(comp, str):
shortcut = Component.get_all_shortcut_implementations()[comp]
return shortcut[0](**shortcut[1], without_rendering=True)
component = Component.get_component_shortcut(comp)
if component is None:
raise ValueError(f"No such component: {comp}")
else:
return component
elif isinstance(
comp, dict
): # a dict with `name` as the input component type and other keys as parameters

View File

@ -1,12 +1,17 @@
from gradio.components import Audio as C_Audio
from gradio.components import Dataframe as C_Dataframe
from gradio.components import File as C_File
from gradio.components import Image as C_Image
from gradio.components import Textbox as C_Textbox
from gradio.components import Video as C_Video
from gradio import components
class TextArea(C_Textbox):
class Text(components.Textbox):
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
self.is_template = True
super().__init__(lines=1, **kwargs)
class TextArea(components.Textbox):
def __init__(self, **kwargs):
"""
Custom component
@ -16,7 +21,7 @@ class TextArea(C_Textbox):
super().__init__(lines=7, **kwargs)
class Webcam(C_Image):
class Webcam(components.Image):
def __init__(self, **kwargs):
"""
Custom component
@ -26,7 +31,7 @@ class Webcam(C_Image):
super().__init__(source="webcam", **kwargs)
class Sketchpad(C_Image):
class Sketchpad(components.Image):
def __init__(self, **kwargs):
"""
Custom component
@ -42,7 +47,7 @@ class Sketchpad(C_Image):
)
class Plot(C_Image):
class Plot(components.Image):
def __init__(self, **kwargs):
"""
Custom component
@ -52,7 +57,7 @@ class Plot(C_Image):
super().__init__(type="plot", **kwargs)
class Pil(C_Image):
class Pil(components.Image):
def __init__(self, **kwargs):
"""
Custom component
@ -62,7 +67,7 @@ class Pil(C_Image):
super().__init__(type="pil", **kwargs)
class PlayableVideo(C_Video):
class PlayableVideo(components.Video):
def __init__(self, **kwargs):
"""
Custom component
@ -72,7 +77,7 @@ class PlayableVideo(C_Video):
super().__init__(type="mp4", **kwargs)
class Microphone(C_Audio):
class Microphone(components.Audio):
def __init__(self, **kwargs):
"""
Custom component
@ -82,7 +87,17 @@ class Microphone(C_Audio):
super().__init__(source="microphone", **kwargs)
class C_Files(C_File):
class Mic(components.Audio):
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
self.is_template = True
super().__init__(source="microphone", **kwargs)
class Files(components.File):
def __init__(self, **kwargs):
"""
Custom component
@ -92,7 +107,7 @@ class C_Files(C_File):
super().__init__(file_count="multiple", **kwargs)
class Numpy(C_Dataframe):
class Numpy(components.Dataframe):
def __init__(self, **kwargs):
"""
Custom component
@ -102,7 +117,7 @@ class Numpy(C_Dataframe):
super().__init__(type="numpy", **kwargs)
class Matrix(C_Dataframe):
class Matrix(components.Dataframe):
def __init__(self, **kwargs):
"""
Custom component
@ -112,7 +127,7 @@ class Matrix(C_Dataframe):
super().__init__(type="array", **kwargs)
class List(C_Dataframe):
class List(components.Dataframe):
def __init__(self, **kwargs):
"""
Custom component
@ -120,3 +135,13 @@ class List(C_Dataframe):
"""
self.is_template = True
super().__init__(type="array", col_count=1, **kwargs)
class Highlight(components.HighlightedText):
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
self.is_template = True
super().__init__(**kwargs)