Support the use of custom authentication mechanism, timeouts, and other httpx parameters in Python Client (#8862)

* gradio Client now supports the use of custom authentication mechanism with httpx

* Fix formatting issues

* Replace specific parameter `httpx_auth` by a more general parameter `httpx_kwargs`.

* add changeset

* typing

* future

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
valgai 2024-07-23 23:44:00 +02:00 committed by GitHub
parent 7f1a78c49e
commit ac132e3cbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 6 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---
feat:Support the use of custom authentication mechanism, timeouts, and other `httpx` parameters in Python Client

View File

@ -79,6 +79,7 @@ class Client:
max_workers: int = 40,
verbose: bool = True,
auth: tuple[str, str] | None = None,
httpx_kwargs: dict[str, Any] | None = None,
*,
headers: dict[str, str] | None = None,
download_files: str | Path | Literal[False] = DEFAULT_TEMP_DIR,
@ -94,6 +95,7 @@ class Client:
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. This parameter will override the default headers if they have the same keys.
download_files: Directory where the client should download output files on the local machine from the remote API. By default, uses the value of the GRADIO_TEMP_DIR environment variable which, if not set by the user, is a temporary directory on your machine. If False, the client does not download files and returns a FileData dataclass object with the filepath on the remote machine instead.
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
httpx_kwargs: Additional keyword arguments to pass to `httpx.Client`, `httpx.stream`, `httpx.get` and `httpx.post`. This can be used to set timeouts, proxies, http auth, etc.
"""
self.verbose = verbose
self.hf_token = hf_token
@ -143,6 +145,7 @@ class Client:
if self.verbose:
print(f"Loaded as API: {self.src}")
self.httpx_kwargs = {} if httpx_kwargs is None else httpx_kwargs
if auth is not None:
self._login(auth)
@ -202,13 +205,15 @@ class Client:
while True:
url = self.heartbeat_url.format(session_hash=self.session_hash)
try:
httpx_kwargs = self.httpx_kwargs.copy()
httpx_kwargs.setdefault("timeout", 20)
with httpx.stream(
"GET",
url,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
timeout=20,
**httpx_kwargs,
) as response:
for _ in response.iter_lines():
if self._refresh_heartbeat.is_set():
@ -223,8 +228,11 @@ class Client:
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
) -> None:
try:
httpx_kwargs = self.httpx_kwargs.copy()
httpx_kwargs.setdefault("timeout", httpx.Timeout(timeout=None))
with httpx.Client(
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
verify=self.ssl_verify,
**httpx_kwargs,
) as client:
with client.stream(
"GET",
@ -284,6 +292,7 @@ class Client:
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
**self.httpx_kwargs,
)
if req.status_code == 503:
raise QueueError("Queue is full! Please try again.")
@ -549,6 +558,7 @@ class Client:
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
**self.httpx_kwargs,
)
if r.is_success:
info = r.json()
@ -561,6 +571,7 @@ class Client:
"config": json.dumps(self.config),
"serialize": False,
},
**self.httpx_kwargs,
)
if fetch.is_success:
info = fetch.json()["api"]
@ -823,6 +834,7 @@ class Client:
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
data={"username": auth[0], "password": auth[1]},
verify=self.ssl_verify,
**self.httpx_kwargs,
)
if not resp.is_success:
if resp.status_code == 401:
@ -841,6 +853,7 @@ class Client:
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
**self.httpx_kwargs,
)
if r.is_success:
return r.json()
@ -854,6 +867,7 @@ class Client:
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
**self.httpx_kwargs,
)
if not r.is_success:
raise ValueError(f"Could not fetch config for {self.src}")
@ -1185,6 +1199,7 @@ class Endpoint:
headers=self.client.headers,
cookies=self.client.cookies,
verify=self.client.ssl_verify,
**self.client.httpx_kwargs,
)
return _cancel
@ -1331,6 +1346,7 @@ class Endpoint:
cookies=self.client.cookies,
verify=self.client.ssl_verify,
files=files,
**self.client.httpx_kwargs,
)
r.raise_for_status()
result = r.json()
@ -1360,6 +1376,7 @@ class Endpoint:
cookies=self.client.cookies,
verify=self.client.ssl_verify,
follow_redirects=True,
**self.client.httpx_kwargs,
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
@ -1375,7 +1392,9 @@ class Endpoint:
def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
with httpx.Client(
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
timeout=httpx.Timeout(timeout=None),
verify=self.client.ssl_verify,
**self.client.httpx_kwargs,
) as client:
return utils.get_pred_from_sse_v0(
client,

View File

@ -101,6 +101,7 @@ class EndpointV3Compatibility:
headers=self.client.headers,
json=data,
verify=self.client.ssl_verify,
auth=self.client.httpx_auth,
)
result = json.loads(response.content.decode("utf-8"))
try:
@ -154,6 +155,7 @@ class EndpointV3Compatibility:
headers=self.client.headers,
files=files,
verify=self.client.ssl_verify,
auth=self.client.httpx_auth,
)
if r.status_code != 200:
uploaded = file_paths

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import os
import pathlib
@ -38,11 +40,14 @@ HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
def connect(
demo: gr.Blocks,
download_files: str = DEFAULT_TEMP_DIR,
client_kwargs: dict | None = None,
**kwargs,
):
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
if client_kwargs is None:
client_kwargs = {}
try:
yield Client(local_url, download_files=download_files)
yield Client(local_url, download_files=download_files, **client_kwargs)
finally:
# A more verbose version of .close()
# because we should set a timeout
@ -1406,3 +1411,13 @@ def test_upstream_exceptions(count_generator_demo_exception):
match="The upstream Gradio app has raised an exception but has not enabled verbose error reporting.",
):
client.predict(7, api_name="/count")
def test_httpx_kwargs(increment_demo):
with connect(
increment_demo, client_kwargs={"httpx_kwargs": {"timeout": 5}}
) as client:
with patch("httpx.post", MagicMock()) as mock_post:
with pytest.raises(Exception):
client.predict(1, api_name="/increment_with_queue")
assert mock_post.call_args.kwargs["timeout"] == 5

View File

@ -62,7 +62,7 @@ with gr.Blocks() as demo:
df = df[(df["time"] >= start) & (df["time"] <= end)]
df["time"] = pd.to_datetime(df["time"], unit="s")
unique_users = len(df["session_hash"].unique())
unique_users = len(df["session_hash"].unique()) # type: ignore
total_requests = len(df)
process_time = round(df["process_time"].mean(), 2)
@ -74,7 +74,8 @@ with gr.Blocks() as demo:
if duration >= 60 * 60 * 3
else "1m"
)
df = df.drop(columns=["session_hash"])
df = df.drop(columns=["session_hash"]) # type: ignore
assert isinstance(df, pd.DataFrame) # noqa: S101
return (
gr.BarPlot(value=df, x_bin=x_bin, x_lim=[start, end]),
unique_users,