From 5d61c7b70131ed0a7e73b883b687b7df5255a17b Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 4 Dec 2024 16:00:28 -0600 Subject: [PATCH] Fix lazy caching (#10124) * fixes * add changeset * changes * add changeset * test --------- Co-authored-by: gradio-pr-bot --- .changeset/lucky-rings-like.md | 5 ++ gradio/helpers.py | 86 ++++++---------------------------- test/test_helpers.py | 8 ++-- 3 files changed, 24 insertions(+), 75 deletions(-) create mode 100644 .changeset/lucky-rings-like.md diff --git a/.changeset/lucky-rings-like.md b/.changeset/lucky-rings-like.md new file mode 100644 index 0000000000..bde54a1c6f --- /dev/null +++ b/.changeset/lucky-rings-like.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fix lazy caching diff --git a/gradio/helpers.py b/gradio/helpers.py index df12310a19..a2109fbdc8 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -294,6 +294,17 @@ class Examples: self._get_processed_example(example) ) + if self.cache_examples == "lazy": + print( + f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use.", + end="", + ) + if Path(self.cached_file).exists(): + print( + "If method or examples have changed since last caching, delete this folder to reset cache." + ) + print("\n") + def _get_processed_example(self, example): """ This function is used to get the post-processed example values, ready to be used @@ -332,7 +343,7 @@ class Examples: if self.root_block: self.root_block.extra_startup_events.append(self._start_caching) - if self.cache_examples == True: # noqa: E712 + if self.cache_examples: def load_example_with_output(example_tuple): example_id, example_value = example_tuple @@ -380,10 +391,7 @@ class Examples: show_api=False, ) - if self.cache_examples == "lazy": - self.lazy_cache() - - if self.run_on_click and self.cache_examples == False: # noqa: E712 + if self.run_on_click: if self.fn is None: raise ValueError( "Cannot run_on_click if no function is provided" @@ -448,71 +456,6 @@ class Examples: else: await self.cache() - def lazy_cache(self) -> None: - print( - f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use. ", - end="", - ) - if Path(self.cached_file).exists(): - print( - "If method or examples have changed since last caching, delete this folder to reset cache.", - end="", - ) - print("\n\n") - self.cache_logger.setup(self.outputs, self.cached_folder) - if inspect.iscoroutinefunction(self.fn) or inspect.isasyncgenfunction(self.fn): - lazy_cache_fn = self.async_lazy_cache - else: - lazy_cache_fn = self.sync_lazy_cache - self.cache_event = self.load_input_event.then( - lazy_cache_fn, - inputs=[self.dataset] + list(self.inputs), - outputs=self.outputs, - postprocess=False, - api_name=self.api_name, - show_api=False, - ) - - async def async_lazy_cache( - self, example_value: tuple[int, list[Any]], *input_values - ): - example_index, _ = example_value - cached_index = self._get_cached_index_if_cached(example_index) - if cached_index is not None: - output = self.load_from_cache(cached_index) - yield output[0] if len(self.outputs) == 1 else output - return - output = [None] * len(self.outputs) - if inspect.isasyncgenfunction(self.fn): - fn = self.fn - else: - fn = utils.async_fn_to_generator(self.fn) - async for output in fn(*input_values): - output = await self._postprocess_output(output) - yield output[0] if len(self.outputs) == 1 else output - self.cache_logger.flag(output) - with open(self.cached_indices_file, "a") as f: - f.write(f"{example_index}\n") - - def sync_lazy_cache(self, example_value: tuple[int, list[Any]], *input_values): - example_index, _ = example_value - cached_index = self._get_cached_index_if_cached(example_index) - if cached_index is not None: - output = self.load_from_cache(cached_index) - yield output[0] if len(self.outputs) == 1 else output - return - output = [None] * len(self.outputs) - if inspect.isgeneratorfunction(self.fn): - fn = self.fn - else: - fn = utils.sync_fn_to_generator(self.fn) - for output in fn(*input_values): - output = client_utils.synchronize_async(self._postprocess_output, output) - yield output[0] if len(self.outputs) == 1 else output - self.cache_logger.flag(output) - with open(self.cached_indices_file, "a") as f: - f.write(f"{example_index}\n") - async def cache(self, example_id: int | None = None) -> None: """ Caches examples so that their predictions can be shown immediately. @@ -599,7 +542,7 @@ class Examples: example_id: The id of the example to process (zero-indexed). """ if self.cache_examples == "lazy": - if cached_index := self._get_cached_index_if_cached(example_id) is None: + if (cached_index := self._get_cached_index_if_cached(example_id)) is None: client_utils.synchronize_async(self.cache, example_id) with open(self.cached_indices_file, "a") as f: f.write(f"{example_id}\n") @@ -610,6 +553,7 @@ class Examples: with open(self.cached_file, encoding="utf-8") as cache: examples = list(csv.reader(cache)) + example = examples[example_id + 1] # +1 to adjust for header output = [] if self.outputs is None: diff --git a/test/test_helpers.py b/test/test_helpers.py index 2a5c6bf62c..e811e11658 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -650,13 +650,13 @@ class TestProcessExamples: response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]}) data = response.json()["data"] - assert data[0]["value"]["path"].endswith("cheetah1.jpg") - assert data[1]["value"] == "cheetah" + assert data[0]["path"].endswith("cheetah1.jpg") + assert data[1] == "cheetah" response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]}) data = response.json()["data"] - assert data[0]["value"]["path"].endswith("bus.png") - assert data[1]["value"] == "bus" + assert data[0]["path"].endswith("bus.png") + assert data[1] == "bus" def test_multiple_file_flagging(tmp_path, connect):