mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
Make Client Jobs Iterable (#3762)
* Add iterator * Break if done * Add test for early termination
This commit is contained in:
parent
7c1db51d95
commit
a828b6088c
@ -559,6 +559,27 @@ class Job(Future):
|
||||
):
|
||||
self.future = future
|
||||
self.communicator = communicator
|
||||
self._counter = 0
|
||||
|
||||
def __iter__(self) -> Job:
|
||||
return self
|
||||
|
||||
def __next__(self) -> Tuple | Any:
|
||||
if not self.communicator:
|
||||
raise StopIteration()
|
||||
|
||||
with self.communicator.lock:
|
||||
if self.communicator.job.latest_status.code == Status.FINISHED:
|
||||
raise StopIteration()
|
||||
|
||||
while True:
|
||||
with self.communicator.lock:
|
||||
if len(self.communicator.job.outputs) == self._counter + 1:
|
||||
o = self.communicator.job.outputs[self._counter]
|
||||
self._counter += 1
|
||||
return o
|
||||
if self.communicator.job.latest_status.code == Status.FINISHED:
|
||||
raise StopIteration()
|
||||
|
||||
def outputs(self) -> List[Tuple | Any]:
|
||||
"""Returns a list containing the latest outputs from the Job.
|
||||
|
@ -101,6 +101,18 @@ class TestPredictionsFromSpaces:
|
||||
|
||||
assert job.outputs() == [str(i) for i in range(3)]
|
||||
|
||||
outputs = []
|
||||
for o in client.predict(3, api_name="/count"):
|
||||
outputs.append(o)
|
||||
assert outputs == [str(i) for i in range(3)]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_break_in_loop_if_error(self):
|
||||
calculator = Client(src="gradio/calculator")
|
||||
job = calculator.predict("foo", "add", 4, fn_index=0)
|
||||
output = [o for o in job]
|
||||
assert output == []
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_timeout(self):
|
||||
with pytest.raises(TimeoutError):
|
||||
|
Loading…
Reference in New Issue
Block a user