diff --git a/.changeset/real-words-listen.md b/.changeset/real-words-listen.md new file mode 100644 index 0000000000..46e5d72396 --- /dev/null +++ b/.changeset/real-words-listen.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fix example loading issue diff --git a/gradio/helpers.py b/gradio/helpers.py index ce23efe3b6..df12310a19 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -295,11 +295,25 @@ class Examples: ) def _get_processed_example(self, example): + """ + This function is used to get the post-processed example values, ready to be used + in the frontend for each input component. For example, if the input components are + image components, the post-processed example values will be the a list of ImageData dictionaries + with the path, url, size, mime_type, orig_name, and is_stream keys. For any input components + that should be skipped (b/c they are None for all samples), they will simply be absent + from the returned list + + Parameters: + example: a list of example values for each input component, excluding those components + that have all None values + """ if example in self.non_none_processed_examples: return self.non_none_processed_examples[example] with utils.set_directory(self.working_directory): sub = [] - for component, sample in zip(self.inputs, example, strict=False): + for component, sample in zip( + self.inputs_with_examples, example, strict=False + ): prediction_value = component.postprocess(sample) if isinstance(prediction_value, (GradioRootModel, GradioModel)): prediction_value = prediction_value.model_dump() @@ -309,9 +323,7 @@ class Examples: postprocess=True, ) sub.append(prediction_value) - return [ - ex for (ex, keep) in zip(sub, self.input_has_examples, strict=False) if keep - ] + return sub def create(self) -> None: """Creates the Dataset component to hold the examples""" @@ -332,7 +344,7 @@ class Examples: self.cache_event = self.load_input_event = self.dataset.click( load_example_with_output, inputs=[self.dataset], - outputs=self.inputs_with_examples + self.outputs, # type: ignore + outputs=self.inputs + self.outputs, # type: ignore show_progress="hidden", postprocess=False, queue=False, @@ -360,7 +372,7 @@ class Examples: self.load_input_event = self.dataset.click( load_example, inputs=[self.dataset], - outputs=self.inputs_with_examples, # type: ignore + outputs=self.inputs_with_examples, show_progress="hidden", postprocess=False, queue=False, @@ -378,8 +390,8 @@ class Examples: ) self.load_input_event.then( self.fn, - inputs=self.inputs, # type: ignore - outputs=self.outputs, # type: ignore + inputs=self.inputs, + outputs=self.outputs, show_api=False, ) else: @@ -546,8 +558,8 @@ class Examples: _, fn_index = self.root_block.default_config.set_event_trigger( [EventListenerMethod(Context.root_block, "load")], fn=fn, - inputs=self.inputs_with_examples, # type: ignore - outputs=self.outputs, # type: ignore + inputs=self.inputs, + outputs=self.outputs, preprocess=self.preprocess and not self._api_mode, postprocess=self.postprocess and not self._api_mode, batch=self.batch, @@ -555,11 +567,13 @@ class Examples: if self.outputs is None: raise ValueError("self.outputs is missing") - for i, example in enumerate(self.examples): + for i, example in enumerate(self.non_none_examples): if example_id is not None and i != example_id: continue - print(f"Caching example {i + 1}/{len(self.examples)}") processed_input = self._get_processed_example(example) + for index, keep in enumerate(self.input_has_examples): + if not keep: + processed_input.insert(index, None) if self.batch: processed_input = [[value] for value in processed_input] with utils.MatplotlibBackendMananger(): diff --git a/test/test_helpers.py b/test/test_helpers.py index ec27f37f2a..2a5c6bf62c 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -932,3 +932,35 @@ def test_check_event_data_in_cache(): }, ), ) + + +def test_examples_no_cache_optional_inputs(): + def foo(a, b, c, d): + return {"a": a, "b": b, "c": c, "d": d} + + io = gr.Interface( + foo, + ["text", "text", "text", "text"], + "json", + cache_examples=False, + examples=[["a", "b", None, "d"], ["a", "b", None, "de"]], + ) + + try: + app, _, _ = io.launch(prevent_thread_lock=True) + + client = TestClient(app) + with client as c: + for i in range(2): + response = c.post( + f"{API_PREFIX}/run/predict/", + json={ + "data": [i], + "fn_index": 6, + "trigger_id": 19, + "session_hash": "test", + }, + ) + assert response.status_code == 200 + finally: + io.close()