Fix lazy caching (#10124)

* fixes

* add changeset

* changes

* add changeset

* test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-12-04 16:00:28 -06:00 committed by GitHub
parent b02c8b7d4f
commit 5d61c7b701
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 75 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix lazy caching

View File

@ -294,6 +294,17 @@ class Examples:
self._get_processed_example(example)
)
if self.cache_examples == "lazy":
print(
f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use.",
end="",
)
if Path(self.cached_file).exists():
print(
"If method or examples have changed since last caching, delete this folder to reset cache."
)
print("\n")
def _get_processed_example(self, example):
"""
This function is used to get the post-processed example values, ready to be used
@ -332,7 +343,7 @@ class Examples:
if self.root_block:
self.root_block.extra_startup_events.append(self._start_caching)
if self.cache_examples == True: # noqa: E712
if self.cache_examples:
def load_example_with_output(example_tuple):
example_id, example_value = example_tuple
@ -380,10 +391,7 @@ class Examples:
show_api=False,
)
if self.cache_examples == "lazy":
self.lazy_cache()
if self.run_on_click and self.cache_examples == False: # noqa: E712
if self.run_on_click:
if self.fn is None:
raise ValueError(
"Cannot run_on_click if no function is provided"
@ -448,71 +456,6 @@ class Examples:
else:
await self.cache()
def lazy_cache(self) -> None:
print(
f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use. ",
end="",
)
if Path(self.cached_file).exists():
print(
"If method or examples have changed since last caching, delete this folder to reset cache.",
end="",
)
print("\n\n")
self.cache_logger.setup(self.outputs, self.cached_folder)
if inspect.iscoroutinefunction(self.fn) or inspect.isasyncgenfunction(self.fn):
lazy_cache_fn = self.async_lazy_cache
else:
lazy_cache_fn = self.sync_lazy_cache
self.cache_event = self.load_input_event.then(
lazy_cache_fn,
inputs=[self.dataset] + list(self.inputs),
outputs=self.outputs,
postprocess=False,
api_name=self.api_name,
show_api=False,
)
async def async_lazy_cache(
self, example_value: tuple[int, list[Any]], *input_values
):
example_index, _ = example_value
cached_index = self._get_cached_index_if_cached(example_index)
if cached_index is not None:
output = self.load_from_cache(cached_index)
yield output[0] if len(self.outputs) == 1 else output
return
output = [None] * len(self.outputs)
if inspect.isasyncgenfunction(self.fn):
fn = self.fn
else:
fn = utils.async_fn_to_generator(self.fn)
async for output in fn(*input_values):
output = await self._postprocess_output(output)
yield output[0] if len(self.outputs) == 1 else output
self.cache_logger.flag(output)
with open(self.cached_indices_file, "a") as f:
f.write(f"{example_index}\n")
def sync_lazy_cache(self, example_value: tuple[int, list[Any]], *input_values):
example_index, _ = example_value
cached_index = self._get_cached_index_if_cached(example_index)
if cached_index is not None:
output = self.load_from_cache(cached_index)
yield output[0] if len(self.outputs) == 1 else output
return
output = [None] * len(self.outputs)
if inspect.isgeneratorfunction(self.fn):
fn = self.fn
else:
fn = utils.sync_fn_to_generator(self.fn)
for output in fn(*input_values):
output = client_utils.synchronize_async(self._postprocess_output, output)
yield output[0] if len(self.outputs) == 1 else output
self.cache_logger.flag(output)
with open(self.cached_indices_file, "a") as f:
f.write(f"{example_index}\n")
async def cache(self, example_id: int | None = None) -> None:
"""
Caches examples so that their predictions can be shown immediately.
@ -599,7 +542,7 @@ class Examples:
example_id: The id of the example to process (zero-indexed).
"""
if self.cache_examples == "lazy":
if cached_index := self._get_cached_index_if_cached(example_id) is None:
if (cached_index := self._get_cached_index_if_cached(example_id)) is None:
client_utils.synchronize_async(self.cache, example_id)
with open(self.cached_indices_file, "a") as f:
f.write(f"{example_id}\n")
@ -610,6 +553,7 @@ class Examples:
with open(self.cached_file, encoding="utf-8") as cache:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
output = []
if self.outputs is None:

View File

@ -650,13 +650,13 @@ class TestProcessExamples:
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]})
data = response.json()["data"]
assert data[0]["value"]["path"].endswith("cheetah1.jpg")
assert data[1]["value"] == "cheetah"
assert data[0]["path"].endswith("cheetah1.jpg")
assert data[1] == "cheetah"
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]})
data = response.json()["data"]
assert data[0]["value"]["path"].endswith("bus.png")
assert data[1]["value"] == "bus"
assert data[0]["path"].endswith("bus.png")
assert data[1] == "bus"
def test_multiple_file_flagging(tmp_path, connect):