mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-07 11:46:51 +08:00
Add status for Python Client Jobs (#3645)
* Add status + unit test (flaky) for now * Install client * Fix tests * Lint backend + tests * Add non-queue test * Fix name * Use lock instead * Add simplify implementation + fix tests * Restore changes to scripts * Fix README typo * Fix CI * Add two concurrent test
This commit is contained in:
parent
cbb84927a7
commit
f73155ed42
2
.github/workflows/backend.yml
vendored
2
.github/workflows/backend.yml
vendored
@ -135,6 +135,6 @@ jobs:
|
||||
- name: Run tests
|
||||
shell: bash
|
||||
run: |
|
||||
coverage run -m pytest -m "${{ matrix.test-type }}"
|
||||
coverage run -m pytest -m "${{ matrix.test-type }}" --ignore=client
|
||||
coverage xml
|
||||
|
||||
|
@ -9,11 +9,11 @@ Here's the entire code to do it:
|
||||
```python
|
||||
import gradio_client as grc
|
||||
|
||||
client = grc.Client("stability-ai/stable-diffusion")
|
||||
job = client.predict("a hyperrealistic portrait of a cat wearing cyberpunk armor")
|
||||
client = grc.Client("stabilityai/stable-diffusion")
|
||||
job = client.predict("a hyperrealistic portrait of a cat wearing cyberpunk armor", "", fn_index=1)
|
||||
job.result()
|
||||
|
||||
>> https://stabilityai-stable-diffusion.hf.space/kjbcxadsk3ada9k/image.png # URL to generated image
|
||||
>> /Users/usersname/b8c26657-df87-4508-aa75-eb37cd38735f # Path to generatoed gallery of images
|
||||
|
||||
```
|
||||
|
||||
@ -56,7 +56,6 @@ client = grc.Client(src="btd372-js72hd.gradio.app")
|
||||
|
||||
The simplest way to make a prediction is simply to call the `.predict()` function with the appropriate arguments and then immediately calling `.result()`, like this:
|
||||
|
||||
|
||||
```python
|
||||
import gradio_client as grc
|
||||
|
||||
@ -73,7 +72,6 @@ Oe should note that `.result()` is a *blocking* operation as it waits for the op
|
||||
|
||||
In many cases, you may be better off letting the job run asynchronously and waiting to call `.result()` when you need the results of the prediction. For example:
|
||||
|
||||
|
||||
```python
|
||||
import gradio_client as grc
|
||||
|
||||
@ -92,12 +90,12 @@ job.result()
|
||||
|
||||
Alternatively, one can add callbacks to perform actions after the job has completed running, like this:
|
||||
|
||||
|
||||
```python
|
||||
import gradio_client as grc
|
||||
|
||||
|
||||
def print_result(x):
|
||||
print(x"The translated result is: {x}")
|
||||
print("The translated result is: {x}")
|
||||
|
||||
client = grc.Client(space="abidlabs/en2fr")
|
||||
|
||||
|
@ -7,6 +7,8 @@ import re
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import huggingface_hub
|
||||
@ -17,6 +19,7 @@ from packaging import version
|
||||
|
||||
from gradio_client import serializing, utils
|
||||
from gradio_client.serializing import Serializable
|
||||
from gradio_client.utils import Communicator, JobStatus, Status, StatusUpdate
|
||||
|
||||
|
||||
class Client:
|
||||
@ -83,9 +86,13 @@ class Client:
|
||||
if api_name:
|
||||
fn_index = self._infer_fn_index(api_name)
|
||||
|
||||
end_to_end_fn = self.endpoints[fn_index].end_to_end_fn
|
||||
helper = None
|
||||
if self.endpoints[fn_index].use_ws:
|
||||
helper = Communicator(Lock(), JobStatus())
|
||||
end_to_end_fn = self.endpoints[fn_index].make_end_to_end_fn(helper)
|
||||
future = self.executor.submit(end_to_end_fn, *args)
|
||||
job = Job(future)
|
||||
|
||||
job = Job(future, communicator=helper)
|
||||
|
||||
if result_callbacks:
|
||||
if isinstance(result_callbacks, Callable):
|
||||
@ -326,42 +333,51 @@ class Endpoint:
|
||||
|
||||
return {"parameters": parameters, "returns": returns}
|
||||
|
||||
def end_to_end_fn(self, *data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
inputs = self.serialize(*data)
|
||||
predictions = self.predict(*inputs)
|
||||
outputs = self.deserialize(*predictions)
|
||||
if len(self.dependency["outputs"]) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
def make_end_to_end_fn(self, helper: Communicator | None = None):
|
||||
|
||||
def predict(self, *data) -> Tuple:
|
||||
data = json.dumps({"data": data, "fn_index": self.fn_index})
|
||||
hash_data = json.dumps(
|
||||
{"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())}
|
||||
)
|
||||
if self.use_ws:
|
||||
result = utils.synchronize_async(self._ws_fn, data, hash_data)
|
||||
output = result["data"]
|
||||
else:
|
||||
response = requests.post(self.api_url, headers=self.headers, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
_predict = self.make_predict(helper)
|
||||
|
||||
def _inner(*data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
inputs = self.serialize(*data)
|
||||
predictions = _predict(*inputs)
|
||||
outputs = self.deserialize(*predictions)
|
||||
if len(self.dependency["outputs"]) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
return _inner
|
||||
|
||||
def make_predict(self, helper: Communicator | None = None):
|
||||
def _predict(*data) -> Tuple:
|
||||
data = json.dumps({"data": data, "fn_index": self.fn_index})
|
||||
hash_data = json.dumps(
|
||||
{"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())}
|
||||
)
|
||||
if self.use_ws:
|
||||
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
|
||||
output = result["data"]
|
||||
except KeyError:
|
||||
if "error" in result and "429" in result["error"]:
|
||||
raise utils.TooManyRequestsError(
|
||||
"Too many requests to the Hugging Face API"
|
||||
else:
|
||||
response = requests.post(self.api_url, headers=self.headers, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
output = result["data"]
|
||||
except KeyError:
|
||||
if "error" in result and "429" in result["error"]:
|
||||
raise utils.TooManyRequestsError(
|
||||
"Too many requests to the Hugging Face API"
|
||||
)
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response. Response received: {result}"
|
||||
)
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response. Response received: {result}"
|
||||
)
|
||||
return tuple(output)
|
||||
return tuple(output)
|
||||
|
||||
return _predict
|
||||
|
||||
def _predict_resolve(self, *data) -> Any:
|
||||
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
|
||||
outputs = self.predict(*data)
|
||||
outputs = self.make_predict()(*data)
|
||||
if len(self.dependency["outputs"]) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
@ -433,18 +449,44 @@ class Endpoint:
|
||||
dependency_uses_queue = dependency.get("queue", False) is not False
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
async def _ws_fn(self, data, hash_data):
|
||||
async def _ws_fn(self, data, hash_data, helper: Communicator):
|
||||
async with websockets.connect( # type: ignore
|
||||
self.ws_url, open_timeout=10, extra_headers=self.headers
|
||||
) as websocket:
|
||||
return await utils.get_pred_from_ws(websocket, data, hash_data)
|
||||
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
|
||||
|
||||
|
||||
class Job(Future):
|
||||
"""A Job is a thin wrapper over the Future class that can be cancelled."""
|
||||
|
||||
def __init__(self, future: Future):
|
||||
def __init__(self, future: Future, communicator: Communicator | None = None):
|
||||
self.future = future
|
||||
self.communicator = communicator
|
||||
|
||||
def status(self) -> StatusUpdate:
|
||||
if not self.communicator:
|
||||
time = datetime.now()
|
||||
if self.done():
|
||||
return StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=None,
|
||||
time=time,
|
||||
eta=None,
|
||||
)
|
||||
else:
|
||||
return StatusUpdate(
|
||||
code=Status.PROCESSING,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=None,
|
||||
time=time,
|
||||
eta=None,
|
||||
)
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
return self.communicator.job.latest_status
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forwards any properties to the Future class."""
|
||||
|
@ -7,8 +7,12 @@ import os
|
||||
import pkgutil
|
||||
import shutil
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import fsspec.asyn
|
||||
import requests
|
||||
@ -38,6 +42,92 @@ class InvalidAPIEndpointError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
"""Status codes presented to client users."""
|
||||
|
||||
STARTING = "STARTING"
|
||||
JOINING_QUEUE = "JOINING_QUEUE"
|
||||
QUEUE_FULL = "QUEUE_FULL"
|
||||
IN_QUEUE = "IN_QUEUE"
|
||||
SENDING_DATA = "SENDING_DATA"
|
||||
PROCESSING = "PROCESSSING"
|
||||
ITERATING = "ITERATING"
|
||||
FINISHED = "FINISHED"
|
||||
|
||||
@staticmethod
|
||||
def ordering(status: "Status") -> int:
|
||||
"""Order of messages. Helpful for testing."""
|
||||
order = [
|
||||
Status.STARTING,
|
||||
Status.JOINING_QUEUE,
|
||||
Status.QUEUE_FULL,
|
||||
Status.IN_QUEUE,
|
||||
Status.SENDING_DATA,
|
||||
Status.PROCESSING,
|
||||
Status.ITERATING,
|
||||
Status.FINISHED,
|
||||
]
|
||||
return order.index(status)
|
||||
|
||||
def __lt__(self, other: "Status"):
|
||||
return self.ordering(self) < self.ordering(other)
|
||||
|
||||
@staticmethod
|
||||
def msg_to_status(msg: str) -> "Status":
|
||||
"""Map the raw message from the backend to the status code presented to users."""
|
||||
return {
|
||||
"send_hash": Status.JOINING_QUEUE,
|
||||
"queue_full": Status.QUEUE_FULL,
|
||||
"estimation": Status.IN_QUEUE,
|
||||
"send_data": Status.SENDING_DATA,
|
||||
"process_starts": Status.PROCESSING,
|
||||
"process_generating": Status.ITERATING,
|
||||
"process_completed": Status.FINISHED,
|
||||
}[msg]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusUpdate:
|
||||
"""Update message sent from the worker thread to the Job on the main thread."""
|
||||
|
||||
code: Status
|
||||
rank: int | None
|
||||
queue_size: int | None
|
||||
eta: float | None
|
||||
success: bool | None
|
||||
time: datetime | None
|
||||
|
||||
|
||||
def create_initial_status_update():
|
||||
return StatusUpdate(
|
||||
code=Status.STARTING,
|
||||
rank=None,
|
||||
queue_size=None,
|
||||
eta=None,
|
||||
success=None,
|
||||
time=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobStatus:
|
||||
"""The job status.
|
||||
|
||||
Keeps strack of the latest status update and intermediate outputs (not yet implements).
|
||||
"""
|
||||
|
||||
latest_status: StatusUpdate = field(default_factory=create_initial_status_update)
|
||||
outputs: List[Any] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Communicator:
|
||||
"""Helper class to help communicate between the worker thread and main thread."""
|
||||
|
||||
lock: Lock
|
||||
job: JobStatus
|
||||
|
||||
|
||||
########################
|
||||
# Network utils
|
||||
########################
|
||||
@ -55,13 +145,27 @@ def is_valid_url(possible_url: str) -> bool:
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
websocket: WebSocketCommonProtocol, data: str, hash_data: str
|
||||
websocket: WebSocketCommonProtocol,
|
||||
data: str,
|
||||
hash_data: str,
|
||||
helper: Communicator | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
completed = False
|
||||
resp = {}
|
||||
while not completed:
|
||||
msg = await websocket.recv()
|
||||
resp = json.loads(msg)
|
||||
if helper:
|
||||
with helper.lock:
|
||||
status_update = StatusUpdate(
|
||||
code=Status.msg_to_status(resp["msg"]),
|
||||
queue_size=resp.get("queue_size"),
|
||||
rank=resp.get("rank", None),
|
||||
success=resp.get("success"),
|
||||
time=datetime.now(),
|
||||
eta=resp.get("rank_eta"),
|
||||
)
|
||||
helper.job.latest_status = status_update
|
||||
if resp["msg"] == "queue_full":
|
||||
raise QueueError("Queue is full! Please try again.")
|
||||
if resp["msg"] == "send_hash":
|
||||
|
@ -1,8 +1,15 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gradio_client import Client
|
||||
from gradio_client.utils import Communicator, Status, StatusUpdate
|
||||
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
HF_TOKEN = "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" # Intentionally revealing this key for testing purposes
|
||||
|
||||
@ -20,6 +27,190 @@ class TestPredictionsFromSpaces:
|
||||
output = client.predict("abc").result()
|
||||
assert output == "abc"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_job_status(self):
|
||||
statuses = []
|
||||
client = Client(src="gradio/calculator")
|
||||
job = client.predict(5, "add", 4)
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
statuses.append(job.status())
|
||||
|
||||
assert statuses
|
||||
# Messages are sorted by time
|
||||
assert sorted([s.time for s in statuses if s]) == [
|
||||
s.time for s in statuses if s
|
||||
]
|
||||
assert sorted([s.code for s in statuses if s]) == [
|
||||
s.code for s in statuses if s
|
||||
]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_job_status_queue_disabled(self):
|
||||
statuses = []
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.predict("I love the gradio python client")
|
||||
while not job.done():
|
||||
time.sleep(0.02)
|
||||
statuses.append(job.status())
|
||||
statuses.append(job.status())
|
||||
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
|
||||
|
||||
|
||||
class TestStatusUpdates:
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
def test_messages_passed_correctly(self, mock_make_end_to_end_fn):
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
messages = [
|
||||
StatusUpdate(
|
||||
code=Status.STARTING,
|
||||
eta=None,
|
||||
rank=None,
|
||||
success=None,
|
||||
queue_size=None,
|
||||
time=now,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.SENDING_DATA,
|
||||
eta=None,
|
||||
rank=None,
|
||||
success=None,
|
||||
queue_size=None,
|
||||
time=now + timedelta(seconds=1),
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
eta=3,
|
||||
rank=2,
|
||||
queue_size=2,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=2),
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
eta=2,
|
||||
rank=1,
|
||||
queue_size=1,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=3),
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.ITERATING,
|
||||
eta=None,
|
||||
rank=None,
|
||||
queue_size=None,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=3),
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
eta=None,
|
||||
rank=None,
|
||||
queue_size=None,
|
||||
success=True,
|
||||
time=now + timedelta(seconds=4),
|
||||
),
|
||||
]
|
||||
|
||||
class MockEndToEndFunction:
|
||||
def __init__(self, communicator: Communicator):
|
||||
self.communicator = communicator
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for m in messages:
|
||||
with self.communicator.lock:
|
||||
self.communicator.job.latest_status = m
|
||||
time.sleep(0.1)
|
||||
|
||||
mock_make_end_to_end_fn.side_effect = MockEndToEndFunction
|
||||
|
||||
client = Client(src="gradio/calculator")
|
||||
job = client.predict(5, "add", 6, fn_index=0)
|
||||
|
||||
statuses = []
|
||||
while not job.done():
|
||||
statuses.append(job.status())
|
||||
time.sleep(0.09)
|
||||
|
||||
assert all(s in messages for s in statuses)
|
||||
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
def test_messages_correct_two_concurrent(self, mock_make_end_to_end_fn):
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
messages_1 = [
|
||||
StatusUpdate(
|
||||
code=Status.STARTING,
|
||||
eta=None,
|
||||
rank=None,
|
||||
success=None,
|
||||
queue_size=None,
|
||||
time=now,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
eta=None,
|
||||
rank=None,
|
||||
queue_size=None,
|
||||
success=True,
|
||||
time=now + timedelta(seconds=4),
|
||||
),
|
||||
]
|
||||
|
||||
messages_2 = [
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
eta=3,
|
||||
rank=2,
|
||||
queue_size=2,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=2),
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
eta=2,
|
||||
rank=1,
|
||||
queue_size=1,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=3),
|
||||
),
|
||||
]
|
||||
|
||||
class MockEndToEndFunction:
|
||||
n_counts = 0
|
||||
|
||||
def __init__(self, communicator: Communicator):
|
||||
self.communicator = communicator
|
||||
self.messages = (
|
||||
messages_1 if MockEndToEndFunction.n_counts == 0 else messages_2
|
||||
)
|
||||
MockEndToEndFunction.n_counts += 1
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for m in self.messages:
|
||||
with self.communicator.lock:
|
||||
print(f"here: {m}")
|
||||
self.communicator.job.latest_status = m
|
||||
time.sleep(0.1)
|
||||
|
||||
mock_make_end_to_end_fn.side_effect = MockEndToEndFunction
|
||||
|
||||
client = Client(src="gradio/calculator")
|
||||
job_1 = client.predict(5, "add", 6, fn_index=0)
|
||||
job_2 = client.predict(11, "subtract", 1, fn_index=0)
|
||||
|
||||
statuses_1 = []
|
||||
statuses_2 = []
|
||||
while not (job_1.done() and job_2.done()):
|
||||
statuses_1.append(job_1.status())
|
||||
statuses_2.append(job_2.status())
|
||||
time.sleep(0.05)
|
||||
|
||||
assert all(s in messages_1 for s in statuses_1)
|
||||
|
||||
|
||||
class TestEndpoints:
|
||||
@pytest.mark.flaky
|
||||
|
@ -6,4 +6,4 @@ source scripts/helpers.sh
|
||||
pip_required
|
||||
|
||||
echo "Installing Gradio..."
|
||||
pip install -e .
|
||||
pip install -e .
|
Loading…
Reference in New Issue
Block a user