mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Support gr.Progress() in python client (#3924)
* Add progress message * CHANGELOG * Dont use pydantic * Docs + local test * Add gr to requirements * Remove editable install * make a bit softer
This commit is contained in:
parent
6b854eefd4
commit
d835c9a816
@ -2,7 +2,7 @@
|
||||
|
||||
## New Features:
|
||||
|
||||
No changes to highlight.
|
||||
- Progress Updates from `gr.Progress()` can be accessed via `job.status().progress_data` by @freddyaboulton](https://github.com/freddyaboulton) in [PR 3924](https://github.com/gradio-app/gradio/pull/3924)
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
|
@ -949,7 +949,11 @@ class Job(Future):
|
||||
def status(self) -> StatusUpdate:
|
||||
"""
|
||||
Returns the latest status update from the Job in the form of a StatusUpdate
|
||||
object, which contains the following fields: code, rank, queue_size, success, time, eta.
|
||||
object, which contains the following fields: code, rank, queue_size, success, time, eta, and progress_data.
|
||||
|
||||
progress_data is a list of updates emitted by the gr.Progress() tracker of the event handler. Each element
|
||||
of the list has the following fields: index, length, unit, progress, desc. If the event handler does not have
|
||||
a gr.Progress() tracker, the progress_data field will be None.
|
||||
|
||||
Example:
|
||||
from gradio_client import Client
|
||||
@ -973,6 +977,7 @@ class Job(Future):
|
||||
success=False,
|
||||
time=time,
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
if self.done():
|
||||
if not self.future._exception: # type: ignore
|
||||
@ -983,6 +988,7 @@ class Job(Future):
|
||||
success=True,
|
||||
time=time,
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
else:
|
||||
return StatusUpdate(
|
||||
@ -992,6 +998,7 @@ class Job(Future):
|
||||
success=False,
|
||||
time=time,
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
else:
|
||||
if not self.communicator:
|
||||
@ -1002,6 +1009,7 @@ class Job(Future):
|
||||
success=None,
|
||||
time=time,
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
|
@ -14,7 +14,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
@ -79,6 +79,7 @@ class Status(Enum):
|
||||
SENDING_DATA = "SENDING_DATA"
|
||||
PROCESSING = "PROCESSING"
|
||||
ITERATING = "ITERATING"
|
||||
PROGRESS = "PROGRESS"
|
||||
FINISHED = "FINISHED"
|
||||
CANCELLED = "CANCELLED"
|
||||
|
||||
@ -92,6 +93,7 @@ class Status(Enum):
|
||||
Status.IN_QUEUE,
|
||||
Status.SENDING_DATA,
|
||||
Status.PROCESSING,
|
||||
Status.PROGRESS,
|
||||
Status.ITERATING,
|
||||
Status.FINISHED,
|
||||
Status.CANCELLED,
|
||||
@ -112,9 +114,32 @@ class Status(Enum):
|
||||
"process_starts": Status.PROCESSING,
|
||||
"process_generating": Status.ITERATING,
|
||||
"process_completed": Status.FINISHED,
|
||||
"progress": Status.PROGRESS,
|
||||
}[msg]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressUnit:
|
||||
index: Optional[int]
|
||||
length: Optional[int]
|
||||
unit: Optional[str]
|
||||
progress: Optional[float]
|
||||
desc: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]:
|
||||
return [
|
||||
cls(
|
||||
index=d.get("index"),
|
||||
length=d.get("length"),
|
||||
unit=d.get("unit"),
|
||||
progress=d.get("progress"),
|
||||
desc=d.get("desc"),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusUpdate:
|
||||
"""Update message sent from the worker thread to the Job on the main thread."""
|
||||
@ -125,6 +150,7 @@ class StatusUpdate:
|
||||
eta: float | None
|
||||
success: bool | None
|
||||
time: datetime | None
|
||||
progress_data: List[ProgressUnit] | None
|
||||
|
||||
|
||||
def create_initial_status_update():
|
||||
@ -135,6 +161,7 @@ def create_initial_status_update():
|
||||
eta=None,
|
||||
success=None,
|
||||
time=datetime.now(),
|
||||
progress_data=None,
|
||||
)
|
||||
|
||||
|
||||
@ -209,6 +236,7 @@ async def get_pred_from_ws(
|
||||
resp = json.loads(msg)
|
||||
if helper:
|
||||
with helper.lock:
|
||||
has_progress = "progress_data" in resp
|
||||
status_update = StatusUpdate(
|
||||
code=Status.msg_to_status(resp["msg"]),
|
||||
queue_size=resp.get("queue_size"),
|
||||
@ -216,6 +244,9 @@ async def get_pred_from_ws(
|
||||
success=resp.get("success"),
|
||||
time=datetime.now(),
|
||||
eta=resp.get("rank_eta"),
|
||||
progress_data=ProgressUnit.from_ws_msg(resp["progress_data"])
|
||||
if has_progress
|
||||
else None,
|
||||
)
|
||||
output = resp.get("output", {}).get("data", [])
|
||||
if output and status_update.code != Status.FINISHED:
|
||||
|
@ -9,5 +9,4 @@ black --check test gradio_client
|
||||
pyright gradio_client/*.py
|
||||
|
||||
echo "Testing..."
|
||||
python -m pip install -e ../../. # Install gradio from local source (as the latest version may not yet be published to PyPI)
|
||||
python -m pytest test
|
||||
|
@ -3,3 +3,4 @@ pytest-asyncio
|
||||
pytest==7.1.2
|
||||
ruff==0.0.260
|
||||
pyright==1.1.298
|
||||
gradio
|
||||
|
@ -8,12 +8,13 @@ from concurrent.futures import CancelledError, TimeoutError
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gradio as gr
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from gradio_client import Client
|
||||
from gradio_client.serializing import SimpleSerializable
|
||||
from gradio_client.utils import Communicator, Status, StatusUpdate
|
||||
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
|
||||
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
@ -96,6 +97,7 @@ class TestPredictionsFromSpaces:
|
||||
statuses.append(job.status())
|
||||
statuses.append(job.status())
|
||||
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
|
||||
assert not any(s.progress_data for s in statuses)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_intermediate_outputs(
|
||||
@ -157,6 +159,40 @@ class TestPredictionsFromSpaces:
|
||||
)
|
||||
assert pathlib.Path(job.result()).exists()
|
||||
|
||||
def test_progress_updates(self):
|
||||
def my_function(x, progress=gr.Progress()):
|
||||
progress(0, desc="Starting...")
|
||||
for i in progress.tqdm(range(20)):
|
||||
time.sleep(0.1)
|
||||
return x
|
||||
|
||||
demo = gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue(
|
||||
concurrency_count=20
|
||||
)
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||
|
||||
try:
|
||||
client = Client(src=local_url)
|
||||
job = client.submit("hello", api_name="/predict")
|
||||
statuses = []
|
||||
while not job.done():
|
||||
statuses.append(job.status())
|
||||
time.sleep(0.02)
|
||||
assert any(s.code == Status.PROGRESS for s in statuses)
|
||||
assert any(s.progress_data is not None for s in statuses)
|
||||
all_progress_data = [
|
||||
p for s in statuses if s.progress_data for p in s.progress_data
|
||||
]
|
||||
count = 0
|
||||
for i in range(20):
|
||||
unit = ProgressUnit(
|
||||
index=i, length=20, unit="steps", progress=None, desc=None
|
||||
)
|
||||
count += unit in all_progress_data
|
||||
assert count
|
||||
finally:
|
||||
demo.close()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_cancel_from_client_queued(self):
|
||||
client = Client(src="gradio-tests/test-cancel-from-client")
|
||||
@ -284,6 +320,7 @@ class TestStatusUpdates:
|
||||
success=None,
|
||||
queue_size=None,
|
||||
time=now,
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.SENDING_DATA,
|
||||
@ -292,6 +329,7 @@ class TestStatusUpdates:
|
||||
success=None,
|
||||
queue_size=None,
|
||||
time=now + timedelta(seconds=1),
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
@ -300,6 +338,7 @@ class TestStatusUpdates:
|
||||
queue_size=2,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=2),
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
@ -308,6 +347,7 @@ class TestStatusUpdates:
|
||||
queue_size=1,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=3),
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.ITERATING,
|
||||
@ -316,6 +356,7 @@ class TestStatusUpdates:
|
||||
queue_size=None,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=3),
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
@ -324,6 +365,7 @@ class TestStatusUpdates:
|
||||
queue_size=None,
|
||||
success=True,
|
||||
time=now + timedelta(seconds=4),
|
||||
progress_data=None,
|
||||
),
|
||||
]
|
||||
|
||||
@ -362,6 +404,7 @@ class TestStatusUpdates:
|
||||
success=None,
|
||||
queue_size=None,
|
||||
time=now,
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
@ -370,6 +413,7 @@ class TestStatusUpdates:
|
||||
queue_size=None,
|
||||
success=True,
|
||||
time=now + timedelta(seconds=4),
|
||||
progress_data=None,
|
||||
),
|
||||
]
|
||||
|
||||
@ -381,6 +425,7 @@ class TestStatusUpdates:
|
||||
queue_size=2,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=2),
|
||||
progress_data=None,
|
||||
),
|
||||
StatusUpdate(
|
||||
code=Status.IN_QUEUE,
|
||||
@ -389,6 +434,7 @@ class TestStatusUpdates:
|
||||
queue_size=1,
|
||||
success=None,
|
||||
time=now + timedelta(seconds=3),
|
||||
progress_data=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user