mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-09 02:00:44 +08:00
Add precision to Number, backend only (#1125)
* integer type * Add integer parameter for Number * Pass down integer prop to Number instead: * Format + update test * Update interpretation neighbors code + docstrings * Rename prop to type as opposed to integer * Update error message + test * Update docstring * Implement precision * Add test + format * Add test for precision=2 * Change round logic for precision=0 * integer type * Add integer parameter for Number * Pass down integer prop to Number instead: * Format + update test * Rename prop to type as opposed to integer * Implement precision * Add test + format * Address docstrings * Fix test + rebase
This commit is contained in:
parent
97929ee795
commit
8d5e05cdbe
@ -631,8 +631,8 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
Component creates a field for user to enter numeric input or display numeric output. Provides a number as an argument to the wrapped function.
|
||||
Can be used as an output as well.
|
||||
|
||||
Input type: float
|
||||
Output type: float
|
||||
Input type: float or int.
|
||||
Output type: float or int.
|
||||
Demos: tax_calculator, titanic_survival
|
||||
"""
|
||||
|
||||
@ -645,6 +645,7 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
interactive: Optional[bool] = None,
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
precision: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -653,8 +654,11 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
label (Optional[str]): component name in interface.
|
||||
show_label (bool): if True, will display label.
|
||||
visible (bool): If False, component will be hidden.
|
||||
precision (Optional[int]): Precision to round input/output to. If set to 0, will
|
||||
round to nearest integer and covert type to int. If None, no rounding happens.
|
||||
"""
|
||||
self.value = float(value) if value is not None else None
|
||||
self.value = self.round_to_precision(value, precision)
|
||||
self.precision = precision
|
||||
self.test_input = self.value if self.value is not None else 1
|
||||
self.interpret_by_tokens = False
|
||||
IOComponent.__init__(
|
||||
@ -667,6 +671,30 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def round_to_precision(
|
||||
num: float | int | None, precision: int | None
|
||||
) -> float | int | None:
|
||||
"""
|
||||
Round to a given precision.
|
||||
|
||||
If precision is None, no rounding happens. If 0, num is converted to int.
|
||||
|
||||
Parameters:
|
||||
num (float | int): Number to round.
|
||||
precision (int | None): Precision to round to.
|
||||
Returns:
|
||||
(float | int): rounded number
|
||||
"""
|
||||
if num is None:
|
||||
return None
|
||||
if precision is None:
|
||||
return float(num)
|
||||
elif precision == 0:
|
||||
return int(round(num, precision))
|
||||
else:
|
||||
return round(num, precision)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"value": self.value,
|
||||
@ -689,26 +717,26 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
def preprocess(self, x: float | None) -> Optional[float]:
|
||||
def preprocess(self, x: int | float | None) -> int | float | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (string): numeric input as a string
|
||||
x (int | float | None): numeric input as a string
|
||||
Returns:
|
||||
(float): number representing function input
|
||||
(int | float | None): number representing function input
|
||||
"""
|
||||
if x is None:
|
||||
return None
|
||||
return float(x)
|
||||
return self.round_to_precision(x, self.precision)
|
||||
|
||||
def preprocess_example(self, x: float | None) -> float | None:
|
||||
def preprocess_example(self, x: int | float | None) -> int | float | None:
|
||||
"""
|
||||
Returns:
|
||||
(float): Number representing function input
|
||||
(int | float | None): Number representing function input
|
||||
"""
|
||||
if x is None:
|
||||
return None
|
||||
else:
|
||||
return float(x)
|
||||
return self.round_to_precision(x, self.precision)
|
||||
|
||||
def set_interpret_parameters(
|
||||
self, steps: int = 3, delta: float = 1, delta_type: str = "percent"
|
||||
@ -725,14 +753,21 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
self.interpretation_delta_type = delta_type
|
||||
return self
|
||||
|
||||
def get_interpretation_neighbors(self, x: float) -> Tuple[List[float], Dict]:
|
||||
x = float(x)
|
||||
def get_interpretation_neighbors(self, x: float | int) -> Tuple[List[float], Dict]:
|
||||
x = self.round_to_precision(x, self.precision)
|
||||
if self.interpretation_delta_type == "percent":
|
||||
delta = 1.0 * self.interpretation_delta * x / 100
|
||||
elif self.interpretation_delta_type == "absolute":
|
||||
delta = self.interpretation_delta
|
||||
else:
|
||||
delta = self.interpretation_delta
|
||||
if self.precision == 0 and math.floor(delta) != delta:
|
||||
raise ValueError(
|
||||
f"Delta value {delta} is not an integer and precision=0. Cannot generate valid set of neighbors. "
|
||||
"If delta_type='percent', pick a value of delta such that x * delta is an integer. "
|
||||
"If delta_type='absolute', pick a value of delta that is an integer."
|
||||
)
|
||||
# run_interpretation will preprocess the neighbors so no need to covert to int here
|
||||
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
|
||||
positives = (x + np.arange(1, self.interpretation_steps + 1) * delta).tolist()
|
||||
return negatives + positives, {}
|
||||
@ -748,18 +783,23 @@ class Number(Changeable, Submittable, IOComponent):
|
||||
interpretation.insert(int(len(interpretation) / 2), [x, None])
|
||||
return interpretation
|
||||
|
||||
def generate_sample(self) -> float:
|
||||
return 1.0
|
||||
def generate_sample(self) -> int | float:
|
||||
return self.round_to_precision(1, self.precision)
|
||||
|
||||
# Output Functionalities
|
||||
def postprocess(self, y: float | None):
|
||||
def postprocess(self, y: int | float | None) -> int | float | None:
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
|
||||
Parameters:
|
||||
y (int | float | None): numeric output
|
||||
Returns:
|
||||
(int | float | None): number representing function output
|
||||
"""
|
||||
if y is None:
|
||||
return None
|
||||
else:
|
||||
return float(y)
|
||||
return self.round_to_precision(y, self.precision)
|
||||
|
||||
def deserialize(self, y):
|
||||
"""
|
||||
|
@ -195,6 +195,71 @@ class TestNumber(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_component_functions_integer(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
|
||||
|
||||
"""
|
||||
numeric_input = gr.Number(precision=0, value=42)
|
||||
self.assertEqual(numeric_input.preprocess(3), 3)
|
||||
self.assertEqual(numeric_input.preprocess(None), None)
|
||||
self.assertEqual(numeric_input.preprocess_example(3), 3)
|
||||
self.assertEqual(numeric_input.postprocess(3), 3)
|
||||
self.assertEqual(numeric_input.postprocess(2.85), 3)
|
||||
self.assertEqual(numeric_input.postprocess(None), None)
|
||||
self.assertEqual(numeric_input.serialize(3, True), 3)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = numeric_input.save_flagged(tmpdirname, "numeric_input", 3, None)
|
||||
self.assertEqual(to_save, 3)
|
||||
restored = numeric_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, 3)
|
||||
self.assertIsInstance(numeric_input.generate_sample(), int)
|
||||
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
|
||||
self.assertEqual(
|
||||
numeric_input.get_interpretation_neighbors(1),
|
||||
([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}),
|
||||
)
|
||||
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent")
|
||||
self.assertEqual(
|
||||
numeric_input.get_interpretation_neighbors(100),
|
||||
([97.0, 98.0, 99.0, 101.0, 102.0, 103.0], {}),
|
||||
)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
numeric_input.get_interpretation_neighbors(1)
|
||||
assert error.msg == "Cannot generate valid set of neighbors"
|
||||
numeric_input.set_interpret_parameters(
|
||||
steps=3, delta=1.24, delta_type="absolute"
|
||||
)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
numeric_input.get_interpretation_neighbors(4)
|
||||
assert error.msg == "Cannot generate valid set of neighbors"
|
||||
self.assertEqual(
|
||||
numeric_input.get_config(),
|
||||
{
|
||||
"value": 42,
|
||||
"name": "number",
|
||||
"show_label": True,
|
||||
"label": None,
|
||||
"style": {},
|
||||
"elem_id": None,
|
||||
"visible": True,
|
||||
"interactive": None,
|
||||
},
|
||||
)
|
||||
|
||||
def test_component_functions_precision(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
|
||||
|
||||
"""
|
||||
numeric_input = gr.Number(precision=2, value=42.3428)
|
||||
self.assertEqual(numeric_input.preprocess(3.231241), 3.23)
|
||||
self.assertEqual(numeric_input.preprocess(None), None)
|
||||
self.assertEqual(numeric_input.preprocess_example(-42.1241), -42.12)
|
||||
self.assertEqual(numeric_input.postprocess(5.6784), 5.68)
|
||||
self.assertEqual(numeric_input.postprocess(2.1421), 2.14)
|
||||
self.assertEqual(numeric_input.postprocess(None), None)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
@ -218,6 +283,30 @@ class TestNumber(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_precision_0_in_interface(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x**2, gr.Number(precision=0), "textbox")
|
||||
self.assertEqual(iface.process([2]), ["4"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", gr.Number(precision=0), interpretation="default"
|
||||
)
|
||||
# Output gets rounded to 4 for all input so no change
|
||||
scores = iface.interpret([2])[0]["interpretation"]
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
(1.94, 0.0),
|
||||
(1.96, 0.0),
|
||||
(1.98, 0.0),
|
||||
[2, None],
|
||||
(2.02, 0.0),
|
||||
(2.04, 0.0),
|
||||
(2.06, 0.0),
|
||||
],
|
||||
)
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
|
Loading…
Reference in New Issue
Block a user