Python client cancel jobs (#3787)

* Working impl

* Add tests:

* formatting

* Fix typo

* Always reset iterator state

* Add httpx

* Fix test

* Fix test

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-04-09 07:22:49 -07:00 committed by GitHub
parent 1def8df6dd
commit 72c9636370
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 125 additions and 4 deletions

View File

@ -89,6 +89,7 @@ class Client:
self.src.replace("http", "ws", 1), utils.WS_URL
)
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
self.config = self._get_config()
self.session_hash = str(uuid.uuid4())
@ -157,7 +158,10 @@ class Client:
helper = None
if self.endpoints[inferred_fn_index].use_ws:
helper = Communicator(
Lock(), JobStatus(), self.endpoints[inferred_fn_index].deserialize
Lock(),
JobStatus(),
self.endpoints[inferred_fn_index].deserialize,
self.reset_url,
)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
@ -829,6 +833,19 @@ class Job(Future):
>> 43.241 # seconds
"""
time = datetime.now()
cancelled = False
if self.communicator:
with self.communicator.lock:
cancelled = self.communicator.should_cancel
if cancelled:
return StatusUpdate(
code=Status.CANCELLED,
rank=0,
queue_size=None,
success=False,
time=time,
eta=None,
)
if self.done():
if not self.future._exception: # type: ignore
return StatusUpdate(
@ -871,3 +888,24 @@ class Job(Future):
def __getattr__(self, name):
"""Forwards any properties to the Future class."""
return getattr(self.future, name)
def cancel(self) -> bool:
"""Cancels the job as best as possible.
If the app you are connecting to has the gradio queue enabled, the job
will be cancelled locally as soon as possible. For apps that do not use the
queue, the job cannot be cancelled if it's been sent to the local executor
(for the time being).
Note: In general, this DOES not stop the process from running in the upstream server
except for the following situations:
1. If the job is queued upstream, it will be removed from the queue and the server will not run the job
2. If the job has iterative outputs, the job will finish as soon as the current iteration finishes running
3. If the job has not been picked up by the queue yet, the queue will not pick up the job
"""
if self.communicator:
with self.communicator.lock:
self.communicator.should_cancel = True
return True
return self.future.cancel()

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import base64
import json
import mimetypes
@ -7,6 +8,7 @@ import os
import pkgutil
import shutil
import tempfile
from concurrent.futures import CancelledError
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@ -15,12 +17,14 @@ from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
import fsspec.asyn
import httpx
import requests
from websockets.legacy.protocol import WebSocketCommonProtocol
API_URL = "/api/predict/"
WS_URL = "/queue/join"
UPLOAD_URL = "/upload"
RESET_URL = "/reset"
DUPLICATE_URL = "https://huggingface.co/spaces/{}?duplicate=true"
STATE_COMPONENT = "state"
@ -56,6 +60,7 @@ class Status(Enum):
PROCESSING = "PROCESSSING"
ITERATING = "ITERATING"
FINISHED = "FINISHED"
CANCELLED = "CANCELLED"
@staticmethod
def ordering(status: "Status") -> int:
@ -69,6 +74,7 @@ class Status(Enum):
Status.PROCESSING,
Status.ITERATING,
Status.FINISHED,
Status.CANCELLED,
]
return order.index(status)
@ -130,6 +136,8 @@ class Communicator:
lock: Lock
job: JobStatus
deserialize: Callable[..., Tuple]
reset_url: str
should_cancel: bool = False
########################
@ -157,7 +165,27 @@ async def get_pred_from_ws(
completed = False
resp = {}
while not completed:
msg = await websocket.recv()
# Receive message in the background so that we can
# cancel even while running a long pred
task = asyncio.create_task(websocket.recv())
while not task.done():
if helper:
with helper.lock:
if helper.should_cancel:
# Need to reset the iterator state since the client
# will not reset the session
async with httpx.AsyncClient() as http:
reset = http.post(
helper.reset_url, json=json.loads(hash_data)
)
# Retrieve cancel exception from task
# otherwise will get nasty warning in console
task.cancel()
await asyncio.gather(task, reset, return_exceptions=True)
raise CancelledError()
# Need to suspend this coroutine so that task actually runs
await asyncio.sleep(0.01)
msg = task.result()
resp = json.loads(msg)
if helper:
with helper.lock:

View File

@ -3,4 +3,5 @@ websockets
packaging
fsspec
huggingface_hub>=0.13.0
typing_extensions
typing_extensions
httpx

View File

@ -3,7 +3,7 @@ import os
import pathlib
import tempfile
import time
from concurrent.futures import TimeoutError
from concurrent.futures import CancelledError, TimeoutError
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
@ -150,6 +150,60 @@ class TestPredictionsFromSpaces:
)
assert pathlib.Path(job.result()).exists()
@pytest.mark.flaky
def test_cancel_from_client_queued(self):
client = Client(src="gradio-tests/test-cancel-from-client")
start = time.time()
job = client.submit(api_name="/long")
while not job.done():
if job.status().code == Status.STARTING:
job.cancel()
break
with pytest.raises(CancelledError):
job.result()
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction
assert time.time() - start < 10
assert job.status().code == Status.CANCELLED
job = client.submit(api_name="/iterate")
iteration_count = 0
while not job.done():
if job.status().code == Status.ITERATING:
iteration_count += 1
if iteration_count == 3:
job.cancel()
break
time.sleep(0.5)
# Result for iterative jobs is always the first result
assert job.result() == 0
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction
assert time.time() - start < 10
# Test that we did not iterate all the way to the end
assert all(o in [0, 1, 2, 3, 4, 5] for o in job.outputs())
assert job.status().code == Status.CANCELLED
@pytest.mark.flaky
def test_cancel_subsequent_jobs_state_reset(self):
client = Client("abidlabs/test-yield")
job1 = client.submit("abcdefefadsadfs")
time.sleep(5)
job1.cancel()
assert len(job1.outputs()) < len("abcdefefadsadfs")
assert job1.status().code == Status.CANCELLED
job2 = client.submit("abcd")
while not job2.done():
time.sleep(0.1)
# Ran all iterations from scratch
assert job2.status().code == Status.FINISHED
assert len(job2.outputs()) == 5
def test_upload_file_private_space(self):
client = Client(