mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Support the use of custom authentication mechanism, timeouts, and other httpx
parameters in Python Client (#8862)
* gradio Client now supports the use of custom authentication mechanism with httpx * Fix formatting issues * Replace specific parameter `httpx_auth` by a more general parameter `httpx_kwargs`. * add changeset * typing * future --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
7f1a78c49e
commit
ac132e3cbc
6
.changeset/clean-eagles-taste.md
Normal file
6
.changeset/clean-eagles-taste.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
feat:Support the use of custom authentication mechanism, timeouts, and other `httpx` parameters in Python Client
|
@ -79,6 +79,7 @@ class Client:
|
||||
max_workers: int = 40,
|
||||
verbose: bool = True,
|
||||
auth: tuple[str, str] | None = None,
|
||||
httpx_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
download_files: str | Path | Literal[False] = DEFAULT_TEMP_DIR,
|
||||
@ -94,6 +95,7 @@ class Client:
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. This parameter will override the default headers if they have the same keys.
|
||||
download_files: Directory where the client should download output files on the local machine from the remote API. By default, uses the value of the GRADIO_TEMP_DIR environment variable which, if not set by the user, is a temporary directory on your machine. If False, the client does not download files and returns a FileData dataclass object with the filepath on the remote machine instead.
|
||||
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
|
||||
httpx_kwargs: Additional keyword arguments to pass to `httpx.Client`, `httpx.stream`, `httpx.get` and `httpx.post`. This can be used to set timeouts, proxies, http auth, etc.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
@ -143,6 +145,7 @@ class Client:
|
||||
if self.verbose:
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
|
||||
self.httpx_kwargs = {} if httpx_kwargs is None else httpx_kwargs
|
||||
if auth is not None:
|
||||
self._login(auth)
|
||||
|
||||
@ -202,13 +205,15 @@ class Client:
|
||||
while True:
|
||||
url = self.heartbeat_url.format(session_hash=self.session_hash)
|
||||
try:
|
||||
httpx_kwargs = self.httpx_kwargs.copy()
|
||||
httpx_kwargs.setdefault("timeout", 20)
|
||||
with httpx.stream(
|
||||
"GET",
|
||||
url,
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
timeout=20,
|
||||
**httpx_kwargs,
|
||||
) as response:
|
||||
for _ in response.iter_lines():
|
||||
if self._refresh_heartbeat.is_set():
|
||||
@ -223,8 +228,11 @@ class Client:
|
||||
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
|
||||
) -> None:
|
||||
try:
|
||||
httpx_kwargs = self.httpx_kwargs.copy()
|
||||
httpx_kwargs.setdefault("timeout", httpx.Timeout(timeout=None))
|
||||
with httpx.Client(
|
||||
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
|
||||
verify=self.ssl_verify,
|
||||
**httpx_kwargs,
|
||||
) as client:
|
||||
with client.stream(
|
||||
"GET",
|
||||
@ -284,6 +292,7 @@ class Client:
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
**self.httpx_kwargs,
|
||||
)
|
||||
if req.status_code == 503:
|
||||
raise QueueError("Queue is full! Please try again.")
|
||||
@ -549,6 +558,7 @@ class Client:
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
**self.httpx_kwargs,
|
||||
)
|
||||
if r.is_success:
|
||||
info = r.json()
|
||||
@ -561,6 +571,7 @@ class Client:
|
||||
"config": json.dumps(self.config),
|
||||
"serialize": False,
|
||||
},
|
||||
**self.httpx_kwargs,
|
||||
)
|
||||
if fetch.is_success:
|
||||
info = fetch.json()["api"]
|
||||
@ -823,6 +834,7 @@ class Client:
|
||||
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
|
||||
data={"username": auth[0], "password": auth[1]},
|
||||
verify=self.ssl_verify,
|
||||
**self.httpx_kwargs,
|
||||
)
|
||||
if not resp.is_success:
|
||||
if resp.status_code == 401:
|
||||
@ -841,6 +853,7 @@ class Client:
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
**self.httpx_kwargs,
|
||||
)
|
||||
if r.is_success:
|
||||
return r.json()
|
||||
@ -854,6 +867,7 @@ class Client:
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
**self.httpx_kwargs,
|
||||
)
|
||||
if not r.is_success:
|
||||
raise ValueError(f"Could not fetch config for {self.src}")
|
||||
@ -1185,6 +1199,7 @@ class Endpoint:
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
verify=self.client.ssl_verify,
|
||||
**self.client.httpx_kwargs,
|
||||
)
|
||||
|
||||
return _cancel
|
||||
@ -1331,6 +1346,7 @@ class Endpoint:
|
||||
cookies=self.client.cookies,
|
||||
verify=self.client.ssl_verify,
|
||||
files=files,
|
||||
**self.client.httpx_kwargs,
|
||||
)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
@ -1360,6 +1376,7 @@ class Endpoint:
|
||||
cookies=self.client.cookies,
|
||||
verify=self.client.ssl_verify,
|
||||
follow_redirects=True,
|
||||
**self.client.httpx_kwargs,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
with open(temp_dir / Path(url_path).name, "wb") as f:
|
||||
@ -1375,7 +1392,9 @@ class Endpoint:
|
||||
|
||||
def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
|
||||
with httpx.Client(
|
||||
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
|
||||
timeout=httpx.Timeout(timeout=None),
|
||||
verify=self.client.ssl_verify,
|
||||
**self.client.httpx_kwargs,
|
||||
) as client:
|
||||
return utils.get_pred_from_sse_v0(
|
||||
client,
|
||||
|
@ -101,6 +101,7 @@ class EndpointV3Compatibility:
|
||||
headers=self.client.headers,
|
||||
json=data,
|
||||
verify=self.client.ssl_verify,
|
||||
auth=self.client.httpx_auth,
|
||||
)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
@ -154,6 +155,7 @@ class EndpointV3Compatibility:
|
||||
headers=self.client.headers,
|
||||
files=files,
|
||||
verify=self.client.ssl_verify,
|
||||
auth=self.client.httpx_auth,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
uploaded = file_paths
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
@ -38,11 +40,14 @@ HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
|
||||
def connect(
|
||||
demo: gr.Blocks,
|
||||
download_files: str = DEFAULT_TEMP_DIR,
|
||||
client_kwargs: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
|
||||
if client_kwargs is None:
|
||||
client_kwargs = {}
|
||||
try:
|
||||
yield Client(local_url, download_files=download_files)
|
||||
yield Client(local_url, download_files=download_files, **client_kwargs)
|
||||
finally:
|
||||
# A more verbose version of .close()
|
||||
# because we should set a timeout
|
||||
@ -1406,3 +1411,13 @@ def test_upstream_exceptions(count_generator_demo_exception):
|
||||
match="The upstream Gradio app has raised an exception but has not enabled verbose error reporting.",
|
||||
):
|
||||
client.predict(7, api_name="/count")
|
||||
|
||||
|
||||
def test_httpx_kwargs(increment_demo):
|
||||
with connect(
|
||||
increment_demo, client_kwargs={"httpx_kwargs": {"timeout": 5}}
|
||||
) as client:
|
||||
with patch("httpx.post", MagicMock()) as mock_post:
|
||||
with pytest.raises(Exception):
|
||||
client.predict(1, api_name="/increment_with_queue")
|
||||
assert mock_post.call_args.kwargs["timeout"] == 5
|
||||
|
@ -62,7 +62,7 @@ with gr.Blocks() as demo:
|
||||
df = df[(df["time"] >= start) & (df["time"] <= end)]
|
||||
df["time"] = pd.to_datetime(df["time"], unit="s")
|
||||
|
||||
unique_users = len(df["session_hash"].unique())
|
||||
unique_users = len(df["session_hash"].unique()) # type: ignore
|
||||
total_requests = len(df)
|
||||
process_time = round(df["process_time"].mean(), 2)
|
||||
|
||||
@ -74,7 +74,8 @@ with gr.Blocks() as demo:
|
||||
if duration >= 60 * 60 * 3
|
||||
else "1m"
|
||||
)
|
||||
df = df.drop(columns=["session_hash"])
|
||||
df = df.drop(columns=["session_hash"]) # type: ignore
|
||||
assert isinstance(df, pd.DataFrame) # noqa: S101
|
||||
return (
|
||||
gr.BarPlot(value=df, x_bin=x_bin, x_lim=[start, end]),
|
||||
unique_users,
|
||||
|
Loading…
Reference in New Issue
Block a user