mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-18 12:50:30 +08:00
Return final output for generators in Client.predict (#5057)
* Add code * add changeset * add changeset * Add feat changeset * Fix js code snippet * Fix changelog * Add test * Delete code * Lint * Make fix for python client * Make JS client changes * Add submit line to error * Changelog * Lint * Undo * Add dual handling --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
667875b244
commit
35856f8b54
.changeset
.config
.github/workflows
client
scripts
17
.changeset/two-spies-shine.md
Normal file
17
.changeset/two-spies-shine.md
Normal file
@ -0,0 +1,17 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
highlight:
|
||||
|
||||
#### Client.predict will now return the final output for streaming endpoints
|
||||
|
||||
### This is a breaking change (for gradio_client only)!
|
||||
|
||||
Previously, `Client.predict` would only return the first output of an endpoint that streamed results. This was causing confusion for developers that wanted to call these streaming demos via the client.
|
||||
|
||||
We realize that developers using the client don't know the internals of whether a demo streams or not, so we're changing the behavior of predict to match developer expectations.
|
||||
|
||||
Using `Client.predict` will now return the final output of a streaming endpoint. This will make it even easier to use gradio apps via the client.
|
@ -3,5 +3,5 @@
|
||||
"singleQuote": false,
|
||||
"trailingComma": "none",
|
||||
"printWidth": 80,
|
||||
"plugins": ["prettier-plugin-svelte"]
|
||||
"plugins": ["prettier-plugin-svelte", "prettier-plugin-css-order"]
|
||||
}
|
||||
|
2
.github/workflows/backend.yml
vendored
2
.github/workflows/backend.yml
vendored
@ -170,7 +170,7 @@ jobs:
|
||||
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Linux'
|
||||
run: |
|
||||
. venv/bin/activate
|
||||
python -m pip install -e . -r test/requirements.txt
|
||||
bash scripts/install_test_requirements.sh
|
||||
- name: Install ffmpeg
|
||||
uses: FedericoCarboni/setup-ffmpeg@v2
|
||||
- name: Lint (Linux)
|
||||
|
@ -341,24 +341,43 @@ export function api_factory(fetch_implementation: typeof fetch): Client {
|
||||
): Promise<unknown> {
|
||||
let data_returned = false;
|
||||
let status_complete = false;
|
||||
let dependency;
|
||||
if (typeof endpoint === "number") {
|
||||
dependency = config.dependencies[endpoint];
|
||||
} else {
|
||||
const trimmed_endpoint = endpoint.replace(/^\//, "");
|
||||
dependency = config.dependencies[api_map[trimmed_endpoint]];
|
||||
}
|
||||
|
||||
if (dependency.types.continuous) {
|
||||
throw new Error(
|
||||
"Cannot call predict on this function as it may run forever. Use submit instead"
|
||||
);
|
||||
}
|
||||
|
||||
return new Promise((res, rej) => {
|
||||
const app = submit(endpoint, data, event_data);
|
||||
let result;
|
||||
|
||||
app
|
||||
.on("data", (d) => {
|
||||
data_returned = true;
|
||||
// if complete message comes before data, resolve here
|
||||
if (status_complete) {
|
||||
app.destroy();
|
||||
res(d);
|
||||
}
|
||||
res(d);
|
||||
data_returned = true;
|
||||
result = d;
|
||||
})
|
||||
.on("status", (status) => {
|
||||
if (status.stage === "error") rej(status);
|
||||
if (status.stage === "complete" && data_returned) {
|
||||
app.destroy();
|
||||
}
|
||||
if (status.stage === "complete") {
|
||||
status_complete = true;
|
||||
app.destroy();
|
||||
// if complete message comes after data, resolve here
|
||||
if (data_returned) {
|
||||
res(result);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
@ -12,7 +12,7 @@ import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future, TimeoutError
|
||||
from concurrent.futures import Future
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@ -283,6 +283,11 @@ class Client:
|
||||
client.predict(5, "add", 4, api_name="/predict")
|
||||
>> 9.0
|
||||
"""
|
||||
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
|
||||
if self.endpoints[inferred_fn_index].is_continuous:
|
||||
raise ValueError(
|
||||
"Cannot call predict on this function as it may run forever. Use submit instead."
|
||||
)
|
||||
return self.submit(*args, api_name=api_name, fn_index=fn_index).result()
|
||||
|
||||
def submit(
|
||||
@ -761,6 +766,7 @@ class Endpoint:
|
||||
self.input_component_types = []
|
||||
self.output_component_types = []
|
||||
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
|
||||
self.is_continuous = dependency.get("types", {}).get("continuous", False)
|
||||
try:
|
||||
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
|
||||
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
|
||||
@ -1103,7 +1109,7 @@ class Job(Future):
|
||||
Parameters:
|
||||
timeout: The number of seconds to wait for the result if the future isn't done. If None, then there is no limit on the wait time.
|
||||
Returns:
|
||||
The result of the call that the future represents.
|
||||
The result of the call that the future represents. For generator functions, it will return the final iteration.
|
||||
Example:
|
||||
from gradio_client import Client
|
||||
calculator = Client(src="gradio/calculator")
|
||||
@ -1111,25 +1117,7 @@ class Job(Future):
|
||||
job.result(timeout=5)
|
||||
>> 9
|
||||
"""
|
||||
if self.communicator:
|
||||
timeout = timeout or float("inf")
|
||||
if self.future._exception: # type: ignore
|
||||
raise self.future._exception # type: ignore
|
||||
with self.communicator.lock:
|
||||
if self.communicator.job.outputs:
|
||||
return self.communicator.job.outputs[0]
|
||||
start = datetime.now()
|
||||
while True:
|
||||
if (datetime.now() - start).seconds > timeout:
|
||||
raise TimeoutError()
|
||||
if self.future._exception: # type: ignore
|
||||
raise self.future._exception # type: ignore
|
||||
with self.communicator.lock:
|
||||
if self.communicator.job.outputs:
|
||||
return self.communicator.job.outputs[0]
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
return super().result(timeout=timeout)
|
||||
return super().result(timeout=timeout)
|
||||
|
||||
def outputs(self) -> list[tuple | Any]:
|
||||
"""
|
||||
|
@ -178,6 +178,32 @@ def count_generator_demo():
|
||||
return demo.queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def count_generator_demo_exception():
|
||||
def count(n):
|
||||
for i in range(int(n)):
|
||||
time.sleep(0.1)
|
||||
if i == 5:
|
||||
raise ValueError("Oh no!")
|
||||
yield i
|
||||
|
||||
def show(n):
|
||||
return str(list(range(int(n))))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
num = gr.Number(value=10)
|
||||
with gr.Row():
|
||||
count_btn = gr.Button("Count")
|
||||
count_forever = gr.Button("Count forever")
|
||||
with gr.Column():
|
||||
out = gr.Textbox()
|
||||
|
||||
count_btn.click(count, num, out, api_name="count")
|
||||
count_forever.click(show, num, out, api_name="count_forever", every=3)
|
||||
return demo.queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_io_demo():
|
||||
demo = gr.Interface(
|
||||
|
@ -133,6 +133,17 @@ class TestClientPredictions:
|
||||
outputs.append(o)
|
||||
assert outputs == [str(i) for i in range(3)]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_intermediate_outputs_with_exception(self, count_generator_demo_exception):
|
||||
with connect(count_generator_demo_exception) as client:
|
||||
with pytest.raises(Exception):
|
||||
client.predict(7, api_name="/count")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot call predict on this function"
|
||||
):
|
||||
client.predict(5, api_name="/count_forever")
|
||||
|
||||
def test_break_in_loop_if_error(self, calculator_demo):
|
||||
with connect(calculator_demo) as client:
|
||||
job = client.submit("foo", "add", 4, fn_index=0)
|
||||
@ -229,8 +240,9 @@ class TestClientPredictions:
|
||||
job.cancel()
|
||||
break
|
||||
time.sleep(0.5)
|
||||
# Result for iterative jobs is always the first result
|
||||
assert job.result() == 0
|
||||
# Result for iterative jobs will raise there is an exception
|
||||
with pytest.raises(CancelledError):
|
||||
job.result()
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
|
@ -7,5 +7,5 @@ pip_required
|
||||
|
||||
echo "Installing requirements before running tests..."
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install -r test/requirements.txt
|
||||
pip install -e client/python
|
||||
|
Loading…
x
Reference in New Issue
Block a user