From a828b6088cc8d7066ed22e85156a16ec1d61fffe Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Wed, 5 Apr 2023 11:19:05 -0700 Subject: [PATCH] Make Client Jobs Iterable (#3762) * Add iterator * Break if done * Add test for early termination --- client/python/gradio_client/client.py | 21 +++++++++++++++++++++ client/python/test/test_client.py | 12 ++++++++++++ 2 files changed, 33 insertions(+) diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 29aec2bf34..736e63a96b 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -559,6 +559,27 @@ class Job(Future): ): self.future = future self.communicator = communicator + self._counter = 0 + + def __iter__(self) -> Job: + return self + + def __next__(self) -> Tuple | Any: + if not self.communicator: + raise StopIteration() + + with self.communicator.lock: + if self.communicator.job.latest_status.code == Status.FINISHED: + raise StopIteration() + + while True: + with self.communicator.lock: + if len(self.communicator.job.outputs) == self._counter + 1: + o = self.communicator.job.outputs[self._counter] + self._counter += 1 + return o + if self.communicator.job.latest_status.code == Status.FINISHED: + raise StopIteration() def outputs(self) -> List[Tuple | Any]: """Returns a list containing the latest outputs from the Job. diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index ca253e6e03..a782d96111 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -101,6 +101,18 @@ class TestPredictionsFromSpaces: assert job.outputs() == [str(i) for i in range(3)] + outputs = [] + for o in client.predict(3, api_name="/count"): + outputs.append(o) + assert outputs == [str(i) for i in range(3)] + + @pytest.mark.flaky + def test_break_in_loop_if_error(self): + calculator = Client(src="gradio/calculator") + job = calculator.predict("foo", "add", 4, fn_index=0) + output = [o for o in job] + assert output == [] + @pytest.mark.flaky def test_timeout(self): with pytest.raises(TimeoutError):