mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
moved optional logic to get_config_file & added some typing
This commit is contained in:
parent
c0aa8898b0
commit
6c52e531a2
@ -37,8 +37,7 @@ class InputComponent(Component):
|
||||
Constructs an input component.
|
||||
"""
|
||||
self.set_interpret_parameters()
|
||||
if optional is True:
|
||||
label = InputComponent.label_as_optional(label)
|
||||
self.optional = optional
|
||||
super().__init__(label, requires_permissions)
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
@ -100,10 +99,12 @@ class InputComponent(Component):
|
||||
Returns a sample value of the input that would be accepted by the api. Used for api documentation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def label_as_optional(label: str) -> str:
|
||||
return f"{label}(Optional)"
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"optional": self.optional,
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
|
||||
class Textbox(InputComponent):
|
||||
@ -266,7 +267,7 @@ class Number(InputComponent):
|
||||
Parameters:
|
||||
default (float): default value.
|
||||
label (str): component name in interface.
|
||||
optional (bool):
|
||||
optional (bool): If True, the interface can be submitted with no value for this component.
|
||||
"""
|
||||
self.default = default
|
||||
self.test_input = default if default is not None else 1
|
||||
@ -282,13 +283,15 @@ class Number(InputComponent):
|
||||
"number": {},
|
||||
}
|
||||
|
||||
def preprocess(self, x: Number) -> float:
|
||||
def preprocess(self, x: Optional[Number]) -> Optional[float]:
|
||||
"""
|
||||
Parameters:
|
||||
x (number): numeric input
|
||||
x (string): numeric input as a string
|
||||
Returns:
|
||||
(float): number representing function input
|
||||
"""
|
||||
if self.optional and x is None:
|
||||
return None
|
||||
return float(x)
|
||||
|
||||
def preprocess_example(self, x: float) -> float:
|
||||
@ -459,7 +462,7 @@ class Checkbox(InputComponent):
|
||||
"checkbox": {},
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: bool) -> bool:
|
||||
"""
|
||||
Parameters:
|
||||
x (bool): boolean input
|
||||
@ -533,7 +536,7 @@ class CheckboxGroup(InputComponent):
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: List[str]) -> List[str] | List[int]:
|
||||
"""
|
||||
Parameters:
|
||||
x (List[str]): list of selected choices
|
||||
@ -630,7 +633,7 @@ class Radio(InputComponent):
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: str) -> str | int:
|
||||
"""
|
||||
Parameters:
|
||||
x (str): selected choice
|
||||
@ -706,7 +709,7 @@ class Dropdown(InputComponent):
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: str) -> str | int:
|
||||
"""
|
||||
Parameters:
|
||||
x (str): selected choice
|
||||
@ -782,7 +785,6 @@ class Image(InputComponent):
|
||||
requires_permissions = source == "webcam"
|
||||
self.tool = tool
|
||||
self.type = type
|
||||
self.optional = optional
|
||||
self.invert_colors = invert_colors
|
||||
self.test_input = test_data.BASE64_IMAGE
|
||||
self.interpret_by_tokens = True
|
||||
@ -811,12 +813,12 @@ class Image(InputComponent):
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: Optional[str]) -> np.array | PIL.Image | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (str): base64 url data
|
||||
Returns:
|
||||
(Union[numpy.array, PIL.Image, file-object]): image in requested format
|
||||
(Union[numpy.array, PIL.Image, filepath]): image in requested format
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
@ -1007,7 +1009,6 @@ class Video(InputComponent):
|
||||
"""
|
||||
self.type = type
|
||||
self.source = source
|
||||
self.optional = optional
|
||||
super().__init__(label, optional=optional)
|
||||
|
||||
@classmethod
|
||||
@ -1026,7 +1027,7 @@ class Video(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return {"name": x, "data": None, "is_example": True}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: Dict[str, str] | None) -> str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (Dict[name: str, data: str]): JSON object with filename as 'name' property and base64 data as 'data' property
|
||||
@ -1095,7 +1096,6 @@ class Audio(InputComponent):
|
||||
self.source = source
|
||||
requires_permissions = source == "microphone"
|
||||
self.type = type
|
||||
self.optional = optional
|
||||
self.test_input = test_data.BASE64_AUDIO
|
||||
self.interpret_by_tokens = True
|
||||
super().__init__(label, requires_permissions, optional=optional)
|
||||
@ -1118,12 +1118,12 @@ class Audio(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return {"name": x, "data": None, "is_example": True}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: Dict[str, str] | None) -> Tuple[int, np.array] | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (Dict[name: str, data: str]): JSON object with filename as 'name' property and base64 data as 'data' property
|
||||
Returns:
|
||||
(Union[Tuple[int, numpy.array], file-object, numpy.array]): audio in requested format
|
||||
(Union[Tuple[int, numpy.array], str, numpy.array]): audio in requested format
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
@ -1309,7 +1309,6 @@ class File(InputComponent):
|
||||
self.file_count = file_count
|
||||
self.type = type
|
||||
self.test_input = None
|
||||
self.optional = optional
|
||||
super().__init__(label, optional=optional)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -1329,7 +1328,7 @@ class File(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return {"name": x, "data": None, "is_example": True}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: List[Dict[str, str]] | None):
|
||||
"""
|
||||
Parameters:
|
||||
x (List[Dict[name: str, data: str]]): List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
@ -1459,7 +1458,7 @@ class Dataframe(InputComponent):
|
||||
"list": {"type": "array", "col_count": 1},
|
||||
}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: List[List[str | Number | bool]]):
|
||||
"""
|
||||
Parameters:
|
||||
x (List[List[Union[str, number, bool]]]): 2D array of str, numeric, or bool data
|
||||
@ -1522,7 +1521,6 @@ class Timeseries(InputComponent):
|
||||
if isinstance(y, str):
|
||||
y = [y]
|
||||
self.y = y
|
||||
self.optional = optional
|
||||
super().__init__(label, optional=optional)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -1542,7 +1540,7 @@ class Timeseries(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return {"name": x, "is_example": True}
|
||||
|
||||
def preprocess(self, x):
|
||||
def preprocess(self, x: Dict | None) -> pd.DataFrame | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (Dict[data: List[List[Union[str, number, bool]]], headers: List[str], range: List[number]]): Dict with keys 'data': 2D array of str, numeric, or bool data, 'headers': list of strings for header names, 'range': optional two element list designating start of end of subrange.
|
||||
|
@ -45,8 +45,8 @@
|
||||
</script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
||||
<title>Gradio</title>
|
||||
<script type="module" crossorigin src="/assets/index.1efa643a.js"></script>
|
||||
<link rel="modulepreload" href="/assets/vendor.a0afec2a.js">
|
||||
<script type="module" crossorigin src="/assets/index.1b31b66f.js"></script>
|
||||
<link rel="modulepreload" href="/assets/vendor.fe13b00e.js">
|
||||
<link rel="stylesheet" href="/assets/vendor.327fceeb.css">
|
||||
<link rel="stylesheet" href="/assets/index.dee61218.css">
|
||||
</head>
|
||||
|
@ -217,6 +217,8 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
|
||||
for iface, param in zip(config["input_components"], param_names):
|
||||
if not iface["label"]:
|
||||
iface["label"] = param.replace("_", " ")
|
||||
if iface["optional"]:
|
||||
iface["label"] += " (optional)"
|
||||
for i, iface in enumerate(config["output_components"]):
|
||||
outputs_per_function = int(
|
||||
len(interface.output_components) / len(interface.predict)
|
||||
|
Loading…
Reference in New Issue
Block a user