Fix example loading issue (#10038)

* fix

* Add code

* add changeset

* Fix bug

* lint

* add changeset

* fix both cache examples=False,True

* format

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-11-27 15:11:06 -05:00 committed by GitHub
parent 458941c508
commit 7d134e0b30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 12 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix example loading issue

View File

@ -295,11 +295,25 @@ class Examples:
)
def _get_processed_example(self, example):
"""
This function is used to get the post-processed example values, ready to be used
in the frontend for each input component. For example, if the input components are
image components, the post-processed example values will be the a list of ImageData dictionaries
with the path, url, size, mime_type, orig_name, and is_stream keys. For any input components
that should be skipped (b/c they are None for all samples), they will simply be absent
from the returned list
Parameters:
example: a list of example values for each input component, excluding those components
that have all None values
"""
if example in self.non_none_processed_examples:
return self.non_none_processed_examples[example]
with utils.set_directory(self.working_directory):
sub = []
for component, sample in zip(self.inputs, example, strict=False):
for component, sample in zip(
self.inputs_with_examples, example, strict=False
):
prediction_value = component.postprocess(sample)
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
prediction_value = prediction_value.model_dump()
@ -309,9 +323,7 @@ class Examples:
postprocess=True,
)
sub.append(prediction_value)
return [
ex for (ex, keep) in zip(sub, self.input_has_examples, strict=False) if keep
]
return sub
def create(self) -> None:
"""Creates the Dataset component to hold the examples"""
@ -332,7 +344,7 @@ class Examples:
self.cache_event = self.load_input_event = self.dataset.click(
load_example_with_output,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
outputs=self.inputs + self.outputs, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
@ -360,7 +372,7 @@ class Examples:
self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples, # type: ignore
outputs=self.inputs_with_examples,
show_progress="hidden",
postprocess=False,
queue=False,
@ -378,8 +390,8 @@ class Examples:
)
self.load_input_event.then(
self.fn,
inputs=self.inputs, # type: ignore
outputs=self.outputs, # type: ignore
inputs=self.inputs,
outputs=self.outputs,
show_api=False,
)
else:
@ -546,8 +558,8 @@ class Examples:
_, fn_index = self.root_block.default_config.set_event_trigger(
[EventListenerMethod(Context.root_block, "load")],
fn=fn,
inputs=self.inputs_with_examples, # type: ignore
outputs=self.outputs, # type: ignore
inputs=self.inputs,
outputs=self.outputs,
preprocess=self.preprocess and not self._api_mode,
postprocess=self.postprocess and not self._api_mode,
batch=self.batch,
@ -555,11 +567,13 @@ class Examples:
if self.outputs is None:
raise ValueError("self.outputs is missing")
for i, example in enumerate(self.examples):
for i, example in enumerate(self.non_none_examples):
if example_id is not None and i != example_id:
continue
print(f"Caching example {i + 1}/{len(self.examples)}")
processed_input = self._get_processed_example(example)
for index, keep in enumerate(self.input_has_examples):
if not keep:
processed_input.insert(index, None)
if self.batch:
processed_input = [[value] for value in processed_input]
with utils.MatplotlibBackendMananger():

View File

@ -932,3 +932,35 @@ def test_check_event_data_in_cache():
},
),
)
def test_examples_no_cache_optional_inputs():
def foo(a, b, c, d):
return {"a": a, "b": b, "c": c, "d": d}
io = gr.Interface(
foo,
["text", "text", "text", "text"],
"json",
cache_examples=False,
examples=[["a", "b", None, "d"], ["a", "b", None, "de"]],
)
try:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
with client as c:
for i in range(2):
response = c.post(
f"{API_PREFIX}/run/predict/",
json={
"data": [i],
"fn_index": 6,
"trigger_id": 19,
"session_hash": "test",
},
)
assert response.status_code == 200
finally:
io.close()