2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-04-18 12:50:30 +08:00

Return final output for generators in Client.predict ()

* Add code

* add changeset

* add changeset

* Add feat changeset

* Fix js code snippet

* Fix changelog

* Add test

* Delete code

* Lint

* Make fix for python client

* Make JS client changes

* Add submit line to error

* Changelog

* Lint

* Undo

* Add dual handling

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2023-08-09 13:03:56 -04:00 committed by GitHub
parent 667875b244
commit 35856f8b54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 93 additions and 31 deletions

@ -0,0 +1,17 @@
---
"@gradio/app": minor
"gradio": minor
"gradio_client": minor
---
highlight:
#### Client.predict will now return the final output for streaming endpoints
### This is a breaking change (for gradio_client only)!
Previously, `Client.predict` would only return the first output of an endpoint that streamed results. This was causing confusion for developers that wanted to call these streaming demos via the client.
We realize that developers using the client don't know the internals of whether a demo streams or not, so we're changing the behavior of predict to match developer expectations.
Using `Client.predict` will now return the final output of a streaming endpoint. This will make it even easier to use gradio apps via the client.

@ -3,5 +3,5 @@
"singleQuote": false,
"trailingComma": "none",
"printWidth": 80,
"plugins": ["prettier-plugin-svelte"]
"plugins": ["prettier-plugin-svelte", "prettier-plugin-css-order"]
}

@ -170,7 +170,7 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Linux'
run: |
. venv/bin/activate
python -m pip install -e . -r test/requirements.txt
bash scripts/install_test_requirements.sh
- name: Install ffmpeg
uses: FedericoCarboni/setup-ffmpeg@v2
- name: Lint (Linux)

@ -341,24 +341,43 @@ export function api_factory(fetch_implementation: typeof fetch): Client {
): Promise<unknown> {
let data_returned = false;
let status_complete = false;
let dependency;
if (typeof endpoint === "number") {
dependency = config.dependencies[endpoint];
} else {
const trimmed_endpoint = endpoint.replace(/^\//, "");
dependency = config.dependencies[api_map[trimmed_endpoint]];
}
if (dependency.types.continuous) {
throw new Error(
"Cannot call predict on this function as it may run forever. Use submit instead"
);
}
return new Promise((res, rej) => {
const app = submit(endpoint, data, event_data);
let result;
app
.on("data", (d) => {
data_returned = true;
// if complete message comes before data, resolve here
if (status_complete) {
app.destroy();
res(d);
}
res(d);
data_returned = true;
result = d;
})
.on("status", (status) => {
if (status.stage === "error") rej(status);
if (status.stage === "complete" && data_returned) {
app.destroy();
}
if (status.stage === "complete") {
status_complete = true;
app.destroy();
// if complete message comes after data, resolve here
if (data_returned) {
res(result);
}
}
});
});

@ -12,7 +12,7 @@ import time
import urllib.parse
import uuid
import warnings
from concurrent.futures import Future, TimeoutError
from concurrent.futures import Future
from datetime import datetime
from pathlib import Path
from threading import Lock
@ -283,6 +283,11 @@ class Client:
client.predict(5, "add", 4, api_name="/predict")
>> 9.0
"""
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
if self.endpoints[inferred_fn_index].is_continuous:
raise ValueError(
"Cannot call predict on this function as it may run forever. Use submit instead."
)
return self.submit(*args, api_name=api_name, fn_index=fn_index).result()
def submit(
@ -761,6 +766,7 @@ class Endpoint:
self.input_component_types = []
self.output_component_types = []
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
self.is_continuous = dependency.get("types", {}).get("continuous", False)
try:
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
@ -1103,7 +1109,7 @@ class Job(Future):
Parameters:
timeout: The number of seconds to wait for the result if the future isn't done. If None, then there is no limit on the wait time.
Returns:
The result of the call that the future represents.
The result of the call that the future represents. For generator functions, it will return the final iteration.
Example:
from gradio_client import Client
calculator = Client(src="gradio/calculator")
@ -1111,25 +1117,7 @@ class Job(Future):
job.result(timeout=5)
>> 9
"""
if self.communicator:
timeout = timeout or float("inf")
if self.future._exception: # type: ignore
raise self.future._exception # type: ignore
with self.communicator.lock:
if self.communicator.job.outputs:
return self.communicator.job.outputs[0]
start = datetime.now()
while True:
if (datetime.now() - start).seconds > timeout:
raise TimeoutError()
if self.future._exception: # type: ignore
raise self.future._exception # type: ignore
with self.communicator.lock:
if self.communicator.job.outputs:
return self.communicator.job.outputs[0]
time.sleep(0.01)
else:
return super().result(timeout=timeout)
return super().result(timeout=timeout)
def outputs(self) -> list[tuple | Any]:
"""

@ -178,6 +178,32 @@ def count_generator_demo():
return demo.queue()
@pytest.fixture
def count_generator_demo_exception():
def count(n):
for i in range(int(n)):
time.sleep(0.1)
if i == 5:
raise ValueError("Oh no!")
yield i
def show(n):
return str(list(range(int(n))))
with gr.Blocks() as demo:
with gr.Column():
num = gr.Number(value=10)
with gr.Row():
count_btn = gr.Button("Count")
count_forever = gr.Button("Count forever")
with gr.Column():
out = gr.Textbox()
count_btn.click(count, num, out, api_name="count")
count_forever.click(show, num, out, api_name="count_forever", every=3)
return demo.queue()
@pytest.fixture
def file_io_demo():
demo = gr.Interface(

@ -133,6 +133,17 @@ class TestClientPredictions:
outputs.append(o)
assert outputs == [str(i) for i in range(3)]
@pytest.mark.flaky
def test_intermediate_outputs_with_exception(self, count_generator_demo_exception):
with connect(count_generator_demo_exception) as client:
with pytest.raises(Exception):
client.predict(7, api_name="/count")
with pytest.raises(
ValueError, match="Cannot call predict on this function"
):
client.predict(5, api_name="/count_forever")
def test_break_in_loop_if_error(self, calculator_demo):
with connect(calculator_demo) as client:
job = client.submit("foo", "add", 4, fn_index=0)
@ -229,8 +240,9 @@ class TestClientPredictions:
job.cancel()
break
time.sleep(0.5)
# Result for iterative jobs is always the first result
assert job.result() == 0
# Result for iterative jobs will raise there is an exception
with pytest.raises(CancelledError):
job.result()
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction

@ -7,5 +7,5 @@ pip_required
echo "Installing requirements before running tests..."
pip install --upgrade pip
pip install -r requirements.txt
pip install -r test/requirements.txt
pip install -e client/python