From 8d5e05cdbee0e1240495e3d8947b6bc05e693bf8 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 13 May 2022 22:04:11 -0400 Subject: [PATCH] Add precision to Number, backend only (#1125) * integer type * Add integer parameter for Number * Pass down integer prop to Number instead: * Format + update test * Update interpretation neighbors code + docstrings * Rename prop to type as opposed to integer * Update error message + test * Update docstring * Implement precision * Add test + format * Add test for precision=2 * Change round logic for precision=0 * integer type * Add integer parameter for Number * Pass down integer prop to Number instead: * Format + update test * Rename prop to type as opposed to integer * Implement precision * Add test + format * Address docstrings * Fix test + rebase --- gradio/components.py | 72 +++++++++++++++++++++++++-------- test/test_components.py | 89 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 16 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index c5e9e53ed1..7a5f8ca158 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -631,8 +631,8 @@ class Number(Changeable, Submittable, IOComponent): Component creates a field for user to enter numeric input or display numeric output. Provides a number as an argument to the wrapped function. Can be used as an output as well. - Input type: float - Output type: float + Input type: float or int. + Output type: float or int. Demos: tax_calculator, titanic_survival """ @@ -645,6 +645,7 @@ class Number(Changeable, Submittable, IOComponent): interactive: Optional[bool] = None, visible: bool = True, elem_id: Optional[str] = None, + precision: Optional[int] = None, **kwargs, ): """ @@ -653,8 +654,11 @@ class Number(Changeable, Submittable, IOComponent): label (Optional[str]): component name in interface. show_label (bool): if True, will display label. visible (bool): If False, component will be hidden. + precision (Optional[int]): Precision to round input/output to. If set to 0, will + round to nearest integer and covert type to int. If None, no rounding happens. """ - self.value = float(value) if value is not None else None + self.value = self.round_to_precision(value, precision) + self.precision = precision self.test_input = self.value if self.value is not None else 1 self.interpret_by_tokens = False IOComponent.__init__( @@ -667,6 +671,30 @@ class Number(Changeable, Submittable, IOComponent): **kwargs, ) + @staticmethod + def round_to_precision( + num: float | int | None, precision: int | None + ) -> float | int | None: + """ + Round to a given precision. + + If precision is None, no rounding happens. If 0, num is converted to int. + + Parameters: + num (float | int): Number to round. + precision (int | None): Precision to round to. + Returns: + (float | int): rounded number + """ + if num is None: + return None + if precision is None: + return float(num) + elif precision == 0: + return int(round(num, precision)) + else: + return round(num, precision) + def get_config(self): return { "value": self.value, @@ -689,26 +717,26 @@ class Number(Changeable, Submittable, IOComponent): "__type__": "update", } - def preprocess(self, x: float | None) -> Optional[float]: + def preprocess(self, x: int | float | None) -> int | float | None: """ Parameters: - x (string): numeric input as a string + x (int | float | None): numeric input as a string Returns: - (float): number representing function input + (int | float | None): number representing function input """ if x is None: return None - return float(x) + return self.round_to_precision(x, self.precision) - def preprocess_example(self, x: float | None) -> float | None: + def preprocess_example(self, x: int | float | None) -> int | float | None: """ Returns: - (float): Number representing function input + (int | float | None): Number representing function input """ if x is None: return None else: - return float(x) + return self.round_to_precision(x, self.precision) def set_interpret_parameters( self, steps: int = 3, delta: float = 1, delta_type: str = "percent" @@ -725,14 +753,21 @@ class Number(Changeable, Submittable, IOComponent): self.interpretation_delta_type = delta_type return self - def get_interpretation_neighbors(self, x: float) -> Tuple[List[float], Dict]: - x = float(x) + def get_interpretation_neighbors(self, x: float | int) -> Tuple[List[float], Dict]: + x = self.round_to_precision(x, self.precision) if self.interpretation_delta_type == "percent": delta = 1.0 * self.interpretation_delta * x / 100 elif self.interpretation_delta_type == "absolute": delta = self.interpretation_delta else: delta = self.interpretation_delta + if self.precision == 0 and math.floor(delta) != delta: + raise ValueError( + f"Delta value {delta} is not an integer and precision=0. Cannot generate valid set of neighbors. " + "If delta_type='percent', pick a value of delta such that x * delta is an integer. " + "If delta_type='absolute', pick a value of delta that is an integer." + ) + # run_interpretation will preprocess the neighbors so no need to covert to int here negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist() positives = (x + np.arange(1, self.interpretation_steps + 1) * delta).tolist() return negatives + positives, {} @@ -748,18 +783,23 @@ class Number(Changeable, Submittable, IOComponent): interpretation.insert(int(len(interpretation) / 2), [x, None]) return interpretation - def generate_sample(self) -> float: - return 1.0 + def generate_sample(self) -> int | float: + return self.round_to_precision(1, self.precision) # Output Functionalities - def postprocess(self, y: float | None): + def postprocess(self, y: int | float | None) -> int | float | None: """ Any postprocessing needed to be performed on function output. + + Parameters: + y (int | float | None): numeric output + Returns: + (int | float | None): number representing function output """ if y is None: return None else: - return float(y) + return self.round_to_precision(y, self.precision) def deserialize(self, y): """ diff --git a/test/test_components.py b/test/test_components.py index dff0d92d71..3b254218b1 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -195,6 +195,71 @@ class TestNumber(unittest.TestCase): }, ) + def test_component_functions_integer(self): + """ + Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context + + """ + numeric_input = gr.Number(precision=0, value=42) + self.assertEqual(numeric_input.preprocess(3), 3) + self.assertEqual(numeric_input.preprocess(None), None) + self.assertEqual(numeric_input.preprocess_example(3), 3) + self.assertEqual(numeric_input.postprocess(3), 3) + self.assertEqual(numeric_input.postprocess(2.85), 3) + self.assertEqual(numeric_input.postprocess(None), None) + self.assertEqual(numeric_input.serialize(3, True), 3) + with tempfile.TemporaryDirectory() as tmpdirname: + to_save = numeric_input.save_flagged(tmpdirname, "numeric_input", 3, None) + self.assertEqual(to_save, 3) + restored = numeric_input.restore_flagged(tmpdirname, to_save, None) + self.assertEqual(restored, 3) + self.assertIsInstance(numeric_input.generate_sample(), int) + numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute") + self.assertEqual( + numeric_input.get_interpretation_neighbors(1), + ([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}), + ) + numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent") + self.assertEqual( + numeric_input.get_interpretation_neighbors(100), + ([97.0, 98.0, 99.0, 101.0, 102.0, 103.0], {}), + ) + with self.assertRaises(ValueError) as error: + numeric_input.get_interpretation_neighbors(1) + assert error.msg == "Cannot generate valid set of neighbors" + numeric_input.set_interpret_parameters( + steps=3, delta=1.24, delta_type="absolute" + ) + with self.assertRaises(ValueError) as error: + numeric_input.get_interpretation_neighbors(4) + assert error.msg == "Cannot generate valid set of neighbors" + self.assertEqual( + numeric_input.get_config(), + { + "value": 42, + "name": "number", + "show_label": True, + "label": None, + "style": {}, + "elem_id": None, + "visible": True, + "interactive": None, + }, + ) + + def test_component_functions_precision(self): + """ + Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context + + """ + numeric_input = gr.Number(precision=2, value=42.3428) + self.assertEqual(numeric_input.preprocess(3.231241), 3.23) + self.assertEqual(numeric_input.preprocess(None), None) + self.assertEqual(numeric_input.preprocess_example(-42.1241), -42.12) + self.assertEqual(numeric_input.postprocess(5.6784), 5.68) + self.assertEqual(numeric_input.postprocess(2.1421), 2.14) + self.assertEqual(numeric_input.postprocess(None), None) + def test_in_interface_as_input(self): """ Interface, process, interpret @@ -218,6 +283,30 @@ class TestNumber(unittest.TestCase): ], ) + def test_precision_0_in_interface(self): + """ + Interface, process, interpret + """ + iface = gr.Interface(lambda x: x**2, gr.Number(precision=0), "textbox") + self.assertEqual(iface.process([2]), ["4"]) + iface = gr.Interface( + lambda x: x**2, "number", gr.Number(precision=0), interpretation="default" + ) + # Output gets rounded to 4 for all input so no change + scores = iface.interpret([2])[0]["interpretation"] + self.assertEqual( + scores, + [ + (1.94, 0.0), + (1.96, 0.0), + (1.98, 0.0), + [2, None], + (2.02, 0.0), + (2.04, 0.0), + (2.06, 0.0), + ], + ) + def test_in_interface_as_output(self): """ Interface, process, interpret