Make Client Jobs Iterable (#3762)

* Add iterator

* Break if done

* Add test for early termination
This commit is contained in:
Freddy Boulton 2023-04-05 11:19:05 -07:00 committed by GitHub
parent 7c1db51d95
commit a828b6088c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 0 deletions

View File

@ -559,6 +559,27 @@ class Job(Future):
):
self.future = future
self.communicator = communicator
self._counter = 0
def __iter__(self) -> Job:
return self
def __next__(self) -> Tuple | Any:
if not self.communicator:
raise StopIteration()
with self.communicator.lock:
if self.communicator.job.latest_status.code == Status.FINISHED:
raise StopIteration()
while True:
with self.communicator.lock:
if len(self.communicator.job.outputs) == self._counter + 1:
o = self.communicator.job.outputs[self._counter]
self._counter += 1
return o
if self.communicator.job.latest_status.code == Status.FINISHED:
raise StopIteration()
def outputs(self) -> List[Tuple | Any]:
"""Returns a list containing the latest outputs from the Job.

View File

@ -101,6 +101,18 @@ class TestPredictionsFromSpaces:
assert job.outputs() == [str(i) for i in range(3)]
outputs = []
for o in client.predict(3, api_name="/count"):
outputs.append(o)
assert outputs == [str(i) for i in range(3)]
@pytest.mark.flaky
def test_break_in_loop_if_error(self):
calculator = Client(src="gradio/calculator")
job = calculator.predict("foo", "add", 4, fn_index=0)
output = [o for o in job]
assert output == []
@pytest.mark.flaky
def test_timeout(self):
with pytest.raises(TimeoutError):