Support gr.Progress() in python client (#3924)

* Add progress message

* CHANGELOG

* Dont use pydantic

* Docs + local test

* Add gr to requirements

* Remove editable install

* make a bit softer
This commit is contained in:
Freddy Boulton 2023-04-24 12:52:10 -04:00 committed by GitHub
parent 6b854eefd4
commit d835c9a816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 5 deletions

View File

@ -2,7 +2,7 @@
## New Features:
No changes to highlight.
- Progress Updates from `gr.Progress()` can be accessed via `job.status().progress_data` by @freddyaboulton](https://github.com/freddyaboulton) in [PR 3924](https://github.com/gradio-app/gradio/pull/3924)
## Bug Fixes:

View File

@ -949,7 +949,11 @@ class Job(Future):
def status(self) -> StatusUpdate:
"""
Returns the latest status update from the Job in the form of a StatusUpdate
object, which contains the following fields: code, rank, queue_size, success, time, eta.
object, which contains the following fields: code, rank, queue_size, success, time, eta, and progress_data.
progress_data is a list of updates emitted by the gr.Progress() tracker of the event handler. Each element
of the list has the following fields: index, length, unit, progress, desc. If the event handler does not have
a gr.Progress() tracker, the progress_data field will be None.
Example:
from gradio_client import Client
@ -973,6 +977,7 @@ class Job(Future):
success=False,
time=time,
eta=None,
progress_data=None,
)
if self.done():
if not self.future._exception: # type: ignore
@ -983,6 +988,7 @@ class Job(Future):
success=True,
time=time,
eta=None,
progress_data=None,
)
else:
return StatusUpdate(
@ -992,6 +998,7 @@ class Job(Future):
success=False,
time=time,
eta=None,
progress_data=None,
)
else:
if not self.communicator:
@ -1002,6 +1009,7 @@ class Job(Future):
success=None,
time=time,
eta=None,
progress_data=None,
)
else:
with self.communicator.lock:

View File

@ -14,7 +14,7 @@ from datetime import datetime
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import fsspec.asyn
import httpx
@ -79,6 +79,7 @@ class Status(Enum):
SENDING_DATA = "SENDING_DATA"
PROCESSING = "PROCESSING"
ITERATING = "ITERATING"
PROGRESS = "PROGRESS"
FINISHED = "FINISHED"
CANCELLED = "CANCELLED"
@ -92,6 +93,7 @@ class Status(Enum):
Status.IN_QUEUE,
Status.SENDING_DATA,
Status.PROCESSING,
Status.PROGRESS,
Status.ITERATING,
Status.FINISHED,
Status.CANCELLED,
@ -112,9 +114,32 @@ class Status(Enum):
"process_starts": Status.PROCESSING,
"process_generating": Status.ITERATING,
"process_completed": Status.FINISHED,
"progress": Status.PROGRESS,
}[msg]
@dataclass
class ProgressUnit:
index: Optional[int]
length: Optional[int]
unit: Optional[str]
progress: Optional[float]
desc: Optional[str]
@classmethod
def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]:
return [
cls(
index=d.get("index"),
length=d.get("length"),
unit=d.get("unit"),
progress=d.get("progress"),
desc=d.get("desc"),
)
for d in data
]
@dataclass
class StatusUpdate:
"""Update message sent from the worker thread to the Job on the main thread."""
@ -125,6 +150,7 @@ class StatusUpdate:
eta: float | None
success: bool | None
time: datetime | None
progress_data: List[ProgressUnit] | None
def create_initial_status_update():
@ -135,6 +161,7 @@ def create_initial_status_update():
eta=None,
success=None,
time=datetime.now(),
progress_data=None,
)
@ -209,6 +236,7 @@ async def get_pred_from_ws(
resp = json.loads(msg)
if helper:
with helper.lock:
has_progress = "progress_data" in resp
status_update = StatusUpdate(
code=Status.msg_to_status(resp["msg"]),
queue_size=resp.get("queue_size"),
@ -216,6 +244,9 @@ async def get_pred_from_ws(
success=resp.get("success"),
time=datetime.now(),
eta=resp.get("rank_eta"),
progress_data=ProgressUnit.from_ws_msg(resp["progress_data"])
if has_progress
else None,
)
output = resp.get("output", {}).get("data", [])
if output and status_update.code != Status.FINISHED:

View File

@ -9,5 +9,4 @@ black --check test gradio_client
pyright gradio_client/*.py
echo "Testing..."
python -m pip install -e ../../. # Install gradio from local source (as the latest version may not yet be published to PyPI)
python -m pytest test

View File

@ -3,3 +3,4 @@ pytest-asyncio
pytest==7.1.2
ruff==0.0.260
pyright==1.1.298
gradio

View File

@ -8,12 +8,13 @@ from concurrent.futures import CancelledError, TimeoutError
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import gradio as gr
import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from gradio_client import Client
from gradio_client.serializing import SimpleSerializable
from gradio_client.utils import Communicator, Status, StatusUpdate
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
@ -96,6 +97,7 @@ class TestPredictionsFromSpaces:
statuses.append(job.status())
statuses.append(job.status())
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
assert not any(s.progress_data for s in statuses)
@pytest.mark.flaky
def test_intermediate_outputs(
@ -157,6 +159,40 @@ class TestPredictionsFromSpaces:
)
assert pathlib.Path(job.result()).exists()
def test_progress_updates(self):
def my_function(x, progress=gr.Progress()):
progress(0, desc="Starting...")
for i in progress.tqdm(range(20)):
time.sleep(0.1)
return x
demo = gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue(
concurrency_count=20
)
_, local_url, _ = demo.launch(prevent_thread_lock=True)
try:
client = Client(src=local_url)
job = client.submit("hello", api_name="/predict")
statuses = []
while not job.done():
statuses.append(job.status())
time.sleep(0.02)
assert any(s.code == Status.PROGRESS for s in statuses)
assert any(s.progress_data is not None for s in statuses)
all_progress_data = [
p for s in statuses if s.progress_data for p in s.progress_data
]
count = 0
for i in range(20):
unit = ProgressUnit(
index=i, length=20, unit="steps", progress=None, desc=None
)
count += unit in all_progress_data
assert count
finally:
demo.close()
@pytest.mark.flaky
def test_cancel_from_client_queued(self):
client = Client(src="gradio-tests/test-cancel-from-client")
@ -284,6 +320,7 @@ class TestStatusUpdates:
success=None,
queue_size=None,
time=now,
progress_data=None,
),
StatusUpdate(
code=Status.SENDING_DATA,
@ -292,6 +329,7 @@ class TestStatusUpdates:
success=None,
queue_size=None,
time=now + timedelta(seconds=1),
progress_data=None,
),
StatusUpdate(
code=Status.IN_QUEUE,
@ -300,6 +338,7 @@ class TestStatusUpdates:
queue_size=2,
success=None,
time=now + timedelta(seconds=2),
progress_data=None,
),
StatusUpdate(
code=Status.IN_QUEUE,
@ -308,6 +347,7 @@ class TestStatusUpdates:
queue_size=1,
success=None,
time=now + timedelta(seconds=3),
progress_data=None,
),
StatusUpdate(
code=Status.ITERATING,
@ -316,6 +356,7 @@ class TestStatusUpdates:
queue_size=None,
success=None,
time=now + timedelta(seconds=3),
progress_data=None,
),
StatusUpdate(
code=Status.FINISHED,
@ -324,6 +365,7 @@ class TestStatusUpdates:
queue_size=None,
success=True,
time=now + timedelta(seconds=4),
progress_data=None,
),
]
@ -362,6 +404,7 @@ class TestStatusUpdates:
success=None,
queue_size=None,
time=now,
progress_data=None,
),
StatusUpdate(
code=Status.FINISHED,
@ -370,6 +413,7 @@ class TestStatusUpdates:
queue_size=None,
success=True,
time=now + timedelta(seconds=4),
progress_data=None,
),
]
@ -381,6 +425,7 @@ class TestStatusUpdates:
queue_size=2,
success=None,
time=now + timedelta(seconds=2),
progress_data=None,
),
StatusUpdate(
code=Status.IN_QUEUE,
@ -389,6 +434,7 @@ class TestStatusUpdates:
queue_size=1,
success=None,
time=now + timedelta(seconds=3),
progress_data=None,
),
]