mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-05 11:10:03 +08:00
Some tweaks to the Client (#4230)
* parameter names * tweaks * separate out serialize * fix * changelog * fix * fix * improve test
This commit is contained in:
parent
6bace9765c
commit
d6c93228d9
@ -1,13 +1,23 @@
|
||||
# Upcoming Release
|
||||
|
||||
# 0.2.4
|
||||
|
||||
## New Features:
|
||||
|
||||
## Bug Fixes:
|
||||
- Fixes parameter names not showing underscores by [@abidlabs](https://github.com/abidlabs) in [PR 4230](https://github.com/gradio-app/gradio/pull/4230)
|
||||
- Fixes issue in which state was not handled correctly if `serialize=False` by [@abidlabs](https://github.com/abidlabs) in [PR 4230](https://github.com/gradio-app/gradio/pull/4230)
|
||||
|
||||
## Breaking Changes:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
# 0.2.4
|
||||
|
||||
## Bug Fixes:
|
||||
- Fixes missing serialization classes for several components: `Barplot`, `Lineplot`, `Scatterplot`, `AnnotatedImage`, `Interpretation` by [@abidlabs](https://github.com/freddyaboulton) in [PR 4167](https://github.com/gradio-app/gradio/pull/4167)
|
||||
- Fixes missing serialization classes for several components: `Barplot`, `Lineplot`, `Scatterplot`, `AnnotatedImage`, `Interpretation` by [@abidlabs](https://github.com/abidlabs) in [PR 4167](https://github.com/gradio-app/gradio/pull/4167)
|
||||
|
||||
## Documentation Changes:
|
||||
|
||||
|
@ -295,7 +295,7 @@ class Client:
|
||||
helper = Communicator(
|
||||
Lock(),
|
||||
JobStatus(),
|
||||
self.endpoints[inferred_fn_index].deserialize,
|
||||
self.endpoints[inferred_fn_index].process_predictions,
|
||||
self.reset_url,
|
||||
)
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
@ -439,7 +439,7 @@ class Client:
|
||||
human_info += self._render_endpoints_info(int(fn_index), endpoint_info)
|
||||
else:
|
||||
if num_unnamed_endpoints > 0:
|
||||
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(`all_endpoints=True`)\n"
|
||||
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(all_endpoints=True)\n"
|
||||
|
||||
if print_info:
|
||||
print(human_info)
|
||||
@ -616,11 +616,11 @@ class Endpoint:
|
||||
def _inner(*data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.serialize:
|
||||
data = self.serialize(*data)
|
||||
predictions = _predict(*data)
|
||||
if self.client.serialize:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
# Append final output only if not already present
|
||||
# for consistency between generators and not generators
|
||||
if helper:
|
||||
@ -745,39 +745,22 @@ class Endpoint:
|
||||
data[i] = files[file_counter]
|
||||
file_counter += 1
|
||||
|
||||
def serialize(self, *data) -> tuple:
|
||||
def insert_state(self, *data) -> tuple:
|
||||
data = list(data)
|
||||
for i, input_component_type in enumerate(self.input_component_types):
|
||||
if input_component_type == utils.STATE_COMPONENT:
|
||||
data.insert(i, None)
|
||||
assert len(data) == len(
|
||||
self.serializers
|
||||
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
return tuple(data)
|
||||
|
||||
files = [
|
||||
f
|
||||
for f, t in zip(data, self.input_component_types)
|
||||
if t in ["file", "uploadbutton"]
|
||||
def remove_state(self, *data) -> tuple:
|
||||
data = [
|
||||
d
|
||||
for d, oct in zip(data, self.output_component_types)
|
||||
if oct != utils.STATE_COMPONENT
|
||||
]
|
||||
uploaded_files = self._upload(files)
|
||||
self._add_uploaded_files_to_data(uploaded_files, data)
|
||||
return tuple(data)
|
||||
|
||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
return o
|
||||
|
||||
def deserialize(self, *data) -> tuple | Any:
|
||||
assert len(data) == len(
|
||||
self.deserializers
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
|
||||
for s, d, oct in zip(
|
||||
self.deserializers, data, self.output_component_types
|
||||
)
|
||||
if oct != utils.STATE_COMPONENT
|
||||
]
|
||||
)
|
||||
def reduce_singleton_output(self, *data) -> Any:
|
||||
if (
|
||||
len(
|
||||
[
|
||||
@ -788,10 +771,44 @@ class Endpoint:
|
||||
)
|
||||
== 1
|
||||
):
|
||||
output = outputs[0]
|
||||
return data[0]
|
||||
else:
|
||||
output = outputs
|
||||
return output
|
||||
return data
|
||||
|
||||
def serialize(self, *data) -> tuple:
|
||||
assert len(data) == len(
|
||||
self.serializers
|
||||
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
|
||||
files = [
|
||||
f
|
||||
for f, t in zip(data, self.input_component_types)
|
||||
if t in ["file", "uploadbutton"]
|
||||
]
|
||||
uploaded_files = self._upload(files)
|
||||
self._add_uploaded_files_to_data(uploaded_files, list(data))
|
||||
|
||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
return o
|
||||
|
||||
def deserialize(self, *data) -> tuple:
|
||||
assert len(data) == len(
|
||||
self.deserializers
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
|
||||
for s, d in zip(self.deserializers, data)
|
||||
]
|
||||
)
|
||||
return outputs
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
if self.client.serialize:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_state(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
|
||||
def _setup_serializers(self) -> tuple[list[Serializable], list[Serializable]]:
|
||||
inputs = self.dependency["inputs"]
|
||||
|
@ -182,7 +182,7 @@ class Communicator:
|
||||
|
||||
lock: Lock
|
||||
job: JobStatus
|
||||
deserialize: Callable[..., tuple]
|
||||
prediction_processor: Callable[..., tuple]
|
||||
reset_url: str
|
||||
should_cancel: bool = False
|
||||
|
||||
@ -251,7 +251,7 @@ async def get_pred_from_ws(
|
||||
output = resp.get("output", {}).get("data", [])
|
||||
if output and status_update.code != Status.FINISHED:
|
||||
try:
|
||||
result = helper.deserialize(*output)
|
||||
result = helper.prediction_processor(*output)
|
||||
except Exception as e:
|
||||
result = [e]
|
||||
helper.job.outputs.append(result)
|
||||
@ -380,10 +380,10 @@ def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> st
|
||||
return filename
|
||||
|
||||
|
||||
def sanitize_parameter_names(original_param_name: str) -> str:
|
||||
"""Strips invalid characters from a parameter name and replaces spaces with underscores."""
|
||||
def sanitize_parameter_names(original_name: str) -> str:
|
||||
"""Cleans up a Python parameter name to make the API info more readable."""
|
||||
return (
|
||||
"".join([char for char in original_param_name if char.isalnum() or char in " "])
|
||||
"".join([char for char in original_name if char.isalnum() or char in " _"])
|
||||
.replace(" ", "_")
|
||||
.lower()
|
||||
)
|
||||
|
@ -1 +1 @@
|
||||
0.2.4
|
||||
0.2.5
|
||||
|
@ -181,6 +181,26 @@ def file_io_demo():
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stateful_chatbot():
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot()
|
||||
msg = gr.Textbox()
|
||||
clear = gr.Button("Clear")
|
||||
st = gr.State([1, 2, 3])
|
||||
|
||||
def respond(message, st, chat_history):
|
||||
assert st[0] == 1 and st[1] == 2 and st[2] == 3
|
||||
bot_message = "I love you"
|
||||
chat_history.append((message, bot_message))
|
||||
return "", chat_history
|
||||
|
||||
msg.submit(respond, [msg, st, chatbot], [msg, chatbot], api_name="submit")
|
||||
clear.click(lambda: None, None, chatbot, queue=False)
|
||||
demo.queue()
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_components():
|
||||
classes_to_check = gr.components.Component.__subclasses__()
|
||||
|
@ -23,10 +23,10 @@ HF_TOKEN = "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" # Intentionally reveali
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connect(demo: gr.Blocks):
|
||||
def connect(demo: gr.Blocks, serialize: bool = True):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||
try:
|
||||
yield Client(local_url)
|
||||
yield Client(local_url, serialize=serialize)
|
||||
finally:
|
||||
# A more verbose version of .close()
|
||||
# because we should set a timeout
|
||||
@ -39,7 +39,7 @@ def connect(demo: gr.Blocks):
|
||||
demo.server.thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestPredictionsFromSpaces:
|
||||
class TestClientPredictions:
|
||||
@pytest.mark.flaky
|
||||
def test_raise_error_invalid_state(self):
|
||||
with pytest.raises(ValueError, match="invalid state"):
|
||||
@ -172,7 +172,6 @@ class TestPredictionsFromSpaces:
|
||||
assert pathlib.Path(job.result()).exists()
|
||||
|
||||
def test_progress_updates(self, progress_demo):
|
||||
|
||||
with connect(progress_demo) as client:
|
||||
job = client.submit("hello", api_name="/predict")
|
||||
statuses = []
|
||||
@ -246,7 +245,6 @@ class TestPredictionsFromSpaces:
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_upload_file_private_space(self):
|
||||
|
||||
client = Client(
|
||||
src="gradio-tests/not-actually-private-file-upload", hf_token=HF_TOKEN
|
||||
)
|
||||
@ -308,11 +306,17 @@ class TestPredictionsFromSpaces:
|
||||
client.submit(1, "foo", f.name, fn_index=0).result()
|
||||
serialize.assert_called_once_with(1, "foo", f.name)
|
||||
|
||||
def test_state_without_serialize(self, stateful_chatbot):
|
||||
with connect(stateful_chatbot, serialize=False) as client:
|
||||
initial_history = [["", None]]
|
||||
message = "Hello"
|
||||
ret = client.predict(message, initial_history, api_name="/submit")
|
||||
assert ret == ("", [["", None], ["Hello", "I love you"]])
|
||||
|
||||
|
||||
class TestStatusUpdates:
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
messages = [
|
||||
@ -396,7 +400,6 @@ class TestStatusUpdates:
|
||||
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
messages_1 = [
|
||||
|
Loading…
Reference in New Issue
Block a user