Some tweaks to the Client (#4230)

* parameter names

* tweaks

* separate out serialize

* fix

* changelog

* fix

* fix

* improve test
This commit is contained in:
Abubakar Abid 2023-05-16 14:32:42 -04:00 committed by GitHub
parent 6bace9765c
commit d6c93228d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 49 deletions

View File

@ -1,13 +1,23 @@
# Upcoming Release
# 0.2.4
## New Features:
## Bug Fixes:
- Fixes parameter names not showing underscores by [@abidlabs](https://github.com/abidlabs) in [PR 4230](https://github.com/gradio-app/gradio/pull/4230)
- Fixes issue in which state was not handled correctly if `serialize=False` by [@abidlabs](https://github.com/abidlabs) in [PR 4230](https://github.com/gradio-app/gradio/pull/4230)
## Breaking Changes:
No changes to highlight.
## Full Changelog:
No changes to highlight.
# 0.2.4
## Bug Fixes:
- Fixes missing serialization classes for several components: `Barplot`, `Lineplot`, `Scatterplot`, `AnnotatedImage`, `Interpretation` by [@abidlabs](https://github.com/freddyaboulton) in [PR 4167](https://github.com/gradio-app/gradio/pull/4167)
- Fixes missing serialization classes for several components: `Barplot`, `Lineplot`, `Scatterplot`, `AnnotatedImage`, `Interpretation` by [@abidlabs](https://github.com/abidlabs) in [PR 4167](https://github.com/gradio-app/gradio/pull/4167)
## Documentation Changes:

View File

@ -295,7 +295,7 @@ class Client:
helper = Communicator(
Lock(),
JobStatus(),
self.endpoints[inferred_fn_index].deserialize,
self.endpoints[inferred_fn_index].process_predictions,
self.reset_url,
)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
@ -439,7 +439,7 @@ class Client:
human_info += self._render_endpoints_info(int(fn_index), endpoint_info)
else:
if num_unnamed_endpoints > 0:
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(`all_endpoints=True`)\n"
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(all_endpoints=True)\n"
if print_info:
print(human_info)
@ -616,11 +616,11 @@ class Endpoint:
def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.serialize:
data = self.serialize(*data)
predictions = _predict(*data)
if self.client.serialize:
predictions = self.deserialize(*predictions)
predictions = self.process_predictions(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
@ -745,39 +745,22 @@ class Endpoint:
data[i] = files[file_counter]
file_counter += 1
def serialize(self, *data) -> tuple:
def insert_state(self, *data) -> tuple:
data = list(data)
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type == utils.STATE_COMPONENT:
data.insert(i, None)
assert len(data) == len(
self.serializers
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
return tuple(data)
files = [
f
for f, t in zip(data, self.input_component_types)
if t in ["file", "uploadbutton"]
def remove_state(self, *data) -> tuple:
data = [
d
for d, oct in zip(data, self.output_component_types)
if oct != utils.STATE_COMPONENT
]
uploaded_files = self._upload(files)
self._add_uploaded_files_to_data(uploaded_files, data)
return tuple(data)
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o
def deserialize(self, *data) -> tuple | Any:
assert len(data) == len(
self.deserializers
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
outputs = tuple(
[
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
for s, d, oct in zip(
self.deserializers, data, self.output_component_types
)
if oct != utils.STATE_COMPONENT
]
)
def reduce_singleton_output(self, *data) -> Any:
if (
len(
[
@ -788,10 +771,44 @@ class Endpoint:
)
== 1
):
output = outputs[0]
return data[0]
else:
output = outputs
return output
return data
def serialize(self, *data) -> tuple:
assert len(data) == len(
self.serializers
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
files = [
f
for f, t in zip(data, self.input_component_types)
if t in ["file", "uploadbutton"]
]
uploaded_files = self._upload(files)
self._add_uploaded_files_to_data(uploaded_files, list(data))
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o
def deserialize(self, *data) -> tuple:
assert len(data) == len(
self.deserializers
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
outputs = tuple(
[
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
for s, d in zip(self.deserializers, data)
]
)
return outputs
def process_predictions(self, *predictions):
if self.client.serialize:
predictions = self.deserialize(*predictions)
predictions = self.remove_state(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
def _setup_serializers(self) -> tuple[list[Serializable], list[Serializable]]:
inputs = self.dependency["inputs"]

View File

@ -182,7 +182,7 @@ class Communicator:
lock: Lock
job: JobStatus
deserialize: Callable[..., tuple]
prediction_processor: Callable[..., tuple]
reset_url: str
should_cancel: bool = False
@ -251,7 +251,7 @@ async def get_pred_from_ws(
output = resp.get("output", {}).get("data", [])
if output and status_update.code != Status.FINISHED:
try:
result = helper.deserialize(*output)
result = helper.prediction_processor(*output)
except Exception as e:
result = [e]
helper.job.outputs.append(result)
@ -380,10 +380,10 @@ def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> st
return filename
def sanitize_parameter_names(original_param_name: str) -> str:
"""Strips invalid characters from a parameter name and replaces spaces with underscores."""
def sanitize_parameter_names(original_name: str) -> str:
"""Cleans up a Python parameter name to make the API info more readable."""
return (
"".join([char for char in original_param_name if char.isalnum() or char in " "])
"".join([char for char in original_name if char.isalnum() or char in " _"])
.replace(" ", "_")
.lower()
)

View File

@ -1 +1 @@
0.2.4
0.2.5

View File

@ -181,6 +181,26 @@ def file_io_demo():
return demo
@pytest.fixture
def stateful_chatbot():
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
st = gr.State([1, 2, 3])
def respond(message, st, chat_history):
assert st[0] == 1 and st[1] == 2 and st[2] == 3
bot_message = "I love you"
chat_history.append((message, bot_message))
return "", chat_history
msg.submit(respond, [msg, st, chatbot], [msg, chatbot], api_name="submit")
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
return demo
@pytest.fixture
def all_components():
classes_to_check = gr.components.Component.__subclasses__()

View File

@ -23,10 +23,10 @@ HF_TOKEN = "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" # Intentionally reveali
@contextmanager
def connect(demo: gr.Blocks):
def connect(demo: gr.Blocks, serialize: bool = True):
_, local_url, _ = demo.launch(prevent_thread_lock=True)
try:
yield Client(local_url)
yield Client(local_url, serialize=serialize)
finally:
# A more verbose version of .close()
# because we should set a timeout
@ -39,7 +39,7 @@ def connect(demo: gr.Blocks):
demo.server.thread.join(timeout=1)
class TestPredictionsFromSpaces:
class TestClientPredictions:
@pytest.mark.flaky
def test_raise_error_invalid_state(self):
with pytest.raises(ValueError, match="invalid state"):
@ -172,7 +172,6 @@ class TestPredictionsFromSpaces:
assert pathlib.Path(job.result()).exists()
def test_progress_updates(self, progress_demo):
with connect(progress_demo) as client:
job = client.submit("hello", api_name="/predict")
statuses = []
@ -246,7 +245,6 @@ class TestPredictionsFromSpaces:
@pytest.mark.flaky
def test_upload_file_private_space(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload", hf_token=HF_TOKEN
)
@ -308,11 +306,17 @@ class TestPredictionsFromSpaces:
client.submit(1, "foo", f.name, fn_index=0).result()
serialize.assert_called_once_with(1, "foo", f.name)
def test_state_without_serialize(self, stateful_chatbot):
with connect(stateful_chatbot, serialize=False) as client:
initial_history = [["", None]]
message = "Hello"
ret = client.predict(message, initial_history, api_name="/submit")
assert ret == ("", [["", None], ["Hello", "I love you"]])
class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
now = datetime.now()
messages = [
@ -396,7 +400,6 @@ class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
now = datetime.now()
messages_1 = [