mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-07 11:46:51 +08:00
Python client cancel jobs (#3787)
* Working impl * Add tests: * formatting * Fix typo * Always reset iterator state * Add httpx * Fix test * Fix test --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
1def8df6dd
commit
72c9636370
@ -89,6 +89,7 @@ class Client:
|
||||
self.src.replace("http", "ws", 1), utils.WS_URL
|
||||
)
|
||||
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
|
||||
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
|
||||
self.config = self._get_config()
|
||||
self.session_hash = str(uuid.uuid4())
|
||||
|
||||
@ -157,7 +158,10 @@ class Client:
|
||||
helper = None
|
||||
if self.endpoints[inferred_fn_index].use_ws:
|
||||
helper = Communicator(
|
||||
Lock(), JobStatus(), self.endpoints[inferred_fn_index].deserialize
|
||||
Lock(),
|
||||
JobStatus(),
|
||||
self.endpoints[inferred_fn_index].deserialize,
|
||||
self.reset_url,
|
||||
)
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
future = self.executor.submit(end_to_end_fn, *args)
|
||||
@ -829,6 +833,19 @@ class Job(Future):
|
||||
>> 43.241 # seconds
|
||||
"""
|
||||
time = datetime.now()
|
||||
cancelled = False
|
||||
if self.communicator:
|
||||
with self.communicator.lock:
|
||||
cancelled = self.communicator.should_cancel
|
||||
if cancelled:
|
||||
return StatusUpdate(
|
||||
code=Status.CANCELLED,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=False,
|
||||
time=time,
|
||||
eta=None,
|
||||
)
|
||||
if self.done():
|
||||
if not self.future._exception: # type: ignore
|
||||
return StatusUpdate(
|
||||
@ -871,3 +888,24 @@ class Job(Future):
|
||||
def __getattr__(self, name):
|
||||
"""Forwards any properties to the Future class."""
|
||||
return getattr(self.future, name)
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Cancels the job as best as possible.
|
||||
|
||||
If the app you are connecting to has the gradio queue enabled, the job
|
||||
will be cancelled locally as soon as possible. For apps that do not use the
|
||||
queue, the job cannot be cancelled if it's been sent to the local executor
|
||||
(for the time being).
|
||||
|
||||
Note: In general, this DOES not stop the process from running in the upstream server
|
||||
except for the following situations:
|
||||
|
||||
1. If the job is queued upstream, it will be removed from the queue and the server will not run the job
|
||||
2. If the job has iterative outputs, the job will finish as soon as the current iteration finishes running
|
||||
3. If the job has not been picked up by the queue yet, the queue will not pick up the job
|
||||
"""
|
||||
if self.communicator:
|
||||
with self.communicator.lock:
|
||||
self.communicator.should_cancel = True
|
||||
return True
|
||||
return self.future.cancel()
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
@ -7,6 +8,7 @@ import os
|
||||
import pkgutil
|
||||
import shutil
|
||||
import tempfile
|
||||
from concurrent.futures import CancelledError
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@ -15,12 +17,14 @@ from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
import requests
|
||||
from websockets.legacy.protocol import WebSocketCommonProtocol
|
||||
|
||||
API_URL = "/api/predict/"
|
||||
WS_URL = "/queue/join"
|
||||
UPLOAD_URL = "/upload"
|
||||
RESET_URL = "/reset"
|
||||
DUPLICATE_URL = "https://huggingface.co/spaces/{}?duplicate=true"
|
||||
STATE_COMPONENT = "state"
|
||||
|
||||
@ -56,6 +60,7 @@ class Status(Enum):
|
||||
PROCESSING = "PROCESSSING"
|
||||
ITERATING = "ITERATING"
|
||||
FINISHED = "FINISHED"
|
||||
CANCELLED = "CANCELLED"
|
||||
|
||||
@staticmethod
|
||||
def ordering(status: "Status") -> int:
|
||||
@ -69,6 +74,7 @@ class Status(Enum):
|
||||
Status.PROCESSING,
|
||||
Status.ITERATING,
|
||||
Status.FINISHED,
|
||||
Status.CANCELLED,
|
||||
]
|
||||
return order.index(status)
|
||||
|
||||
@ -130,6 +136,8 @@ class Communicator:
|
||||
lock: Lock
|
||||
job: JobStatus
|
||||
deserialize: Callable[..., Tuple]
|
||||
reset_url: str
|
||||
should_cancel: bool = False
|
||||
|
||||
|
||||
########################
|
||||
@ -157,7 +165,27 @@ async def get_pred_from_ws(
|
||||
completed = False
|
||||
resp = {}
|
||||
while not completed:
|
||||
msg = await websocket.recv()
|
||||
# Receive message in the background so that we can
|
||||
# cancel even while running a long pred
|
||||
task = asyncio.create_task(websocket.recv())
|
||||
while not task.done():
|
||||
if helper:
|
||||
with helper.lock:
|
||||
if helper.should_cancel:
|
||||
# Need to reset the iterator state since the client
|
||||
# will not reset the session
|
||||
async with httpx.AsyncClient() as http:
|
||||
reset = http.post(
|
||||
helper.reset_url, json=json.loads(hash_data)
|
||||
)
|
||||
# Retrieve cancel exception from task
|
||||
# otherwise will get nasty warning in console
|
||||
task.cancel()
|
||||
await asyncio.gather(task, reset, return_exceptions=True)
|
||||
raise CancelledError()
|
||||
# Need to suspend this coroutine so that task actually runs
|
||||
await asyncio.sleep(0.01)
|
||||
msg = task.result()
|
||||
resp = json.loads(msg)
|
||||
if helper:
|
||||
with helper.lock:
|
||||
|
@ -4,3 +4,4 @@ packaging
|
||||
fsspec
|
||||
huggingface_hub>=0.13.0
|
||||
typing_extensions
|
||||
httpx
|
@ -3,7 +3,7 @@ import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
from concurrent.futures import TimeoutError
|
||||
from concurrent.futures import CancelledError, TimeoutError
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -150,6 +150,60 @@ class TestPredictionsFromSpaces:
|
||||
)
|
||||
assert pathlib.Path(job.result()).exists()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_cancel_from_client_queued(self):
|
||||
client = Client(src="gradio-tests/test-cancel-from-client")
|
||||
start = time.time()
|
||||
job = client.submit(api_name="/long")
|
||||
while not job.done():
|
||||
if job.status().code == Status.STARTING:
|
||||
job.cancel()
|
||||
break
|
||||
with pytest.raises(CancelledError):
|
||||
job.result()
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
assert time.time() - start < 10
|
||||
assert job.status().code == Status.CANCELLED
|
||||
|
||||
job = client.submit(api_name="/iterate")
|
||||
iteration_count = 0
|
||||
while not job.done():
|
||||
if job.status().code == Status.ITERATING:
|
||||
iteration_count += 1
|
||||
if iteration_count == 3:
|
||||
job.cancel()
|
||||
break
|
||||
time.sleep(0.5)
|
||||
# Result for iterative jobs is always the first result
|
||||
assert job.result() == 0
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
assert time.time() - start < 10
|
||||
|
||||
# Test that we did not iterate all the way to the end
|
||||
assert all(o in [0, 1, 2, 3, 4, 5] for o in job.outputs())
|
||||
assert job.status().code == Status.CANCELLED
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_cancel_subsequent_jobs_state_reset(self):
|
||||
client = Client("abidlabs/test-yield")
|
||||
job1 = client.submit("abcdefefadsadfs")
|
||||
time.sleep(5)
|
||||
job1.cancel()
|
||||
|
||||
assert len(job1.outputs()) < len("abcdefefadsadfs")
|
||||
assert job1.status().code == Status.CANCELLED
|
||||
|
||||
job2 = client.submit("abcd")
|
||||
while not job2.done():
|
||||
time.sleep(0.1)
|
||||
# Ran all iterations from scratch
|
||||
assert job2.status().code == Status.FINISHED
|
||||
assert len(job2.outputs()) == 5
|
||||
|
||||
def test_upload_file_private_space(self):
|
||||
|
||||
client = Client(
|
||||
|
Loading…
Reference in New Issue
Block a user