From d835c9a816e16be00794b788958d041c47ec8e4b Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Mon, 24 Apr 2023 12:52:10 -0400 Subject: [PATCH] Support gr.Progress() in python client (#3924) * Add progress message * CHANGELOG * Dont use pydantic * Docs + local test * Add gr to requirements * Remove editable install * make a bit softer --- client/python/CHANGELOG.md | 2 +- client/python/gradio_client/client.py | 10 +++++- client/python/gradio_client/utils.py | 33 +++++++++++++++++- client/python/scripts/ci.sh | 1 - client/python/test/requirements.txt | 1 + client/python/test/test_client.py | 48 ++++++++++++++++++++++++++- 6 files changed, 90 insertions(+), 5 deletions(-) diff --git a/client/python/CHANGELOG.md b/client/python/CHANGELOG.md index 1225636c6f..fef0b566c3 100644 --- a/client/python/CHANGELOG.md +++ b/client/python/CHANGELOG.md @@ -2,7 +2,7 @@ ## New Features: -No changes to highlight. +- Progress Updates from `gr.Progress()` can be accessed via `job.status().progress_data` by @freddyaboulton](https://github.com/freddyaboulton) in [PR 3924](https://github.com/gradio-app/gradio/pull/3924) ## Bug Fixes: diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index bb8ba98f63..73a85d0d74 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -949,7 +949,11 @@ class Job(Future): def status(self) -> StatusUpdate: """ Returns the latest status update from the Job in the form of a StatusUpdate - object, which contains the following fields: code, rank, queue_size, success, time, eta. + object, which contains the following fields: code, rank, queue_size, success, time, eta, and progress_data. + + progress_data is a list of updates emitted by the gr.Progress() tracker of the event handler. Each element + of the list has the following fields: index, length, unit, progress, desc. If the event handler does not have + a gr.Progress() tracker, the progress_data field will be None. Example: from gradio_client import Client @@ -973,6 +977,7 @@ class Job(Future): success=False, time=time, eta=None, + progress_data=None, ) if self.done(): if not self.future._exception: # type: ignore @@ -983,6 +988,7 @@ class Job(Future): success=True, time=time, eta=None, + progress_data=None, ) else: return StatusUpdate( @@ -992,6 +998,7 @@ class Job(Future): success=False, time=time, eta=None, + progress_data=None, ) else: if not self.communicator: @@ -1002,6 +1009,7 @@ class Job(Future): success=None, time=time, eta=None, + progress_data=None, ) else: with self.communicator.lock: diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index d2b02c0316..697f9d191d 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -14,7 +14,7 @@ from datetime import datetime from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import fsspec.asyn import httpx @@ -79,6 +79,7 @@ class Status(Enum): SENDING_DATA = "SENDING_DATA" PROCESSING = "PROCESSING" ITERATING = "ITERATING" + PROGRESS = "PROGRESS" FINISHED = "FINISHED" CANCELLED = "CANCELLED" @@ -92,6 +93,7 @@ class Status(Enum): Status.IN_QUEUE, Status.SENDING_DATA, Status.PROCESSING, + Status.PROGRESS, Status.ITERATING, Status.FINISHED, Status.CANCELLED, @@ -112,9 +114,32 @@ class Status(Enum): "process_starts": Status.PROCESSING, "process_generating": Status.ITERATING, "process_completed": Status.FINISHED, + "progress": Status.PROGRESS, }[msg] +@dataclass +class ProgressUnit: + index: Optional[int] + length: Optional[int] + unit: Optional[str] + progress: Optional[float] + desc: Optional[str] + + @classmethod + def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]: + return [ + cls( + index=d.get("index"), + length=d.get("length"), + unit=d.get("unit"), + progress=d.get("progress"), + desc=d.get("desc"), + ) + for d in data + ] + + @dataclass class StatusUpdate: """Update message sent from the worker thread to the Job on the main thread.""" @@ -125,6 +150,7 @@ class StatusUpdate: eta: float | None success: bool | None time: datetime | None + progress_data: List[ProgressUnit] | None def create_initial_status_update(): @@ -135,6 +161,7 @@ def create_initial_status_update(): eta=None, success=None, time=datetime.now(), + progress_data=None, ) @@ -209,6 +236,7 @@ async def get_pred_from_ws( resp = json.loads(msg) if helper: with helper.lock: + has_progress = "progress_data" in resp status_update = StatusUpdate( code=Status.msg_to_status(resp["msg"]), queue_size=resp.get("queue_size"), @@ -216,6 +244,9 @@ async def get_pred_from_ws( success=resp.get("success"), time=datetime.now(), eta=resp.get("rank_eta"), + progress_data=ProgressUnit.from_ws_msg(resp["progress_data"]) + if has_progress + else None, ) output = resp.get("output", {}).get("data", []) if output and status_update.code != Status.FINISHED: diff --git a/client/python/scripts/ci.sh b/client/python/scripts/ci.sh index f2a49271f0..b1a1bed536 100644 --- a/client/python/scripts/ci.sh +++ b/client/python/scripts/ci.sh @@ -9,5 +9,4 @@ black --check test gradio_client pyright gradio_client/*.py echo "Testing..." -python -m pip install -e ../../. # Install gradio from local source (as the latest version may not yet be published to PyPI) python -m pytest test diff --git a/client/python/test/requirements.txt b/client/python/test/requirements.txt index 11ddf3453f..b5d71dc8e8 100644 --- a/client/python/test/requirements.txt +++ b/client/python/test/requirements.txt @@ -3,3 +3,4 @@ pytest-asyncio pytest==7.1.2 ruff==0.0.260 pyright==1.1.298 +gradio diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index fbbf8169e2..08392e4f5e 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -8,12 +8,13 @@ from concurrent.futures import CancelledError, TimeoutError from datetime import datetime, timedelta from unittest.mock import MagicMock, patch +import gradio as gr import pytest from huggingface_hub.utils import RepositoryNotFoundError from gradio_client import Client from gradio_client.serializing import SimpleSerializable -from gradio_client.utils import Communicator, Status, StatusUpdate +from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" @@ -96,6 +97,7 @@ class TestPredictionsFromSpaces: statuses.append(job.status()) statuses.append(job.status()) assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses) + assert not any(s.progress_data for s in statuses) @pytest.mark.flaky def test_intermediate_outputs( @@ -157,6 +159,40 @@ class TestPredictionsFromSpaces: ) assert pathlib.Path(job.result()).exists() + def test_progress_updates(self): + def my_function(x, progress=gr.Progress()): + progress(0, desc="Starting...") + for i in progress.tqdm(range(20)): + time.sleep(0.1) + return x + + demo = gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue( + concurrency_count=20 + ) + _, local_url, _ = demo.launch(prevent_thread_lock=True) + + try: + client = Client(src=local_url) + job = client.submit("hello", api_name="/predict") + statuses = [] + while not job.done(): + statuses.append(job.status()) + time.sleep(0.02) + assert any(s.code == Status.PROGRESS for s in statuses) + assert any(s.progress_data is not None for s in statuses) + all_progress_data = [ + p for s in statuses if s.progress_data for p in s.progress_data + ] + count = 0 + for i in range(20): + unit = ProgressUnit( + index=i, length=20, unit="steps", progress=None, desc=None + ) + count += unit in all_progress_data + assert count + finally: + demo.close() + @pytest.mark.flaky def test_cancel_from_client_queued(self): client = Client(src="gradio-tests/test-cancel-from-client") @@ -284,6 +320,7 @@ class TestStatusUpdates: success=None, queue_size=None, time=now, + progress_data=None, ), StatusUpdate( code=Status.SENDING_DATA, @@ -292,6 +329,7 @@ class TestStatusUpdates: success=None, queue_size=None, time=now + timedelta(seconds=1), + progress_data=None, ), StatusUpdate( code=Status.IN_QUEUE, @@ -300,6 +338,7 @@ class TestStatusUpdates: queue_size=2, success=None, time=now + timedelta(seconds=2), + progress_data=None, ), StatusUpdate( code=Status.IN_QUEUE, @@ -308,6 +347,7 @@ class TestStatusUpdates: queue_size=1, success=None, time=now + timedelta(seconds=3), + progress_data=None, ), StatusUpdate( code=Status.ITERATING, @@ -316,6 +356,7 @@ class TestStatusUpdates: queue_size=None, success=None, time=now + timedelta(seconds=3), + progress_data=None, ), StatusUpdate( code=Status.FINISHED, @@ -324,6 +365,7 @@ class TestStatusUpdates: queue_size=None, success=True, time=now + timedelta(seconds=4), + progress_data=None, ), ] @@ -362,6 +404,7 @@ class TestStatusUpdates: success=None, queue_size=None, time=now, + progress_data=None, ), StatusUpdate( code=Status.FINISHED, @@ -370,6 +413,7 @@ class TestStatusUpdates: queue_size=None, success=True, time=now + timedelta(seconds=4), + progress_data=None, ), ] @@ -381,6 +425,7 @@ class TestStatusUpdates: queue_size=2, success=None, time=now + timedelta(seconds=2), + progress_data=None, ), StatusUpdate( code=Status.IN_QUEUE, @@ -389,6 +434,7 @@ class TestStatusUpdates: queue_size=1, success=None, time=now + timedelta(seconds=3), + progress_data=None, ), ]