Get Intermediate Results from Python Client (#3694)

* Add status + unit test (flaky) for now

* Install client

* Fix tests

* Lint backend + tests

* Add non-queue test

* Fix name

* Use lock instead

* Add simplify implementation + fix tests

* Restore changes to scripts

* Fix README typo

* Fix CI

* Add intermediate results to python client

* Type check

* Typecheck again

* Catch exception:

* Thinking

* Dont read generator from config

* add no queue test

* Remove unused method

* Fix types

* Remove breakpoint

* Fix code

* Fix test

* Fix tests

* Unpack list

* Add docstring

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-04-04 14:58:25 -07:00 committed by GitHub
parent c4ad09b631
commit 9325cba14c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 157 additions and 28 deletions

View File

@ -5,8 +5,9 @@ import concurrent.futures
import json
import re
import threading
import time
import uuid
from concurrent.futures import Future
from concurrent.futures import Future, TimeoutError
from datetime import datetime
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
@ -76,7 +77,7 @@ class Client:
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | List[Callable] | None = None,
) -> Future:
) -> Job:
"""
Parameters:
*args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app.
@ -90,7 +91,9 @@ class Client:
helper = None
if self.endpoints[inferred_fn_index].use_ws:
helper = Communicator(Lock(), JobStatus())
helper = Communicator(
Lock(), JobStatus(), self.endpoints[inferred_fn_index].deserialize
)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
@ -389,19 +392,14 @@ class Endpoint:
raise utils.InvalidAPIEndpointError()
inputs = self.serialize(*data)
predictions = _predict(*inputs)
outputs = self.deserialize(*predictions)
if (
len(
[
oct
for oct in self.output_component_types
if not oct == utils.STATE_COMPONENT
]
)
== 1
):
return outputs[0]
return outputs
output = self.deserialize(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
with helper.lock:
if not helper.job.outputs:
helper.job.outputs.append(output)
return output
return _inner
@ -461,11 +459,11 @@ class Endpoint:
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
return tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
def deserialize(self, *data) -> Tuple:
def deserialize(self, *data) -> Tuple | Any:
assert len(data) == len(
self.deserializers
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
return tuple(
outputs = tuple(
[
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
for s, d, oct in zip(
@ -474,6 +472,20 @@ class Endpoint:
if not oct == utils.STATE_COMPONENT
]
)
if (
len(
[
oct
for oct in self.output_component_types
if not oct == utils.STATE_COMPONENT
]
)
== 1
):
output = outputs[0]
else:
output = outputs
return output
def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
inputs = self.dependency["inputs"]
@ -529,7 +541,10 @@ class Endpoint:
async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.client.ws_url, open_timeout=10, extra_headers=self.client.headers
self.client.ws_url,
open_timeout=10,
extra_headers=self.client.headers,
max_size=1024 * 1024 * 1024,
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
@ -537,23 +552,89 @@ class Endpoint:
class Job(Future):
"""A Job is a thin wrapper over the Future class that can be cancelled."""
def __init__(self, future: Future, communicator: Communicator | None = None):
def __init__(
self,
future: Future,
communicator: Communicator | None = None,
):
self.future = future
self.communicator = communicator
def status(self) -> StatusUpdate:
def outputs(self) -> List[Tuple | Any]:
"""Returns a list containing the latest outputs from the Job.
If the endpoint has multiple output components, the list will contain
a tuple of results. Otherwise, it will contain the results without storing them
in tuples.
For endpoints that are queued, this list will contain the final job output even
if that endpoint does not use a generator function.
"""
if not self.communicator:
time = datetime.now()
if self.done():
return []
else:
with self.communicator.lock:
return self.communicator.job.outputs
def result(self, timeout=None):
"""Return the result of the call that the future represents.
Args:
timeout: The number of seconds to wait for the result if the future
isn't done. If None, then there is no limit on the wait time.
Returns:
The result of the call that the future represents.
Raises:
CancelledError: If the future was cancelled.
TimeoutError: If the future didn't finish executing before the given
timeout.
Exception: If the call raised then that exception will be raised.
"""
if self.communicator:
timeout = timeout or float("inf")
if self.future._exception: # type: ignore
raise self.future._exception # type: ignore
with self.communicator.lock:
if self.communicator.job.outputs:
return self.communicator.job.outputs[0]
start = datetime.now()
while True:
if (datetime.now() - start).seconds > timeout:
raise TimeoutError()
if self.future._exception: # type: ignore
raise self.future._exception # type: ignore
with self.communicator.lock:
if self.communicator.job.outputs:
return self.communicator.job.outputs[0]
time.sleep(0.01)
else:
return super().result(timeout=timeout)
def status(self) -> StatusUpdate:
time = datetime.now()
if self.done():
if not self.future._exception: # type: ignore
return StatusUpdate(
code=Status.FINISHED,
rank=0,
queue_size=None,
success=None,
success=True,
time=time,
eta=None,
)
else:
return StatusUpdate(
code=Status.FINISHED,
rank=0,
queue_size=None,
success=False,
time=time,
eta=None,
)
else:
if not self.communicator:
return StatusUpdate(
code=Status.PROCESSING,
rank=0,
@ -562,9 +643,9 @@ class Job(Future):
time=time,
eta=None,
)
else:
with self.communicator.lock:
return self.communicator.job.latest_status
else:
with self.communicator.lock:
return self.communicator.job.latest_status
def __getattr__(self, name):
"""Forwards any properties to the Future class."""

View File

@ -127,6 +127,7 @@ class Communicator:
lock: Lock
job: JobStatus
deserialize: Callable[..., Tuple]
########################
@ -166,6 +167,13 @@ async def get_pred_from_ws(
time=datetime.now(),
eta=resp.get("rank_eta"),
)
output = resp.get("output", {}).get("data", [])
if output and status_update.code != Status.FINISHED:
try:
result = helper.deserialize(*output)
except Exception as e:
result = [e]
helper.job.outputs.append(result)
helper.job.latest_status = status_update
if resp["msg"] == "queue_full":
raise QueueError("Queue is full! Please try again.")

View File

@ -2,6 +2,7 @@ import json
import os
import pathlib
import time
from concurrent.futures import TimeoutError
from datetime import datetime, timedelta
from unittest.mock import patch
@ -81,7 +82,7 @@ class TestPredictionsFromSpaces:
def test_job_status_queue_disabled(self):
statuses = []
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict("I love the gradio python client", fn_index=0)
job = client.predict("I love the gradio python client", api_name="/classify")
while not job.done():
time.sleep(0.02)
statuses.append(job.status())
@ -89,6 +90,45 @@ class TestPredictionsFromSpaces:
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
@pytest.mark.flaky
def test_intermediate_outputs(
self,
):
client = Client(src="gradio/count_generator")
job = client.predict(3, api_name="/count")
while not job.done():
time.sleep(0.1)
assert job.outputs() == [str(i) for i in range(3)]
@pytest.mark.flaky
def test_timeout(self):
with pytest.raises(TimeoutError):
client = Client(src="gradio/count_generator")
job = client.predict(api_name="/sleep")
job.result(timeout=0.05)
@pytest.mark.flaky
def test_timeout_no_queue(self):
with pytest.raises(TimeoutError):
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict(api_name="/sleep")
job.result(timeout=0.1)
@pytest.mark.flaky
def test_raises_exception(self):
with pytest.raises(Exception):
client = Client(src="freddyaboulton/calculator")
job = client.predict("foo", "add", 9, fn_index=0)
job.result()
@pytest.mark.flaky
def test_raises_exception_no_queue(self):
with pytest.raises(Exception):
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict([5], api_name="/sleep")
job.result()
def test_job_output_video(self):
client = Client(src="gradio/video_component")
job = client.predict(