From 8464aa725868c9d550c54d5945cf72a76d537afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20=C3=96zdemir?= Date: Fri, 15 Apr 2022 00:24:14 +0300 Subject: [PATCH] Refactor component shortcuts (#995) * custom-components - create template components - changes in PKG and requires comes from scripts/install_gradio.sh * custom-components - tweaks * update-components - tweaks * update-components - fix get_block_name * update-components - add webcam demo * custom-components - make use of get_block_name function whenever possible * custom-components - tweaks * refactor-component-shortcuts - no description whatsoever :D * refactor-component-shortcuts - tweaks * refactor-component-shortcuts - create shortcut function "component" * refactor-component-shortcuts - reformat * refactor-component-shortcuts - tweaks * refactor-component-shortcuts - tweaks --- demo/blocks_textarea/run.py | 5 +- demo/blocks_webcam/run.py | 3 +- gradio/__init__.py | 1 + gradio/components.py | 166 +++++++++--------------------------- gradio/templates.py | 59 +++++++++---- 5 files changed, 85 insertions(+), 149 deletions(-) diff --git a/demo/blocks_textarea/run.py b/demo/blocks_textarea/run.py index 9554bfe913..adf142a43d 100644 --- a/demo/blocks_textarea/run.py +++ b/demo/blocks_textarea/run.py @@ -1,12 +1,11 @@ import gradio as gr -from gradio import Templates - def greet(name): return "Hello " + name + "!!" -demo = gr.Interface(fn=greet, inputs=Templates.TextArea(), outputs=Templates.TextArea()) +demo = gr.Interface(fn=greet, inputs=gr.component("textarea"), outputs=gr.component("textarea")) + if __name__ == "__main__": demo.launch() diff --git a/demo/blocks_webcam/run.py b/demo/blocks_webcam/run.py index 8a62ecd8a8..bae9953e0d 100644 --- a/demo/blocks_webcam/run.py +++ b/demo/blocks_webcam/run.py @@ -7,8 +7,7 @@ from gradio import Templates def snap(image): return np.flipud(image) - -demo = gr.Interface(snap, Templates.Webcam(), gr.Image()) +demo = gr.Interface(snap, gr.component("webcam"), gr.component("image")) if __name__ == "__main__": demo.launch() diff --git a/gradio/__init__.py b/gradio/__init__.py index bd65170e20..6f78bbd324 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -30,6 +30,7 @@ from gradio.components import ( Timeseries, Variable, Video, + component, ) from gradio.flagging import ( CSVLogger, diff --git a/gradio/components.py b/gradio/components.py index 1b72832773..d6aedcc6ce 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -70,13 +70,6 @@ class Component(Block): "interactive": self.interactive, } - @classmethod - def get_shortcut_implementations(cls): - """ - Return dictionary of shortcut implementations - """ - return {} - def save_flagged( self, dir: str, label: Optional[str], data: Any, encryption_key: bool ) -> Any: @@ -131,12 +124,25 @@ class Component(Block): return {"name": file, "data": data} @classmethod - def get_all_shortcut_implementations(cls): - shortcuts = {} + def get_component_shortcut(cls, str_shortcut: str) -> Optional[Component]: + """ + Creates a component, where class name equals to str_shortcut. + + @param str_shortcut: string shortcut of a component + @return: + True, found_class or + False, None + """ + # Make it suitable with class names + str_shortcut = str_shortcut.replace("_", "") for sub_cls in cls.__subclasses__(): - for shortcut, parameters in sub_cls.get_shortcut_implementations().items(): - shortcuts[shortcut] = (sub_cls, parameters) - return shortcuts + if sub_cls.__name__.lower() == str_shortcut: + return sub_cls() + # For template components + for sub_sub_cls in sub_cls.__subclasses__(): + if sub_sub_cls.__name__.lower() == str_shortcut: + return sub_sub_cls() + return None # Input Functionalities def preprocess(self, x: Any) -> Any: @@ -276,13 +282,6 @@ class Textbox(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "text": {}, - "textbox": {"lines": 7}, - } - # Input Functionalities def preprocess(self, x: str | None) -> Any: """ @@ -437,12 +436,6 @@ class Number(Component): def get_template_context(self): return {"default_value": self.default_value, **super().get_template_context()} - @classmethod - def get_shortcut_implementations(cls): - return { - "number": {}, - } - def preprocess(self, x: float | None) -> Optional[float]: """ Parameters: @@ -590,12 +583,6 @@ class Slider(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "slider": {}, - } - def preprocess(self, x: float) -> float: """ Parameters: @@ -694,12 +681,6 @@ class Checkbox(Component): def get_template_context(self): return {"default_value": self.default_value, **super().get_template_context()} - @classmethod - def get_shortcut_implementations(cls): - return { - "checkbox": {}, - } - def preprocess(self, x: bool) -> bool: """ Parameters: @@ -1095,21 +1076,6 @@ class Image(Component): label=label, requires_permissions=requires_permissions, **kwargs ) - @classmethod - def get_shortcut_implementations(cls): - return { - "image": {}, - "webcam": {"source": "webcam"}, - "sketchpad": { - "image_mode": "L", - "source": "canvas", - "shape": (28, 28), - "invert_colors": True, - }, - "plot": {"type": "plot"}, - "pil": {"type": "pil"}, - } - def get_template_context(self): return { "image_mode": self.image_mode, @@ -1410,13 +1376,6 @@ class Video(Component): self.source = source super().__init__(label=label, css=css, **kwargs) - @classmethod - def get_shortcut_implementations(cls): - return { - "video": {}, - "playable_video": {"type": "mp4"}, - } - def get_template_context(self): return { "source": self.source, @@ -1592,14 +1551,6 @@ class Audio(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "audio": {}, - "microphone": {"source": "microphone"}, - "mic": {"source": "microphone"}, - } - def preprocess_example(self, x): return {"name": x, "data": None, "is_example": True} @@ -1899,13 +1850,6 @@ class File(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "file": {}, - "files": {"file_count": "multiple"}, - } - def preprocess_example(self, x): return {"name": x, "data": None, "is_example": True} @@ -2081,15 +2025,6 @@ class Dataframe(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "dataframe": {"type": "pandas"}, - "numpy": {"type": "numpy"}, - "matrix": {"type": "array"}, - "list": {"type": "array", "col_count": 1}, - } - def preprocess(self, x: List[List[str | Number | bool]]): """ Parameters: @@ -2218,12 +2153,6 @@ class Timeseries(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "timeseries": {}, - } - def preprocess_example(self, x): return {"name": x, "is_example": True} @@ -2305,12 +2234,6 @@ class Variable(Component): def get_template_context(self): return {"default_value": self.default_value, **super().get_template_context()} - @classmethod - def get_shortcut_implementations(cls): - return { - "state": {}, - } - # Only Output Components class Label(Component): @@ -2392,12 +2315,6 @@ class Label(Component): return y raise ValueError("Unable to deserialize output: {}".format(y)) - @classmethod - def get_shortcut_implementations(cls): - return { - "label": {}, - } - def save_flagged(self, dir, label, data, encryption_key): """ Returns: (Union[str, Dict[str, number]]): Either a string representing the main category label, or a dictionary with category keys mapping to confidence levels. @@ -2494,12 +2411,6 @@ class HighlightedText(Component): **super().get_template_context(), } - @classmethod - def get_shortcut_implementations(cls): - return { - "highlight": {}, - } - def postprocess(self, y): """ Parameters: @@ -2568,12 +2479,6 @@ class JSON(Component): else: return y - @classmethod - def get_shortcut_implementations(cls): - return { - "json": {}, - } - def save_flagged(self, dir, label, data, encryption_key): return json.dumps(data) @@ -2628,12 +2533,6 @@ class HTML(Component): """ return x - @classmethod - def get_shortcut_implementations(cls): - return { - "html": {}, - } - def change(self, fn: Callable, inputs: List[Component], outputs: List[Component]): """ Parameters: @@ -2762,12 +2661,6 @@ class Chatbot(Component): def get_template_context(self): return {"default_value": self.default_value, **super().get_template_context()} - @classmethod - def get_shortcut_implementations(cls): - return { - "chatbot": {}, - } - def postprocess(self, y): """ Parameters: @@ -2958,10 +2851,29 @@ class Interpretation(Component): } +def component(str_shortcut: str) -> (bool, Optional[Component]): + """ + Creates a component, where class name equals to str_shortcut. + + @param str_shortcut: string shortcut of a component + @return: + True, found_class or + False, None + """ + component = Component.get_component_shortcut(str_shortcut) + if component is None: + raise ValueError(f"No such component: {str_shortcut}") + else: + return component + + def get_component_instance(comp: str | dict | Component): if isinstance(comp, str): - shortcut = Component.get_all_shortcut_implementations()[comp] - return shortcut[0](**shortcut[1], without_rendering=True) + component = Component.get_component_shortcut(comp) + if component is None: + raise ValueError(f"No such component: {comp}") + else: + return component elif isinstance( comp, dict ): # a dict with `name` as the input component type and other keys as parameters diff --git a/gradio/templates.py b/gradio/templates.py index 86d9062907..516ea2e5ef 100644 --- a/gradio/templates.py +++ b/gradio/templates.py @@ -1,12 +1,17 @@ -from gradio.components import Audio as C_Audio -from gradio.components import Dataframe as C_Dataframe -from gradio.components import File as C_File -from gradio.components import Image as C_Image -from gradio.components import Textbox as C_Textbox -from gradio.components import Video as C_Video +from gradio import components -class TextArea(C_Textbox): +class Text(components.Textbox): + def __init__(self, **kwargs): + """ + Custom component + @param kwargs: + """ + self.is_template = True + super().__init__(lines=1, **kwargs) + + +class TextArea(components.Textbox): def __init__(self, **kwargs): """ Custom component @@ -16,7 +21,7 @@ class TextArea(C_Textbox): super().__init__(lines=7, **kwargs) -class Webcam(C_Image): +class Webcam(components.Image): def __init__(self, **kwargs): """ Custom component @@ -26,7 +31,7 @@ class Webcam(C_Image): super().__init__(source="webcam", **kwargs) -class Sketchpad(C_Image): +class Sketchpad(components.Image): def __init__(self, **kwargs): """ Custom component @@ -42,7 +47,7 @@ class Sketchpad(C_Image): ) -class Plot(C_Image): +class Plot(components.Image): def __init__(self, **kwargs): """ Custom component @@ -52,7 +57,7 @@ class Plot(C_Image): super().__init__(type="plot", **kwargs) -class Pil(C_Image): +class Pil(components.Image): def __init__(self, **kwargs): """ Custom component @@ -62,7 +67,7 @@ class Pil(C_Image): super().__init__(type="pil", **kwargs) -class PlayableVideo(C_Video): +class PlayableVideo(components.Video): def __init__(self, **kwargs): """ Custom component @@ -72,7 +77,7 @@ class PlayableVideo(C_Video): super().__init__(type="mp4", **kwargs) -class Microphone(C_Audio): +class Microphone(components.Audio): def __init__(self, **kwargs): """ Custom component @@ -82,7 +87,17 @@ class Microphone(C_Audio): super().__init__(source="microphone", **kwargs) -class C_Files(C_File): +class Mic(components.Audio): + def __init__(self, **kwargs): + """ + Custom component + @param kwargs: + """ + self.is_template = True + super().__init__(source="microphone", **kwargs) + + +class Files(components.File): def __init__(self, **kwargs): """ Custom component @@ -92,7 +107,7 @@ class C_Files(C_File): super().__init__(file_count="multiple", **kwargs) -class Numpy(C_Dataframe): +class Numpy(components.Dataframe): def __init__(self, **kwargs): """ Custom component @@ -102,7 +117,7 @@ class Numpy(C_Dataframe): super().__init__(type="numpy", **kwargs) -class Matrix(C_Dataframe): +class Matrix(components.Dataframe): def __init__(self, **kwargs): """ Custom component @@ -112,7 +127,7 @@ class Matrix(C_Dataframe): super().__init__(type="array", **kwargs) -class List(C_Dataframe): +class List(components.Dataframe): def __init__(self, **kwargs): """ Custom component @@ -120,3 +135,13 @@ class List(C_Dataframe): """ self.is_template = True super().__init__(type="array", col_count=1, **kwargs) + + +class Highlight(components.HighlightedText): + def __init__(self, **kwargs): + """ + Custom component + @param kwargs: + """ + self.is_template = True + super().__init__(**kwargs)