From c977ef1fa85d880c5ac643978e3b505e30ea2bb8 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 21 Sep 2022 11:53:06 -0500 Subject: [PATCH] Supports `gr.update()` in example caching (#2309) * support for update * formatting * testing * fixes to tests * fixes to tests * fix tests * fix review comments * Update blocks.py --- gradio/blocks.py | 91 +++++++++++++++++++++---------------------- gradio/examples.py | 34 +++++++++++----- gradio/flagging.py | 19 +++++---- gradio/utils.py | 5 +++ test/test_blocks.py | 35 +++++++++++++++++ test/test_examples.py | 50 +++++++++++++++++++++--- 6 files changed, 163 insertions(+), 71 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index 874cca7ea7..165a9e171c 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -44,7 +44,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime). import wandb from fastapi.applications import FastAPI - from gradio.components import Component, StatusTracker + from gradio.components import Component, IOComponent class Block: @@ -183,6 +183,12 @@ class Block: "style": self._style, } + @classmethod + def get_specific_update(cls, generic_update): + del generic_update["__type__"] + generic_update = cls.update(**generic_update) + return generic_update + class BlockContext(Block): def __init__( @@ -304,14 +310,39 @@ def update(**kwargs) -> dict: return kwargs -def is_update(val): - return type(val) is dict and "update" in val.get("__type__", "") - - def skip() -> dict: return update() +def postprocess_update_dict(block: Block, update_dict: Dict): + prediction_value = block.get_specific_update(update_dict) + if prediction_value.get("value") is components._Keywords.NO_VALUE: + prediction_value.pop("value") + prediction_value = delete_none(prediction_value, skip_value=True) + if "value" in prediction_value: + prediction_value["value"] = block.postprocess(prediction_value["value"]) + return prediction_value + + +def convert_update_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List: + keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()] + if all(keys_are_blocks): + reordered_predictions = [skip() for _ in outputs_ids] + for component, value in predictions.items(): + if component._id not in outputs_ids: + return ValueError( + f"Returned component {component} not specified as output of function." + ) + output_index = outputs_ids.index(component._id) + reordered_predictions[output_index] = value + predictions = utils.resolve_singleton(reordered_predictions) + elif any(keys_are_blocks): + raise ValueError( + "Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order." + ) + return predictions + + @document("load") class Blocks(BlockContext): """ @@ -656,21 +687,10 @@ class Blocks(BlockContext): dependency = self.dependencies[fn_index] if type(predictions) is dict and len(predictions) > 0: - keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()] - if all(keys_are_blocks): - reordered_predictions = [skip() for _ in dependency["outputs"]] - for component, value in predictions.items(): - if component._id not in dependency["outputs"]: - return ValueError( - f"Returned component {component} not specified as output of function." - ) - output_index = dependency["outputs"].index(component._id) - reordered_predictions[output_index] = value - predictions = utils.resolve_singleton(reordered_predictions) - elif any(keys_are_blocks): - raise ValueError( - "Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order." - ) + predictions = convert_update_dict_to_list( + dependency["outputs"], predictions + ) + if len(dependency["outputs"]) == 1: predictions = (predictions,) @@ -682,38 +702,15 @@ class Blocks(BlockContext): break block = self.blocks[output_id] if getattr(block, "stateful", False): - if not is_update(predictions[i]): + if not utils.is_update(predictions[i]): state[output_id] = predictions[i] output.append(None) else: prediction_value = predictions[i] - if is_update(prediction_value): - if prediction_value["__type__"] == "generic_update": - del prediction_value["__type__"] - prediction_value = block.__class__.update( - **prediction_value - ) - # If the prediction is the default (NO_VALUE) enum then the user did - # not specify a value for the 'value' key and we can get rid of it - if ( - prediction_value.get("value") - == components._Keywords.NO_VALUE - ): - prediction_value.pop("value") - prediction_value = delete_none( - prediction_value, skip_value=True - ) - if "value" in prediction_value: - prediction_value["value"] = block.postprocess( - prediction_value["value"] - ) - output_value = prediction_value + if utils.is_update(prediction_value): + output_value = postprocess_update_dict(block, prediction_value) else: - output_value = ( - block.postprocess(prediction_value) - if prediction_value is not None - else None - ) + output_value = block.postprocess(prediction_value) output.append(output_value) else: diff --git a/gradio/examples.py b/gradio/examples.py index 903061ec1b..32ef4ec428 100644 --- a/gradio/examples.py +++ b/gradio/examples.py @@ -3,6 +3,7 @@ Defines helper methods useful for loading and caching Interface examples. """ from __future__ import annotations +import ast import csv import inspect import os @@ -13,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional import anyio from gradio import utils +from gradio.blocks import convert_update_dict_to_list, postprocess_update_dict from gradio.components import Dataset from gradio.context import Context from gradio.documentation import document, set_documentation_group @@ -188,11 +190,8 @@ class Examples: for example in self.processed_examples ] if cache_examples: - for ex in non_none_examples: - if ( - len([sample for sample in ex if sample is not None]) - != self.inputs_with_examples - ): + for example in self.examples: + if len([ex for ex in example if ex is not None]) != len(self.inputs): warnings.warn( "Examples are being cached but not all input components have " "example values. This may result in an exception being thrown by " @@ -275,13 +274,23 @@ class Examples: predictions = await self.fn(*processed_input) else: predictions = await anyio.to_thread.run_sync(self.fn, *processed_input) + + output_ids = [output._id for output in self.outputs] + if type(predictions) is dict and len(predictions) > 0: + predictions = convert_update_dict_to_list(output_ids, predictions) + if len(self.outputs) == 1: predictions = [predictions] if not self._api_mode: - predictions = [ - output_component.postprocess(predictions[i]) - for i, output_component in enumerate(self.outputs) - ] + predictions_ = [] + for i, output_component in enumerate(self.outputs): + if utils.is_update(predictions[i]): + predictions_.append( + postprocess_update_dict(output_component, predictions[i]) + ) + else: + predictions_.append(output_component.postprocess(predictions[i])) + predictions = predictions_ return predictions async def load_from_cache(self, example_id: int) -> List[Any]: @@ -294,5 +303,10 @@ class Examples: example = examples[example_id + 1] # +1 to adjust for header output = [] for component, value in zip(self.outputs, example): - output.append(component.serialize(value, self.cached_folder)) + try: + value_as_dict = ast.literal_eval(value) + assert utils.is_update(value_as_dict) + output.append(value_as_dict) + except (ValueError, TypeError, SyntaxError, AssertionError): + output.append(component.serialize(value, self.cached_folder)) return output diff --git a/gradio/flagging.py b/gradio/flagging.py index 31f945999d..a679ad3068 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -206,15 +206,18 @@ class CSVLogger(FlaggingCallback): component.label or f"component {idx}" ), ) - csv_data.append( - component.deserialize( - sample, - save_dir=save_dir, - encryption_key=self.encryption_key, + if utils.is_update(sample): + csv_data.append(str(sample)) + else: + csv_data.append( + component.deserialize( + sample, + save_dir=save_dir, + encryption_key=self.encryption_key, + ) + if sample is not None + else "" ) - if sample is not None - else "" - ) csv_data.append(flag_option if flag_option is not None else "") csv_data.append(username if username is not None else "") csv_data.append(str(datetime.datetime.now())) diff --git a/gradio/utils.py b/gradio/utils.py index d917b7dc2e..627206a3a3 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import ast import asyncio import copy import inspect @@ -675,3 +676,7 @@ def validate_url(possible_url: str) -> bool: except Exception: pass return False + + +def is_update(val): + return type(val) is dict and "update" in val.get("__type__", "") diff --git a/test/test_blocks.py b/test/test_blocks.py index b9637cebaf..4d93db637b 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -408,5 +408,40 @@ class TestCallFunction: assert output["prediction"] == (0, 3) +class TestSpecificUpdate: + def test_without_update(self): + with pytest.raises(KeyError): + gr.Textbox.get_specific_update({"lines": 4}) + + def test_with_update(self): + specific_update = gr.Textbox.get_specific_update( + {"lines": 4, "__type__": "update"} + ) + assert specific_update == { + "lines": 4, + "max_lines": None, + "placeholder": None, + "label": None, + "show_label": None, + "visible": None, + "value": gr.components._Keywords.NO_VALUE, + "__type__": "update", + } + + def test_with_generic_update(self): + specific_update = gr.Video.get_specific_update( + {"visible": True, "value": "test.mp4", "__type__": "generic_update"} + ) + assert specific_update == { + "source": None, + "label": None, + "show_label": None, + "interactive": None, + "visible": True, + "value": "test.mp4", + "__type__": "update", + } + + if __name__ == "__main__": unittest.main() diff --git a/test/test_examples.py b/test/test_examples.py index 4d32ba549d..fd44ce476f 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1,4 +1,5 @@ import os +import tempfile from unittest.mock import patch import pytest @@ -61,6 +62,7 @@ class TestExamplesDataset: assert examples.dataset.headers == ["im", ""] +@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp()) class TestProcessExamples: @pytest.mark.asyncio async def test_predict_example(self): @@ -85,18 +87,54 @@ class TestProcessExamples: "text", examples=[["World"], ["Dunya"], ["Monde"]], ) - io.launch(prevent_thread_lock=True) await io.examples_handler.cache_interface_examples() prediction = await io.examples_handler.load_from_cache(1) - io.close() assert prediction[0] == "Hello Dunya" + @pytest.mark.asyncio + async def test_caching_with_update(self): + io = gr.Interface( + lambda x: gr.update(visible=False), + "text", + "image", + examples=[["World"], ["Dunya"], ["Monde"]], + ) + await io.examples_handler.cache_interface_examples() + prediction = await io.examples_handler.load_from_cache(1) + assert prediction[0] == {"visible": False, "__type__": "update"} -def test_raise_helpful_error_message_if_providing_partial_examples(tmp_path): - def foo(a, b): - return a + b + @pytest.mark.asyncio + async def test_caching_with_mix_update(self): + io = gr.Interface( + lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"], + "text", + ["text", "image"], + examples=[["World"], ["Dunya"], ["Monde"]], + ) + await io.examples_handler.cache_interface_examples() + prediction = await io.examples_handler.load_from_cache(1) + assert prediction[0] == {"lines": 4, "value": "hello", "__type__": "update"} + + @pytest.mark.asyncio + async def test_caching_with_dict(self): + text = gr.Textbox() + out = gr.Label() + + io = gr.Interface( + lambda _: {text: gr.update(lines=4), out: "lion"}, + "textbox", + [text, out], + examples=["abc"], + cache_examples=True, + ) + await io.examples_handler.cache_interface_examples() + prediction = await io.examples_handler.load_from_cache(0) + assert prediction == [{"lines": 4, "__type__": "update"}, {"label": "lion"}] + + def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path): + def foo(a, b): + return a + b - with patch("gradio.examples.CACHED_FOLDER", tmp_path): with pytest.warns( UserWarning, match="^Examples are being cached but not all input components have",