mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Fix example loading issue (#10038)
* fix * Add code * add changeset * Fix bug * lint * add changeset * fix both cache examples=False,True * format --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
458941c508
commit
7d134e0b30
5
.changeset/real-words-listen.md
Normal file
5
.changeset/real-words-listen.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix example loading issue
|
@ -295,11 +295,25 @@ class Examples:
|
||||
)
|
||||
|
||||
def _get_processed_example(self, example):
|
||||
"""
|
||||
This function is used to get the post-processed example values, ready to be used
|
||||
in the frontend for each input component. For example, if the input components are
|
||||
image components, the post-processed example values will be the a list of ImageData dictionaries
|
||||
with the path, url, size, mime_type, orig_name, and is_stream keys. For any input components
|
||||
that should be skipped (b/c they are None for all samples), they will simply be absent
|
||||
from the returned list
|
||||
|
||||
Parameters:
|
||||
example: a list of example values for each input component, excluding those components
|
||||
that have all None values
|
||||
"""
|
||||
if example in self.non_none_processed_examples:
|
||||
return self.non_none_processed_examples[example]
|
||||
with utils.set_directory(self.working_directory):
|
||||
sub = []
|
||||
for component, sample in zip(self.inputs, example, strict=False):
|
||||
for component, sample in zip(
|
||||
self.inputs_with_examples, example, strict=False
|
||||
):
|
||||
prediction_value = component.postprocess(sample)
|
||||
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
|
||||
prediction_value = prediction_value.model_dump()
|
||||
@ -309,9 +323,7 @@ class Examples:
|
||||
postprocess=True,
|
||||
)
|
||||
sub.append(prediction_value)
|
||||
return [
|
||||
ex for (ex, keep) in zip(sub, self.input_has_examples, strict=False) if keep
|
||||
]
|
||||
return sub
|
||||
|
||||
def create(self) -> None:
|
||||
"""Creates the Dataset component to hold the examples"""
|
||||
@ -332,7 +344,7 @@ class Examples:
|
||||
self.cache_event = self.load_input_event = self.dataset.click(
|
||||
load_example_with_output,
|
||||
inputs=[self.dataset],
|
||||
outputs=self.inputs_with_examples + self.outputs, # type: ignore
|
||||
outputs=self.inputs + self.outputs, # type: ignore
|
||||
show_progress="hidden",
|
||||
postprocess=False,
|
||||
queue=False,
|
||||
@ -360,7 +372,7 @@ class Examples:
|
||||
self.load_input_event = self.dataset.click(
|
||||
load_example,
|
||||
inputs=[self.dataset],
|
||||
outputs=self.inputs_with_examples, # type: ignore
|
||||
outputs=self.inputs_with_examples,
|
||||
show_progress="hidden",
|
||||
postprocess=False,
|
||||
queue=False,
|
||||
@ -378,8 +390,8 @@ class Examples:
|
||||
)
|
||||
self.load_input_event.then(
|
||||
self.fn,
|
||||
inputs=self.inputs, # type: ignore
|
||||
outputs=self.outputs, # type: ignore
|
||||
inputs=self.inputs,
|
||||
outputs=self.outputs,
|
||||
show_api=False,
|
||||
)
|
||||
else:
|
||||
@ -546,8 +558,8 @@ class Examples:
|
||||
_, fn_index = self.root_block.default_config.set_event_trigger(
|
||||
[EventListenerMethod(Context.root_block, "load")],
|
||||
fn=fn,
|
||||
inputs=self.inputs_with_examples, # type: ignore
|
||||
outputs=self.outputs, # type: ignore
|
||||
inputs=self.inputs,
|
||||
outputs=self.outputs,
|
||||
preprocess=self.preprocess and not self._api_mode,
|
||||
postprocess=self.postprocess and not self._api_mode,
|
||||
batch=self.batch,
|
||||
@ -555,11 +567,13 @@ class Examples:
|
||||
|
||||
if self.outputs is None:
|
||||
raise ValueError("self.outputs is missing")
|
||||
for i, example in enumerate(self.examples):
|
||||
for i, example in enumerate(self.non_none_examples):
|
||||
if example_id is not None and i != example_id:
|
||||
continue
|
||||
print(f"Caching example {i + 1}/{len(self.examples)}")
|
||||
processed_input = self._get_processed_example(example)
|
||||
for index, keep in enumerate(self.input_has_examples):
|
||||
if not keep:
|
||||
processed_input.insert(index, None)
|
||||
if self.batch:
|
||||
processed_input = [[value] for value in processed_input]
|
||||
with utils.MatplotlibBackendMananger():
|
||||
|
@ -932,3 +932,35 @@ def test_check_event_data_in_cache():
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_examples_no_cache_optional_inputs():
|
||||
def foo(a, b, c, d):
|
||||
return {"a": a, "b": b, "c": c, "d": d}
|
||||
|
||||
io = gr.Interface(
|
||||
foo,
|
||||
["text", "text", "text", "text"],
|
||||
"json",
|
||||
cache_examples=False,
|
||||
examples=[["a", "b", None, "d"], ["a", "b", None, "de"]],
|
||||
)
|
||||
|
||||
try:
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
|
||||
client = TestClient(app)
|
||||
with client as c:
|
||||
for i in range(2):
|
||||
response = c.post(
|
||||
f"{API_PREFIX}/run/predict/",
|
||||
json={
|
||||
"data": [i],
|
||||
"fn_index": 6,
|
||||
"trigger_id": 19,
|
||||
"session_hash": "test",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
finally:
|
||||
io.close()
|
||||
|
Loading…
Reference in New Issue
Block a user