Add status for Python Client Jobs (#3645)

* Add status + unit test (flaky) for now

* Install client

* Fix tests

* Lint backend + tests

* Add non-queue test

* Fix name

* Use lock instead

* Add simplify implementation + fix tests

* Restore changes to scripts

* Fix README typo

* Fix CI

* Add two concurrent test
This commit is contained in:
Freddy Boulton 2023-03-29 18:41:12 -04:00 committed by GitHub
parent cbb84927a7
commit f73155ed42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 381 additions and 46 deletions

View File

@ -135,6 +135,6 @@ jobs:
- name: Run tests
shell: bash
run: |
coverage run -m pytest -m "${{ matrix.test-type }}"
coverage run -m pytest -m "${{ matrix.test-type }}" --ignore=client
coverage xml

View File

@ -9,11 +9,11 @@ Here's the entire code to do it:
```python
import gradio_client as grc
client = grc.Client("stability-ai/stable-diffusion")
job = client.predict("a hyperrealistic portrait of a cat wearing cyberpunk armor")
client = grc.Client("stabilityai/stable-diffusion")
job = client.predict("a hyperrealistic portrait of a cat wearing cyberpunk armor", "", fn_index=1)
job.result()
>> https://stabilityai-stable-diffusion.hf.space/kjbcxadsk3ada9k/image.png # URL to generated image
>> /Users/usersname/b8c26657-df87-4508-aa75-eb37cd38735f # Path to generatoed gallery of images
```
@ -56,7 +56,6 @@ client = grc.Client(src="btd372-js72hd.gradio.app")
The simplest way to make a prediction is simply to call the `.predict()` function with the appropriate arguments and then immediately calling `.result()`, like this:
```python
import gradio_client as grc
@ -73,7 +72,6 @@ Oe should note that `.result()` is a *blocking* operation as it waits for the op
In many cases, you may be better off letting the job run asynchronously and waiting to call `.result()` when you need the results of the prediction. For example:
```python
import gradio_client as grc
@ -92,12 +90,12 @@ job.result()
Alternatively, one can add callbacks to perform actions after the job has completed running, like this:
```python
import gradio_client as grc
def print_result(x):
print(x"The translated result is: {x}")
print("The translated result is: {x}")
client = grc.Client(space="abidlabs/en2fr")

View File

@ -7,6 +7,8 @@ import re
import threading
import uuid
from concurrent.futures import Future
from datetime import datetime
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
import huggingface_hub
@ -17,6 +19,7 @@ from packaging import version
from gradio_client import serializing, utils
from gradio_client.serializing import Serializable
from gradio_client.utils import Communicator, JobStatus, Status, StatusUpdate
class Client:
@ -83,9 +86,13 @@ class Client:
if api_name:
fn_index = self._infer_fn_index(api_name)
end_to_end_fn = self.endpoints[fn_index].end_to_end_fn
helper = None
if self.endpoints[fn_index].use_ws:
helper = Communicator(Lock(), JobStatus())
end_to_end_fn = self.endpoints[fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
job = Job(future)
job = Job(future, communicator=helper)
if result_callbacks:
if isinstance(result_callbacks, Callable):
@ -326,42 +333,51 @@ class Endpoint:
return {"parameters": parameters, "returns": returns}
def end_to_end_fn(self, *data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
inputs = self.serialize(*data)
predictions = self.predict(*inputs)
outputs = self.deserialize(*predictions)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs
def make_end_to_end_fn(self, helper: Communicator | None = None):
def predict(self, *data) -> Tuple:
data = json.dumps({"data": data, "fn_index": self.fn_index})
hash_data = json.dumps(
{"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())}
)
if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data)
output = result["data"]
else:
response = requests.post(self.api_url, headers=self.headers, data=data)
result = json.loads(response.content.decode("utf-8"))
try:
_predict = self.make_predict(helper)
def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
inputs = self.serialize(*data)
predictions = _predict(*inputs)
outputs = self.deserialize(*predictions)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs
return _inner
def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> Tuple:
data = json.dumps({"data": data, "fn_index": self.fn_index})
hash_data = json.dumps(
{"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())}
)
if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
output = result["data"]
except KeyError:
if "error" in result and "429" in result["error"]:
raise utils.TooManyRequestsError(
"Too many requests to the Hugging Face API"
else:
response = requests.post(self.api_url, headers=self.headers, data=data)
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
except KeyError:
if "error" in result and "429" in result["error"]:
raise utils.TooManyRequestsError(
"Too many requests to the Hugging Face API"
)
raise KeyError(
f"Could not find 'data' key in response. Response received: {result}"
)
raise KeyError(
f"Could not find 'data' key in response. Response received: {result}"
)
return tuple(output)
return tuple(output)
return _predict
def _predict_resolve(self, *data) -> Any:
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
outputs = self.predict(*data)
outputs = self.make_predict()(*data)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs
@ -433,18 +449,44 @@ class Endpoint:
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
async def _ws_fn(self, data, hash_data):
async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.ws_url, open_timeout=10, extra_headers=self.headers
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data)
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
class Job(Future):
"""A Job is a thin wrapper over the Future class that can be cancelled."""
def __init__(self, future: Future):
def __init__(self, future: Future, communicator: Communicator | None = None):
self.future = future
self.communicator = communicator
def status(self) -> StatusUpdate:
if not self.communicator:
time = datetime.now()
if self.done():
return StatusUpdate(
code=Status.FINISHED,
rank=0,
queue_size=None,
success=None,
time=time,
eta=None,
)
else:
return StatusUpdate(
code=Status.PROCESSING,
rank=0,
queue_size=None,
success=None,
time=time,
eta=None,
)
else:
with self.communicator.lock:
return self.communicator.job.latest_status
def __getattr__(self, name):
"""Forwards any properties to the Future class."""

View File

@ -7,8 +7,12 @@ import os
import pkgutil
import shutil
import tempfile
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
import fsspec.asyn
import requests
@ -38,6 +42,92 @@ class InvalidAPIEndpointError(Exception):
pass
class Status(Enum):
"""Status codes presented to client users."""
STARTING = "STARTING"
JOINING_QUEUE = "JOINING_QUEUE"
QUEUE_FULL = "QUEUE_FULL"
IN_QUEUE = "IN_QUEUE"
SENDING_DATA = "SENDING_DATA"
PROCESSING = "PROCESSSING"
ITERATING = "ITERATING"
FINISHED = "FINISHED"
@staticmethod
def ordering(status: "Status") -> int:
"""Order of messages. Helpful for testing."""
order = [
Status.STARTING,
Status.JOINING_QUEUE,
Status.QUEUE_FULL,
Status.IN_QUEUE,
Status.SENDING_DATA,
Status.PROCESSING,
Status.ITERATING,
Status.FINISHED,
]
return order.index(status)
def __lt__(self, other: "Status"):
return self.ordering(self) < self.ordering(other)
@staticmethod
def msg_to_status(msg: str) -> "Status":
"""Map the raw message from the backend to the status code presented to users."""
return {
"send_hash": Status.JOINING_QUEUE,
"queue_full": Status.QUEUE_FULL,
"estimation": Status.IN_QUEUE,
"send_data": Status.SENDING_DATA,
"process_starts": Status.PROCESSING,
"process_generating": Status.ITERATING,
"process_completed": Status.FINISHED,
}[msg]
@dataclass
class StatusUpdate:
"""Update message sent from the worker thread to the Job on the main thread."""
code: Status
rank: int | None
queue_size: int | None
eta: float | None
success: bool | None
time: datetime | None
def create_initial_status_update():
return StatusUpdate(
code=Status.STARTING,
rank=None,
queue_size=None,
eta=None,
success=None,
time=datetime.now(),
)
@dataclass
class JobStatus:
"""The job status.
Keeps strack of the latest status update and intermediate outputs (not yet implements).
"""
latest_status: StatusUpdate = field(default_factory=create_initial_status_update)
outputs: List[Any] = field(default_factory=list)
@dataclass
class Communicator:
"""Helper class to help communicate between the worker thread and main thread."""
lock: Lock
job: JobStatus
########################
# Network utils
########################
@ -55,13 +145,27 @@ def is_valid_url(possible_url: str) -> bool:
async def get_pred_from_ws(
websocket: WebSocketCommonProtocol, data: str, hash_data: str
websocket: WebSocketCommonProtocol,
data: str,
hash_data: str,
helper: Communicator | None = None,
) -> Dict[str, Any]:
completed = False
resp = {}
while not completed:
msg = await websocket.recv()
resp = json.loads(msg)
if helper:
with helper.lock:
status_update = StatusUpdate(
code=Status.msg_to_status(resp["msg"]),
queue_size=resp.get("queue_size"),
rank=resp.get("rank", None),
success=resp.get("success"),
time=datetime.now(),
eta=resp.get("rank_eta"),
)
helper.job.latest_status = status_update
if resp["msg"] == "queue_full":
raise QueueError("Queue is full! Please try again.")
if resp["msg"] == "send_hash":

View File

@ -1,8 +1,15 @@
import json
import os
import time
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from gradio_client import Client
from gradio_client.utils import Communicator, Status, StatusUpdate
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_TOKEN = "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" # Intentionally revealing this key for testing purposes
@ -20,6 +27,190 @@ class TestPredictionsFromSpaces:
output = client.predict("abc").result()
assert output == "abc"
@pytest.mark.flaky
def test_job_status(self):
statuses = []
client = Client(src="gradio/calculator")
job = client.predict(5, "add", 4)
while not job.done():
time.sleep(0.1)
statuses.append(job.status())
assert statuses
# Messages are sorted by time
assert sorted([s.time for s in statuses if s]) == [
s.time for s in statuses if s
]
assert sorted([s.code for s in statuses if s]) == [
s.code for s in statuses if s
]
@pytest.mark.flaky
def test_job_status_queue_disabled(self):
statuses = []
client = Client(src="freddyaboulton/sentiment-classification")
job = client.predict("I love the gradio python client")
while not job.done():
time.sleep(0.02)
statuses.append(job.status())
statuses.append(job.status())
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
now = datetime.now()
messages = [
StatusUpdate(
code=Status.STARTING,
eta=None,
rank=None,
success=None,
queue_size=None,
time=now,
),
StatusUpdate(
code=Status.SENDING_DATA,
eta=None,
rank=None,
success=None,
queue_size=None,
time=now + timedelta(seconds=1),
),
StatusUpdate(
code=Status.IN_QUEUE,
eta=3,
rank=2,
queue_size=2,
success=None,
time=now + timedelta(seconds=2),
),
StatusUpdate(
code=Status.IN_QUEUE,
eta=2,
rank=1,
queue_size=1,
success=None,
time=now + timedelta(seconds=3),
),
StatusUpdate(
code=Status.ITERATING,
eta=None,
rank=None,
queue_size=None,
success=None,
time=now + timedelta(seconds=3),
),
StatusUpdate(
code=Status.FINISHED,
eta=None,
rank=None,
queue_size=None,
success=True,
time=now + timedelta(seconds=4),
),
]
class MockEndToEndFunction:
def __init__(self, communicator: Communicator):
self.communicator = communicator
def __call__(self, *args, **kwargs):
for m in messages:
with self.communicator.lock:
self.communicator.job.latest_status = m
time.sleep(0.1)
mock_make_end_to_end_fn.side_effect = MockEndToEndFunction
client = Client(src="gradio/calculator")
job = client.predict(5, "add", 6, fn_index=0)
statuses = []
while not job.done():
statuses.append(job.status())
time.sleep(0.09)
assert all(s in messages for s in statuses)
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
now = datetime.now()
messages_1 = [
StatusUpdate(
code=Status.STARTING,
eta=None,
rank=None,
success=None,
queue_size=None,
time=now,
),
StatusUpdate(
code=Status.FINISHED,
eta=None,
rank=None,
queue_size=None,
success=True,
time=now + timedelta(seconds=4),
),
]
messages_2 = [
StatusUpdate(
code=Status.IN_QUEUE,
eta=3,
rank=2,
queue_size=2,
success=None,
time=now + timedelta(seconds=2),
),
StatusUpdate(
code=Status.IN_QUEUE,
eta=2,
rank=1,
queue_size=1,
success=None,
time=now + timedelta(seconds=3),
),
]
class MockEndToEndFunction:
n_counts = 0
def __init__(self, communicator: Communicator):
self.communicator = communicator
self.messages = (
messages_1 if MockEndToEndFunction.n_counts == 0 else messages_2
)
MockEndToEndFunction.n_counts += 1
def __call__(self, *args, **kwargs):
for m in self.messages:
with self.communicator.lock:
print(f"here: {m}")
self.communicator.job.latest_status = m
time.sleep(0.1)
mock_make_end_to_end_fn.side_effect = MockEndToEndFunction
client = Client(src="gradio/calculator")
job_1 = client.predict(5, "add", 6, fn_index=0)
job_2 = client.predict(11, "subtract", 1, fn_index=0)
statuses_1 = []
statuses_2 = []
while not (job_1.done() and job_2.done()):
statuses_1.append(job_1.status())
statuses_2.append(job_2.status())
time.sleep(0.05)
assert all(s in messages_1 for s in statuses_1)
class TestEndpoints:
@pytest.mark.flaky

View File

@ -6,4 +6,4 @@ source scripts/helpers.sh
pip_required
echo "Installing Gradio..."
pip install -e .
pip install -e .