backend_default_value_refactoring

- refactor default to default_value
This commit is contained in:
Ömer Faruk Özdemir 2022-03-24 19:25:28 +03:00
parent 569491f896
commit 2f4f16a5ad
3 changed files with 85 additions and 52 deletions

View File

@ -11,12 +11,12 @@ import warnings
from types import ModuleType from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import matplotlib.figure
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import PIL import PIL
from ffmpy import FFmpeg from ffmpy import FFmpeg
from markdown_it import MarkdownIt from markdown_it import MarkdownIt
import matplotlib.figure
from gradio import processing_utils, test_data from gradio import processing_utils, test_data
from gradio.blocks import Block from gradio.blocks import Block
@ -245,7 +245,7 @@ class Textbox(Component):
default_value = str(default_value) default_value = str(default_value)
self.lines = lines self.lines = lines
self.placeholder = placeholder self.placeholder = placeholder
self.default = default_value self.default_value = default_value
self.test_input = default_value self.test_input = default_value
self.interpret_by_tokens = True self.interpret_by_tokens = True
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
@ -254,7 +254,7 @@ class Textbox(Component):
return { return {
"lines": self.lines, "lines": self.lines,
"placeholder": self.placeholder, "placeholder": self.placeholder,
"default": self.default, "default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -411,13 +411,13 @@ class Number(Component):
default_value (float): default value. default_value (float): default value.
label (str): component name in interface. label (str): component name in interface.
""" """
self.default = float(default_value) if default_value is not None else None self.default_value = float(default_value) if default_value is not None else None
self.test_input = self.default if self.default is not None else 1 self.test_input = self.default_value if self.default_value is not None else 1
self.interpret_by_tokens = False self.interpret_by_tokens = False
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
def get_template_context(self): def get_template_context(self):
return {"default": self.default, **super().get_template_context()} return {"default_value": self.default_value, **super().get_template_context()}
@classmethod @classmethod
def get_shortcut_implementations(cls): def get_shortcut_implementations(cls):
@ -558,8 +558,8 @@ class Slider(Component):
power = math.floor(math.log10(difference) - 2) power = math.floor(math.log10(difference) - 2)
step = 10**power step = 10**power
self.step = step self.step = step
self.default = minimum if default_value is None else default_value self.default_value = minimum if default_value is None else default_value
self.test_input = self.default self.test_input = self.default_value
self.interpret_by_tokens = False self.interpret_by_tokens = False
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
@ -568,7 +568,7 @@ class Slider(Component):
"minimum": self.minimum, "minimum": self.minimum,
"maximum": self.maximum, "maximum": self.maximum,
"step": self.step, "step": self.step,
"default": self.default, "default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -669,12 +669,12 @@ class Checkbox(Component):
label (str): component name in interface. label (str): component name in interface.
""" """
self.test_input = True self.test_input = True
self.default = default_value self.default_value = default_value
self.interpret_by_tokens = False self.interpret_by_tokens = False
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
def get_template_context(self): def get_template_context(self):
return {"default": self.default, **super().get_template_context()} return {"default_value": self.default_value, **super().get_template_context()}
@classmethod @classmethod
def get_shortcut_implementations(cls): def get_shortcut_implementations(cls):
@ -774,7 +774,7 @@ class CheckboxGroup(Component):
): # Mutable parameters shall not be given as default parameters in the function. ): # Mutable parameters shall not be given as default parameters in the function.
default_selected = [] default_selected = []
self.choices = choices self.choices = choices
self.default = default_selected self.default_value = default_selected
self.type = type self.type = type
self.test_input = self.choices self.test_input = self.choices
self.interpret_by_tokens = False self.interpret_by_tokens = False
@ -783,7 +783,7 @@ class CheckboxGroup(Component):
def get_template_context(self): def get_template_context(self):
return { return {
"choices": self.choices, "choices": self.choices,
"default": self.default, "default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -900,7 +900,7 @@ class Radio(Component):
self.choices = choices self.choices = choices
self.type = type self.type = type
self.test_input = self.choices[0] self.test_input = self.choices[0]
self.default = ( self.default_value = (
default_selected if default_selected is not None else self.choices[0] default_selected if default_selected is not None else self.choices[0]
) )
self.interpret_by_tokens = False self.interpret_by_tokens = False
@ -909,7 +909,7 @@ class Radio(Component):
def get_template_context(self): def get_template_context(self):
return { return {
"choices": self.choices, "choices": self.choices,
"default": self.default, "default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -1023,7 +1023,7 @@ class Image(Component):
def __init__( def __init__(
self, self,
default_value=None, default_value: Optional[str] = None,
*, *,
shape: Tuple[int, int] = None, shape: Tuple[int, int] = None,
image_mode: str = "RGB", image_mode: str = "RGB",
@ -1037,7 +1037,7 @@ class Image(Component):
): ):
""" """
Parameters: Parameters:
default_value(str): IGNORED default_value(str): A path or URL for the default value that Image component is going to take.
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size. shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
image_mode (str): "RGB" if color, or "L" if black and white. image_mode (str): "RGB" if color, or "L" if black and white.
invert_colors (bool): whether to invert the image as a preprocessing step. invert_colors (bool): whether to invert the image as a preprocessing step.
@ -1055,6 +1055,8 @@ class Image(Component):
else: else:
self.type = type self.type = type
if default_value is not None:
self.default_value = processing_utils.decode_base64_to_file(default_value)
self.type = type self.type = type
self.output_type = "auto" self.output_type = "auto"
self.shape = shape self.shape = shape
@ -1069,8 +1071,7 @@ class Image(Component):
label=label, requires_permissions=requires_permissions, **kwargs label=label, requires_permissions=requires_permissions, **kwargs
) )
@classmethod def get_shortcut_implementations(self):
def get_shortcut_implementations(cls):
return { return {
"image": {}, "image": {},
"webcam": {"source": "webcam"}, "webcam": {"source": "webcam"},
@ -1082,6 +1083,7 @@ class Image(Component):
}, },
"plot": {"type": "plot"}, "plot": {"type": "plot"},
"pil": {"type": "pil"}, "pil": {"type": "pil"},
"default_value": self.default_value,
} }
def get_template_context(self): def get_template_context(self):
@ -1355,7 +1357,7 @@ class Video(Component):
def __init__( def __init__(
self, self,
default_value="", default_value: str = "",
*, *,
type: Optional[str] = None, type: Optional[str] = None,
source: str = "upload", source: str = "upload",
@ -1365,12 +1367,13 @@ class Video(Component):
): ):
""" """
Parameters: Parameters:
default_value (str): IGNORED default_value(str): A path or URL for the default value that Video component is going to take.
type (str): Type of video format to be returned by component, such as 'avi' or 'mp4'. Use 'mp4' to ensure browser playability. If set to None, video will keep uploaded format. type (str): Type of video format to be returned by component, such as 'avi' or 'mp4'. Use 'mp4' to ensure browser playability. If set to None, video will keep uploaded format.
source (str): Source of video. "upload" creates a box where user can drop an video file, "webcam" allows user to record a video from their webcam. source (str): Source of video. "upload" creates a box where user can drop an video file, "webcam" allows user to record a video from their webcam.
label (str): component name in interface. label (str): component name in interface.
optional (bool): If True, the interface can be submitted with no uploaded video, in which case the input value is None. optional (bool): If True, the interface can be submitted with no uploaded video, in which case the input value is None.
""" """
self.default_value = processing_utils.decode_base64_to_file(default_value)
self.type = type self.type = type
self.source = source self.source = source
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
@ -1385,6 +1388,7 @@ class Video(Component):
def get_template_context(self): def get_template_context(self):
return { return {
"source": self.source, "source": self.source,
"default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -1534,6 +1538,7 @@ class Audio(Component):
type (str): The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly. type (str): The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
label (str): component name in interface. label (str): component name in interface.
""" """
self.default_value = processing_utils.decode_base64_to_file(default_value)
self.source = source self.source = source
requires_permissions = source == "microphone" requires_permissions = source == "microphone"
self.type = type self.type = type
@ -1547,6 +1552,7 @@ class Audio(Component):
def get_template_context(self): def get_template_context(self):
return { return {
"source": self.source, # TODO: This did not exist in output template, careful here if an error arrives "source": self.source, # TODO: This did not exist in output template, careful here if an error arrives
"default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -1831,13 +1837,14 @@ class File(Component):
): ):
""" """
Parameters: Parameters:
default_value (str): IGNORED default_value (str): Default value given as file path
file_count (str): if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory". file_count (str): if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory".
type (str): Type of value to be returned by component. "file" returns a temporary file object whose path can be retrieved by file_obj.name, "binary" returns an bytes object. type (str): Type of value to be returned by component. "file" returns a temporary file object whose path can be retrieved by file_obj.name, "binary" returns an bytes object.
label (str): component name in interface. label (str): component name in interface.
""" """
if "keep_filename" in kwargs: if "keep_filename" in kwargs:
warnings.warn("keep_filename is deprecated", DeprecationWarning) warnings.warn("keep_filename is deprecated", DeprecationWarning)
self.default_value = processing_utils.decode_base64_to_file(default_value)
self.file_count = file_count self.file_count = file_count
self.type = type self.type = type
self.test_input = None self.test_input = None
@ -1846,6 +1853,7 @@ class File(Component):
def get_template_context(self): def get_template_context(self):
return { return {
"file_count": self.file_count, "file_count": self.file_count,
"default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -1997,7 +2005,7 @@ class Dataframe(Component):
self.col_width = col_width self.col_width = col_width
self.type = type self.type = type
self.output_type = "auto" self.output_type = "auto"
self.default = ( self.default_value = (
default_value default_value
if default_value is not None if default_value is not None
else [[None for _ in range(self.col_count)] for _ in range(self.row_count)] else [[None for _ in range(self.col_count)] for _ in range(self.row_count)]
@ -2026,7 +2034,7 @@ class Dataframe(Component):
"row_count": self.row_count, "row_count": self.row_count,
"col_count": self.col_count, "col_count": self.col_count,
"col_width": self.col_width, "col_width": self.col_width,
"default": self.default, "default_value": self.default_value,
"max_rows": self.max_rows, "max_rows": self.max_rows,
"max_cols": self.max_cols, "max_cols": self.max_cols,
"overflow_row_behaviour": self.overflow_row_behaviour, "overflow_row_behaviour": self.overflow_row_behaviour,
@ -2136,7 +2144,7 @@ class Timeseries(Component):
def __init__( def __init__(
self, self,
default_value=None, default_value: str = "",
*, *,
x: Optional[str] = None, x: Optional[str] = None,
y: str | List[str] = None, y: str | List[str] = None,
@ -2146,11 +2154,13 @@ class Timeseries(Component):
): ):
""" """
Parameters: Parameters:
default_value: IGNORED default_value: File path for the timeseries csv file.
x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series. x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series. y (Union[str, List[str]]): Column name of y series, or list of column names if multiple series. None if csv has no headers, in which case every column after first is a y series.
label (str): component name in interface. label (str): component name in interface.
""" """
# TODO: Probably incorrect
self.default_value = pd.DataFrame(default_value)
self.x = x self.x = x
if isinstance(y, str): if isinstance(y, str):
y = [y] y = [y]
@ -2247,11 +2257,11 @@ class State(Component):
default_value (Any): the initial value of the state. default_value (Any): the initial value of the state.
label (str): component name in interface (not used). label (str): component name in interface (not used).
""" """
self.default = default_value self.default_value = default_value
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
def get_template_context(self): def get_template_context(self):
return {"default": self.default, **super().get_template_context()} return {"default_value": self.default_value, **super().get_template_context()}
@classmethod @classmethod
def get_shortcut_implementations(cls): def get_shortcut_implementations(cls):
@ -2281,10 +2291,11 @@ class Label(Component):
): ):
""" """
Parameters: Parameters:
default_value(str): IGNORED default_value(str): Default string value
num_top_classes (int): number of most confident classes to show. num_top_classes (int): number of most confident classes to show.
label (str): component name in interface. label (str): component name in interface.
""" """
# TODO: Shall we have a default value for the label component?
self.num_top_classes = num_top_classes self.num_top_classes = num_top_classes
self.output_type = "auto" self.output_type = "auto"
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
@ -2423,11 +2434,12 @@ class HighlightedText(Component):
): ):
""" """
Parameters: Parameters:
default_value (str): IGNORED default_value (str): Default value
color_map (Dict[str, str]): Map between category and respective colors color_map (Dict[str, str]): Map between category and respective colors
label (str): component name in interface. label (str): component name in interface.
show_legend (bool): whether to show span categories in a separate legend or inline. show_legend (bool): whether to show span categories in a separate legend or inline.
""" """
self.default_value = default_value
self.color_map = color_map self.color_map = color_map
self.show_legend = show_legend self.show_legend = show_legend
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
@ -2436,6 +2448,7 @@ class HighlightedText(Component):
return { return {
"color_map": self.color_map, "color_map": self.color_map,
"show_legend": self.show_legend, "show_legend": self.show_legend,
"default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }
@ -2489,11 +2502,18 @@ class JSON(Component):
): ):
""" """
Parameters: Parameters:
default_value (str): IGNORED default_value (str): Default value
label (str): component name in interface. label (str): component name in interface.
""" """
self.default_value = json.dumps(default_value)
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
def get_template_context(self):
return {
"default_value": self.default_value,
**super().get_template_context(),
}
def postprocess(self, y): def postprocess(self, y):
""" """
Parameters: Parameters:
@ -2545,11 +2565,18 @@ class HTML(Component):
): ):
""" """
Parameters: Parameters:
default_value (str): IGNORED default_value (str): Default value
label (str): component name in interface. label (str): component name in interface.
""" """
self.default_value = default_value
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
def get_template_context(self):
return {
"default_value": self.default_value,
**super().get_template_context(),
}
def postprocess(self, x): def postprocess(self, x):
""" """
Parameters: Parameters:
@ -2598,6 +2625,7 @@ class Carousel(Component):
components (Union[List[OutputComponent], OutputComponent]): Classes of component(s) that will be scrolled through. components (Union[List[OutputComponent], OutputComponent]): Classes of component(s) that will be scrolled through.
label (str): component name in interface. label (str): component name in interface.
""" """
# TODO: Shall we havea default value in carousel?
if not isinstance(components, list): if not isinstance(components, list):
components = [components] components = [components]
self.components = [ self.components = [
@ -2683,13 +2711,14 @@ class Chatbot(Component):
): ):
""" """
Parameters: Parameters:
default_value (str): IGNORED default_value (str): Default value
label (str): component name in interface (not used). label (str): component name in interface (not used).
""" """
self.default_value = default_value
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
def get_template_context(self): def get_template_context(self):
return {**super().get_template_context()} return {"default_value": self.default_value, **super().get_template_context()}
@classmethod @classmethod
def get_shortcut_implementations(cls): def get_shortcut_implementations(cls):
@ -2728,12 +2757,18 @@ class Markdown(Component):
css: Optional[Dict] = None, css: Optional[Dict] = None,
**kwargs, **kwargs,
): ):
"""
Parameters:
default_value (str): Default value
label (str): component name
css (dict): optional css parameters for the component
"""
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
self.md = MarkdownIt() self.md = MarkdownIt()
self.value = self.md.render(default_value) self.default_value = self.md.render(default_value)
def get_template_context(self): def get_template_context(self):
return {"value": self.value, **super().get_template_context()} return {"default_value": self.default_value, **super().get_template_context()}
class Button(Component): class Button(Component):
@ -2745,11 +2780,17 @@ class Button(Component):
css: Optional[Dict] = None, css: Optional[Dict] = None,
**kwargs, **kwargs,
): ):
"""
Parameters:
default_value (str): Default value
label (str): component name
css (dict): optional css parameters for the component
"""
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
self.value = default_value self.default_value = default_value
def get_template_context(self): def get_template_context(self):
return {"value": self.value, **super().get_template_context()} return {"default_value": self.default_value, **super().get_template_context()}
def click(self, fn: Callable, inputs: List[Component], outputs: List[Component]): def click(self, fn: Callable, inputs: List[Component], outputs: List[Component]):
""" """
@ -2774,12 +2815,12 @@ class DatasetViewer(Component):
): ):
super().__init__(label=label, css=css, **kwargs) super().__init__(label=label, css=css, **kwargs)
self.types = types self.types = types
self.value = default_value self.default_value = default_value
def get_template_context(self): def get_template_context(self):
return { return {
"types": [_type.__class__.__name__.lower() for _type in types], "types": [_type.__class__.__name__.lower() for _type in self.types],
"value": self.value, "default_value": self.default_value,
**super().get_template_context(), **super().get_template_context(),
} }

View File

@ -489,12 +489,3 @@ class State(C_State):
DeprecationWarning, DeprecationWarning,
) )
super().__init__(default_value=default, label=label, optional=optional) super().__init__(default_value=default, label=label, optional=optional)
def get_template_context(self):
return {"default": self.default, **super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
return {
"state": {},
}

View File

@ -186,9 +186,9 @@ class Interface(Launchable):
isinstance(i, i_State) for i in self.input_components isinstance(i, i_State) for i in self.input_components
].index(True) ].index(True)
state: i_State = self.input_components[state_param_index] state: i_State = self.input_components[state_param_index]
if state.default is None: if state.default_value is None:
default = utils.get_default_args(fn[0])[state_param_index] default = utils.get_default_args(fn[0])[state_param_index]
state.default = default state.default_value = default
if ( if (
interpretation is None interpretation is None
@ -643,6 +643,7 @@ class Interface(Launchable):
"flag_index": flag_index, "flag_index": flag_index,
} }
# TODO: Remove duplicate process_api, Ali Abid what is it for?
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
class RequestApi: class RequestApi:
SUBMIT = 0 SUBMIT = 0