mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Get Intermediate Results from Python Client (#3694)
* Add status + unit test (flaky) for now * Install client * Fix tests * Lint backend + tests * Add non-queue test * Fix name * Use lock instead * Add simplify implementation + fix tests * Restore changes to scripts * Fix README typo * Fix CI * Add intermediate results to python client * Type check * Typecheck again * Catch exception: * Thinking * Dont read generator from config * add no queue test * Remove unused method * Fix types * Remove breakpoint * Fix code * Fix test * Fix tests * Unpack list * Add docstring --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
c4ad09b631
commit
9325cba14c
@ -5,8 +5,9 @@ import concurrent.futures
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import Future, TimeoutError
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
@ -76,7 +77,7 @@ class Client:
|
||||
api_name: str | None = None,
|
||||
fn_index: int | None = None,
|
||||
result_callbacks: Callable | List[Callable] | None = None,
|
||||
) -> Future:
|
||||
) -> Job:
|
||||
"""
|
||||
Parameters:
|
||||
*args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app.
|
||||
@ -90,7 +91,9 @@ class Client:
|
||||
|
||||
helper = None
|
||||
if self.endpoints[inferred_fn_index].use_ws:
|
||||
helper = Communicator(Lock(), JobStatus())
|
||||
helper = Communicator(
|
||||
Lock(), JobStatus(), self.endpoints[inferred_fn_index].deserialize
|
||||
)
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
future = self.executor.submit(end_to_end_fn, *args)
|
||||
|
||||
@ -389,19 +392,14 @@ class Endpoint:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
inputs = self.serialize(*data)
|
||||
predictions = _predict(*inputs)
|
||||
outputs = self.deserialize(*predictions)
|
||||
if (
|
||||
len(
|
||||
[
|
||||
oct
|
||||
for oct in self.output_component_types
|
||||
if not oct == utils.STATE_COMPONENT
|
||||
]
|
||||
)
|
||||
== 1
|
||||
):
|
||||
return outputs[0]
|
||||
return outputs
|
||||
output = self.deserialize(*predictions)
|
||||
# Append final output only if not already present
|
||||
# for consistency between generators and not generators
|
||||
if helper:
|
||||
with helper.lock:
|
||||
if not helper.job.outputs:
|
||||
helper.job.outputs.append(output)
|
||||
return output
|
||||
|
||||
return _inner
|
||||
|
||||
@ -461,11 +459,11 @@ class Endpoint:
|
||||
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
return tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
|
||||
def deserialize(self, *data) -> Tuple:
|
||||
def deserialize(self, *data) -> Tuple | Any:
|
||||
assert len(data) == len(
|
||||
self.deserializers
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
return tuple(
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
|
||||
for s, d, oct in zip(
|
||||
@ -474,6 +472,20 @@ class Endpoint:
|
||||
if not oct == utils.STATE_COMPONENT
|
||||
]
|
||||
)
|
||||
if (
|
||||
len(
|
||||
[
|
||||
oct
|
||||
for oct in self.output_component_types
|
||||
if not oct == utils.STATE_COMPONENT
|
||||
]
|
||||
)
|
||||
== 1
|
||||
):
|
||||
output = outputs[0]
|
||||
else:
|
||||
output = outputs
|
||||
return output
|
||||
|
||||
def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
|
||||
inputs = self.dependency["inputs"]
|
||||
@ -529,7 +541,10 @@ class Endpoint:
|
||||
|
||||
async def _ws_fn(self, data, hash_data, helper: Communicator):
|
||||
async with websockets.connect( # type: ignore
|
||||
self.client.ws_url, open_timeout=10, extra_headers=self.client.headers
|
||||
self.client.ws_url,
|
||||
open_timeout=10,
|
||||
extra_headers=self.client.headers,
|
||||
max_size=1024 * 1024 * 1024,
|
||||
) as websocket:
|
||||
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
|
||||
|
||||
@ -537,23 +552,89 @@ class Endpoint:
|
||||
class Job(Future):
|
||||
"""A Job is a thin wrapper over the Future class that can be cancelled."""
|
||||
|
||||
def __init__(self, future: Future, communicator: Communicator | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
future: Future,
|
||||
communicator: Communicator | None = None,
|
||||
):
|
||||
self.future = future
|
||||
self.communicator = communicator
|
||||
|
||||
def status(self) -> StatusUpdate:
|
||||
def outputs(self) -> List[Tuple | Any]:
|
||||
"""Returns a list containing the latest outputs from the Job.
|
||||
|
||||
If the endpoint has multiple output components, the list will contain
|
||||
a tuple of results. Otherwise, it will contain the results without storing them
|
||||
in tuples.
|
||||
|
||||
For endpoints that are queued, this list will contain the final job output even
|
||||
if that endpoint does not use a generator function.
|
||||
"""
|
||||
if not self.communicator:
|
||||
time = datetime.now()
|
||||
if self.done():
|
||||
return []
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
return self.communicator.job.outputs
|
||||
|
||||
def result(self, timeout=None):
|
||||
"""Return the result of the call that the future represents.
|
||||
|
||||
Args:
|
||||
timeout: The number of seconds to wait for the result if the future
|
||||
isn't done. If None, then there is no limit on the wait time.
|
||||
|
||||
Returns:
|
||||
The result of the call that the future represents.
|
||||
|
||||
Raises:
|
||||
CancelledError: If the future was cancelled.
|
||||
TimeoutError: If the future didn't finish executing before the given
|
||||
timeout.
|
||||
Exception: If the call raised then that exception will be raised.
|
||||
"""
|
||||
if self.communicator:
|
||||
timeout = timeout or float("inf")
|
||||
if self.future._exception: # type: ignore
|
||||
raise self.future._exception # type: ignore
|
||||
with self.communicator.lock:
|
||||
if self.communicator.job.outputs:
|
||||
return self.communicator.job.outputs[0]
|
||||
start = datetime.now()
|
||||
while True:
|
||||
if (datetime.now() - start).seconds > timeout:
|
||||
raise TimeoutError()
|
||||
if self.future._exception: # type: ignore
|
||||
raise self.future._exception # type: ignore
|
||||
with self.communicator.lock:
|
||||
if self.communicator.job.outputs:
|
||||
return self.communicator.job.outputs[0]
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
return super().result(timeout=timeout)
|
||||
|
||||
def status(self) -> StatusUpdate:
|
||||
time = datetime.now()
|
||||
if self.done():
|
||||
if not self.future._exception: # type: ignore
|
||||
return StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=None,
|
||||
success=True,
|
||||
time=time,
|
||||
eta=None,
|
||||
)
|
||||
else:
|
||||
return StatusUpdate(
|
||||
code=Status.FINISHED,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=False,
|
||||
time=time,
|
||||
eta=None,
|
||||
)
|
||||
else:
|
||||
if not self.communicator:
|
||||
return StatusUpdate(
|
||||
code=Status.PROCESSING,
|
||||
rank=0,
|
||||
@ -562,9 +643,9 @@ class Job(Future):
|
||||
time=time,
|
||||
eta=None,
|
||||
)
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
return self.communicator.job.latest_status
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
return self.communicator.job.latest_status
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forwards any properties to the Future class."""
|
||||
|
@ -127,6 +127,7 @@ class Communicator:
|
||||
|
||||
lock: Lock
|
||||
job: JobStatus
|
||||
deserialize: Callable[..., Tuple]
|
||||
|
||||
|
||||
########################
|
||||
@ -166,6 +167,13 @@ async def get_pred_from_ws(
|
||||
time=datetime.now(),
|
||||
eta=resp.get("rank_eta"),
|
||||
)
|
||||
output = resp.get("output", {}).get("data", [])
|
||||
if output and status_update.code != Status.FINISHED:
|
||||
try:
|
||||
result = helper.deserialize(*output)
|
||||
except Exception as e:
|
||||
result = [e]
|
||||
helper.job.outputs.append(result)
|
||||
helper.job.latest_status = status_update
|
||||
if resp["msg"] == "queue_full":
|
||||
raise QueueError("Queue is full! Please try again.")
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from concurrent.futures import TimeoutError
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -81,7 +82,7 @@ class TestPredictionsFromSpaces:
|
||||
def test_job_status_queue_disabled(self):
|
||||
statuses = []
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.predict("I love the gradio python client", fn_index=0)
|
||||
job = client.predict("I love the gradio python client", api_name="/classify")
|
||||
while not job.done():
|
||||
time.sleep(0.02)
|
||||
statuses.append(job.status())
|
||||
@ -89,6 +90,45 @@ class TestPredictionsFromSpaces:
|
||||
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_intermediate_outputs(
|
||||
self,
|
||||
):
|
||||
client = Client(src="gradio/count_generator")
|
||||
job = client.predict(3, api_name="/count")
|
||||
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
|
||||
assert job.outputs() == [str(i) for i in range(3)]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_timeout(self):
|
||||
with pytest.raises(TimeoutError):
|
||||
client = Client(src="gradio/count_generator")
|
||||
job = client.predict(api_name="/sleep")
|
||||
job.result(timeout=0.05)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_timeout_no_queue(self):
|
||||
with pytest.raises(TimeoutError):
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.predict(api_name="/sleep")
|
||||
job.result(timeout=0.1)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_raises_exception(self):
|
||||
with pytest.raises(Exception):
|
||||
client = Client(src="freddyaboulton/calculator")
|
||||
job = client.predict("foo", "add", 9, fn_index=0)
|
||||
job.result()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_raises_exception_no_queue(self):
|
||||
with pytest.raises(Exception):
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.predict([5], api_name="/sleep")
|
||||
job.result()
|
||||
|
||||
def test_job_output_video(self):
|
||||
client = Client(src="gradio/video_component")
|
||||
job = client.predict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user