mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Fix lazy caching (#10124)
* fixes * add changeset * changes * add changeset * test --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
b02c8b7d4f
commit
5d61c7b701
5
.changeset/lucky-rings-like.md
Normal file
5
.changeset/lucky-rings-like.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix lazy caching
|
@ -294,6 +294,17 @@ class Examples:
|
||||
self._get_processed_example(example)
|
||||
)
|
||||
|
||||
if self.cache_examples == "lazy":
|
||||
print(
|
||||
f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use.",
|
||||
end="",
|
||||
)
|
||||
if Path(self.cached_file).exists():
|
||||
print(
|
||||
"If method or examples have changed since last caching, delete this folder to reset cache."
|
||||
)
|
||||
print("\n")
|
||||
|
||||
def _get_processed_example(self, example):
|
||||
"""
|
||||
This function is used to get the post-processed example values, ready to be used
|
||||
@ -332,7 +343,7 @@ class Examples:
|
||||
if self.root_block:
|
||||
self.root_block.extra_startup_events.append(self._start_caching)
|
||||
|
||||
if self.cache_examples == True: # noqa: E712
|
||||
if self.cache_examples:
|
||||
|
||||
def load_example_with_output(example_tuple):
|
||||
example_id, example_value = example_tuple
|
||||
@ -380,10 +391,7 @@ class Examples:
|
||||
show_api=False,
|
||||
)
|
||||
|
||||
if self.cache_examples == "lazy":
|
||||
self.lazy_cache()
|
||||
|
||||
if self.run_on_click and self.cache_examples == False: # noqa: E712
|
||||
if self.run_on_click:
|
||||
if self.fn is None:
|
||||
raise ValueError(
|
||||
"Cannot run_on_click if no function is provided"
|
||||
@ -448,71 +456,6 @@ class Examples:
|
||||
else:
|
||||
await self.cache()
|
||||
|
||||
def lazy_cache(self) -> None:
|
||||
print(
|
||||
f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use. ",
|
||||
end="",
|
||||
)
|
||||
if Path(self.cached_file).exists():
|
||||
print(
|
||||
"If method or examples have changed since last caching, delete this folder to reset cache.",
|
||||
end="",
|
||||
)
|
||||
print("\n\n")
|
||||
self.cache_logger.setup(self.outputs, self.cached_folder)
|
||||
if inspect.iscoroutinefunction(self.fn) or inspect.isasyncgenfunction(self.fn):
|
||||
lazy_cache_fn = self.async_lazy_cache
|
||||
else:
|
||||
lazy_cache_fn = self.sync_lazy_cache
|
||||
self.cache_event = self.load_input_event.then(
|
||||
lazy_cache_fn,
|
||||
inputs=[self.dataset] + list(self.inputs),
|
||||
outputs=self.outputs,
|
||||
postprocess=False,
|
||||
api_name=self.api_name,
|
||||
show_api=False,
|
||||
)
|
||||
|
||||
async def async_lazy_cache(
|
||||
self, example_value: tuple[int, list[Any]], *input_values
|
||||
):
|
||||
example_index, _ = example_value
|
||||
cached_index = self._get_cached_index_if_cached(example_index)
|
||||
if cached_index is not None:
|
||||
output = self.load_from_cache(cached_index)
|
||||
yield output[0] if len(self.outputs) == 1 else output
|
||||
return
|
||||
output = [None] * len(self.outputs)
|
||||
if inspect.isasyncgenfunction(self.fn):
|
||||
fn = self.fn
|
||||
else:
|
||||
fn = utils.async_fn_to_generator(self.fn)
|
||||
async for output in fn(*input_values):
|
||||
output = await self._postprocess_output(output)
|
||||
yield output[0] if len(self.outputs) == 1 else output
|
||||
self.cache_logger.flag(output)
|
||||
with open(self.cached_indices_file, "a") as f:
|
||||
f.write(f"{example_index}\n")
|
||||
|
||||
def sync_lazy_cache(self, example_value: tuple[int, list[Any]], *input_values):
|
||||
example_index, _ = example_value
|
||||
cached_index = self._get_cached_index_if_cached(example_index)
|
||||
if cached_index is not None:
|
||||
output = self.load_from_cache(cached_index)
|
||||
yield output[0] if len(self.outputs) == 1 else output
|
||||
return
|
||||
output = [None] * len(self.outputs)
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
fn = self.fn
|
||||
else:
|
||||
fn = utils.sync_fn_to_generator(self.fn)
|
||||
for output in fn(*input_values):
|
||||
output = client_utils.synchronize_async(self._postprocess_output, output)
|
||||
yield output[0] if len(self.outputs) == 1 else output
|
||||
self.cache_logger.flag(output)
|
||||
with open(self.cached_indices_file, "a") as f:
|
||||
f.write(f"{example_index}\n")
|
||||
|
||||
async def cache(self, example_id: int | None = None) -> None:
|
||||
"""
|
||||
Caches examples so that their predictions can be shown immediately.
|
||||
@ -599,7 +542,7 @@ class Examples:
|
||||
example_id: The id of the example to process (zero-indexed).
|
||||
"""
|
||||
if self.cache_examples == "lazy":
|
||||
if cached_index := self._get_cached_index_if_cached(example_id) is None:
|
||||
if (cached_index := self._get_cached_index_if_cached(example_id)) is None:
|
||||
client_utils.synchronize_async(self.cache, example_id)
|
||||
with open(self.cached_indices_file, "a") as f:
|
||||
f.write(f"{example_id}\n")
|
||||
@ -610,6 +553,7 @@ class Examples:
|
||||
|
||||
with open(self.cached_file, encoding="utf-8") as cache:
|
||||
examples = list(csv.reader(cache))
|
||||
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
if self.outputs is None:
|
||||
|
@ -650,13 +650,13 @@ class TestProcessExamples:
|
||||
|
||||
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]})
|
||||
data = response.json()["data"]
|
||||
assert data[0]["value"]["path"].endswith("cheetah1.jpg")
|
||||
assert data[1]["value"] == "cheetah"
|
||||
assert data[0]["path"].endswith("cheetah1.jpg")
|
||||
assert data[1] == "cheetah"
|
||||
|
||||
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]})
|
||||
data = response.json()["data"]
|
||||
assert data[0]["value"]["path"].endswith("bus.png")
|
||||
assert data[1]["value"] == "bus"
|
||||
assert data[0]["path"].endswith("bus.png")
|
||||
assert data[1] == "bus"
|
||||
|
||||
|
||||
def test_multiple_file_flagging(tmp_path, connect):
|
||||
|
Loading…
Reference in New Issue
Block a user