diff --git a/client/python/CHANGELOG.md b/client/python/CHANGELOG.md index 29025914ae..079c937d03 100644 --- a/client/python/CHANGELOG.md +++ b/client/python/CHANGELOG.md @@ -1,13 +1,23 @@ # Upcoming Release -# 0.2.4 - ## New Features: +## Bug Fixes: +- Fixes parameter names not showing underscores by [@abidlabs](https://github.com/abidlabs) in [PR 4230](https://github.com/gradio-app/gradio/pull/4230) +- Fixes issue in which state was not handled correctly if `serialize=False` by [@abidlabs](https://github.com/abidlabs) in [PR 4230](https://github.com/gradio-app/gradio/pull/4230) + +## Breaking Changes: + No changes to highlight. +## Full Changelog: + +No changes to highlight. + +# 0.2.4 + ## Bug Fixes: -- Fixes missing serialization classes for several components: `Barplot`, `Lineplot`, `Scatterplot`, `AnnotatedImage`, `Interpretation` by [@abidlabs](https://github.com/freddyaboulton) in [PR 4167](https://github.com/gradio-app/gradio/pull/4167) +- Fixes missing serialization classes for several components: `Barplot`, `Lineplot`, `Scatterplot`, `AnnotatedImage`, `Interpretation` by [@abidlabs](https://github.com/abidlabs) in [PR 4167](https://github.com/gradio-app/gradio/pull/4167) ## Documentation Changes: diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 31f55ee888..db4a1f716f 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -295,7 +295,7 @@ class Client: helper = Communicator( Lock(), JobStatus(), - self.endpoints[inferred_fn_index].deserialize, + self.endpoints[inferred_fn_index].process_predictions, self.reset_url, ) end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) @@ -439,7 +439,7 @@ class Client: human_info += self._render_endpoints_info(int(fn_index), endpoint_info) else: if num_unnamed_endpoints > 0: - human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(`all_endpoints=True`)\n" + human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(all_endpoints=True)\n" if print_info: print(human_info) @@ -616,11 +616,11 @@ class Endpoint: def _inner(*data): if not self.is_valid: raise utils.InvalidAPIEndpointError() + data = self.insert_state(*data) if self.client.serialize: data = self.serialize(*data) predictions = _predict(*data) - if self.client.serialize: - predictions = self.deserialize(*predictions) + predictions = self.process_predictions(*predictions) # Append final output only if not already present # for consistency between generators and not generators if helper: @@ -745,39 +745,22 @@ class Endpoint: data[i] = files[file_counter] file_counter += 1 - def serialize(self, *data) -> tuple: + def insert_state(self, *data) -> tuple: data = list(data) for i, input_component_type in enumerate(self.input_component_types): if input_component_type == utils.STATE_COMPONENT: data.insert(i, None) - assert len(data) == len( - self.serializers - ), f"Expected {len(self.serializers)} arguments, got {len(data)}" + return tuple(data) - files = [ - f - for f, t in zip(data, self.input_component_types) - if t in ["file", "uploadbutton"] + def remove_state(self, *data) -> tuple: + data = [ + d + for d, oct in zip(data, self.output_component_types) + if oct != utils.STATE_COMPONENT ] - uploaded_files = self._upload(files) - self._add_uploaded_files_to_data(uploaded_files, data) + return tuple(data) - o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) - return o - - def deserialize(self, *data) -> tuple | Any: - assert len(data) == len( - self.deserializers - ), f"Expected {len(self.deserializers)} outputs, got {len(data)}" - outputs = tuple( - [ - s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url) - for s, d, oct in zip( - self.deserializers, data, self.output_component_types - ) - if oct != utils.STATE_COMPONENT - ] - ) + def reduce_singleton_output(self, *data) -> Any: if ( len( [ @@ -788,10 +771,44 @@ class Endpoint: ) == 1 ): - output = outputs[0] + return data[0] else: - output = outputs - return output + return data + + def serialize(self, *data) -> tuple: + assert len(data) == len( + self.serializers + ), f"Expected {len(self.serializers)} arguments, got {len(data)}" + + files = [ + f + for f, t in zip(data, self.input_component_types) + if t in ["file", "uploadbutton"] + ] + uploaded_files = self._upload(files) + self._add_uploaded_files_to_data(uploaded_files, list(data)) + + o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) + return o + + def deserialize(self, *data) -> tuple: + assert len(data) == len( + self.deserializers + ), f"Expected {len(self.deserializers)} outputs, got {len(data)}" + outputs = tuple( + [ + s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url) + for s, d in zip(self.deserializers, data) + ] + ) + return outputs + + def process_predictions(self, *predictions): + if self.client.serialize: + predictions = self.deserialize(*predictions) + predictions = self.remove_state(*predictions) + predictions = self.reduce_singleton_output(*predictions) + return predictions def _setup_serializers(self) -> tuple[list[Serializable], list[Serializable]]: inputs = self.dependency["inputs"] diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index efaff6bd81..069ffdc53b 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -182,7 +182,7 @@ class Communicator: lock: Lock job: JobStatus - deserialize: Callable[..., tuple] + prediction_processor: Callable[..., tuple] reset_url: str should_cancel: bool = False @@ -251,7 +251,7 @@ async def get_pred_from_ws( output = resp.get("output", {}).get("data", []) if output and status_update.code != Status.FINISHED: try: - result = helper.deserialize(*output) + result = helper.prediction_processor(*output) except Exception as e: result = [e] helper.job.outputs.append(result) @@ -380,10 +380,10 @@ def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> st return filename -def sanitize_parameter_names(original_param_name: str) -> str: - """Strips invalid characters from a parameter name and replaces spaces with underscores.""" +def sanitize_parameter_names(original_name: str) -> str: + """Cleans up a Python parameter name to make the API info more readable.""" return ( - "".join([char for char in original_param_name if char.isalnum() or char in " "]) + "".join([char for char in original_name if char.isalnum() or char in " _"]) .replace(" ", "_") .lower() ) diff --git a/client/python/gradio_client/version.txt b/client/python/gradio_client/version.txt index abd410582d..3a4036fb45 100644 --- a/client/python/gradio_client/version.txt +++ b/client/python/gradio_client/version.txt @@ -1 +1 @@ -0.2.4 +0.2.5 diff --git a/client/python/test/conftest.py b/client/python/test/conftest.py index a3b06a72ba..e9a5f30961 100644 --- a/client/python/test/conftest.py +++ b/client/python/test/conftest.py @@ -181,6 +181,26 @@ def file_io_demo(): return demo +@pytest.fixture +def stateful_chatbot(): + with gr.Blocks() as demo: + chatbot = gr.Chatbot() + msg = gr.Textbox() + clear = gr.Button("Clear") + st = gr.State([1, 2, 3]) + + def respond(message, st, chat_history): + assert st[0] == 1 and st[1] == 2 and st[2] == 3 + bot_message = "I love you" + chat_history.append((message, bot_message)) + return "", chat_history + + msg.submit(respond, [msg, st, chatbot], [msg, chatbot], api_name="submit") + clear.click(lambda: None, None, chatbot, queue=False) + demo.queue() + return demo + + @pytest.fixture def all_components(): classes_to_check = gr.components.Component.__subclasses__() diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index de0cf7323e..77a582319b 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -23,10 +23,10 @@ HF_TOKEN = "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" # Intentionally reveali @contextmanager -def connect(demo: gr.Blocks): +def connect(demo: gr.Blocks, serialize: bool = True): _, local_url, _ = demo.launch(prevent_thread_lock=True) try: - yield Client(local_url) + yield Client(local_url, serialize=serialize) finally: # A more verbose version of .close() # because we should set a timeout @@ -39,7 +39,7 @@ def connect(demo: gr.Blocks): demo.server.thread.join(timeout=1) -class TestPredictionsFromSpaces: +class TestClientPredictions: @pytest.mark.flaky def test_raise_error_invalid_state(self): with pytest.raises(ValueError, match="invalid state"): @@ -172,7 +172,6 @@ class TestPredictionsFromSpaces: assert pathlib.Path(job.result()).exists() def test_progress_updates(self, progress_demo): - with connect(progress_demo) as client: job = client.submit("hello", api_name="/predict") statuses = [] @@ -246,7 +245,6 @@ class TestPredictionsFromSpaces: @pytest.mark.flaky def test_upload_file_private_space(self): - client = Client( src="gradio-tests/not-actually-private-file-upload", hf_token=HF_TOKEN ) @@ -308,11 +306,17 @@ class TestPredictionsFromSpaces: client.submit(1, "foo", f.name, fn_index=0).result() serialize.assert_called_once_with(1, "foo", f.name) + def test_state_without_serialize(self, stateful_chatbot): + with connect(stateful_chatbot, serialize=False) as client: + initial_history = [["", None]] + message = "Hello" + ret = client.predict(message, initial_history, api_name="/submit") + assert ret == ("", [["", None], ["Hello", "I love you"]]) + class TestStatusUpdates: @patch("gradio_client.client.Endpoint.make_end_to_end_fn") def test_messages_passed_correctly(self, mock_make_end_to_end_fn): - now = datetime.now() messages = [ @@ -396,7 +400,6 @@ class TestStatusUpdates: @patch("gradio_client.client.Endpoint.make_end_to_end_fn") def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn): - now = datetime.now() messages_1 = [