diff --git a/.changeset/clean-eagles-taste.md b/.changeset/clean-eagles-taste.md new file mode 100644 index 0000000000..65a18ad8dc --- /dev/null +++ b/.changeset/clean-eagles-taste.md @@ -0,0 +1,6 @@ +--- +"gradio": minor +"gradio_client": minor +--- + +feat:Support the use of custom authentication mechanism, timeouts, and other `httpx` parameters in Python Client diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 87d4ada6ce..934d498da0 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -79,6 +79,7 @@ class Client: max_workers: int = 40, verbose: bool = True, auth: tuple[str, str] | None = None, + httpx_kwargs: dict[str, Any] | None = None, *, headers: dict[str, str] | None = None, download_files: str | Path | Literal[False] = DEFAULT_TEMP_DIR, @@ -94,6 +95,7 @@ class Client: headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. This parameter will override the default headers if they have the same keys. download_files: Directory where the client should download output files on the local machine from the remote API. By default, uses the value of the GRADIO_TEMP_DIR environment variable which, if not set by the user, is a temporary directory on your machine. If False, the client does not download files and returns a FileData dataclass object with the filepath on the remote machine instead. ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates. + httpx_kwargs: Additional keyword arguments to pass to `httpx.Client`, `httpx.stream`, `httpx.get` and `httpx.post`. This can be used to set timeouts, proxies, http auth, etc. """ self.verbose = verbose self.hf_token = hf_token @@ -143,6 +145,7 @@ class Client: if self.verbose: print(f"Loaded as API: {self.src} ✔") + self.httpx_kwargs = {} if httpx_kwargs is None else httpx_kwargs if auth is not None: self._login(auth) @@ -202,13 +205,15 @@ class Client: while True: url = self.heartbeat_url.format(session_hash=self.session_hash) try: + httpx_kwargs = self.httpx_kwargs.copy() + httpx_kwargs.setdefault("timeout", 20) with httpx.stream( "GET", url, headers=self.headers, cookies=self.cookies, verify=self.ssl_verify, - timeout=20, + **httpx_kwargs, ) as response: for _ in response.iter_lines(): if self._refresh_heartbeat.is_set(): @@ -223,8 +228,11 @@ class Client: self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"] ) -> None: try: + httpx_kwargs = self.httpx_kwargs.copy() + httpx_kwargs.setdefault("timeout", httpx.Timeout(timeout=None)) with httpx.Client( - timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify + verify=self.ssl_verify, + **httpx_kwargs, ) as client: with client.stream( "GET", @@ -284,6 +292,7 @@ class Client: headers=self.headers, cookies=self.cookies, verify=self.ssl_verify, + **self.httpx_kwargs, ) if req.status_code == 503: raise QueueError("Queue is full! Please try again.") @@ -549,6 +558,7 @@ class Client: headers=self.headers, cookies=self.cookies, verify=self.ssl_verify, + **self.httpx_kwargs, ) if r.is_success: info = r.json() @@ -561,6 +571,7 @@ class Client: "config": json.dumps(self.config), "serialize": False, }, + **self.httpx_kwargs, ) if fetch.is_success: info = fetch.json()["api"] @@ -823,6 +834,7 @@ class Client: urllib.parse.urljoin(self.src, utils.LOGIN_URL), data={"username": auth[0], "password": auth[1]}, verify=self.ssl_verify, + **self.httpx_kwargs, ) if not resp.is_success: if resp.status_code == 401: @@ -841,6 +853,7 @@ class Client: headers=self.headers, cookies=self.cookies, verify=self.ssl_verify, + **self.httpx_kwargs, ) if r.is_success: return r.json() @@ -854,6 +867,7 @@ class Client: headers=self.headers, cookies=self.cookies, verify=self.ssl_verify, + **self.httpx_kwargs, ) if not r.is_success: raise ValueError(f"Could not fetch config for {self.src}") @@ -1185,6 +1199,7 @@ class Endpoint: headers=self.client.headers, cookies=self.client.cookies, verify=self.client.ssl_verify, + **self.client.httpx_kwargs, ) return _cancel @@ -1331,6 +1346,7 @@ class Endpoint: cookies=self.client.cookies, verify=self.client.ssl_verify, files=files, + **self.client.httpx_kwargs, ) r.raise_for_status() result = r.json() @@ -1360,6 +1376,7 @@ class Endpoint: cookies=self.client.cookies, verify=self.client.ssl_verify, follow_redirects=True, + **self.client.httpx_kwargs, ) as response: response.raise_for_status() with open(temp_dir / Path(url_path).name, "wb") as f: @@ -1375,7 +1392,9 @@ class Endpoint: def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): with httpx.Client( - timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify + timeout=httpx.Timeout(timeout=None), + verify=self.client.ssl_verify, + **self.client.httpx_kwargs, ) as client: return utils.get_pred_from_sse_v0( client, diff --git a/client/python/gradio_client/compatibility.py b/client/python/gradio_client/compatibility.py index 7eaae1c331..c9d11c0afa 100644 --- a/client/python/gradio_client/compatibility.py +++ b/client/python/gradio_client/compatibility.py @@ -101,6 +101,7 @@ class EndpointV3Compatibility: headers=self.client.headers, json=data, verify=self.client.ssl_verify, + auth=self.client.httpx_auth, ) result = json.loads(response.content.decode("utf-8")) try: @@ -154,6 +155,7 @@ class EndpointV3Compatibility: headers=self.client.headers, files=files, verify=self.client.ssl_verify, + auth=self.client.httpx_auth, ) if r.status_code != 200: uploaded = file_paths diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 240f0018bf..4bb9120b5b 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import pathlib @@ -38,11 +40,14 @@ HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token() def connect( demo: gr.Blocks, download_files: str = DEFAULT_TEMP_DIR, + client_kwargs: dict | None = None, **kwargs, ): _, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs) + if client_kwargs is None: + client_kwargs = {} try: - yield Client(local_url, download_files=download_files) + yield Client(local_url, download_files=download_files, **client_kwargs) finally: # A more verbose version of .close() # because we should set a timeout @@ -1406,3 +1411,13 @@ def test_upstream_exceptions(count_generator_demo_exception): match="The upstream Gradio app has raised an exception but has not enabled verbose error reporting.", ): client.predict(7, api_name="/count") + + +def test_httpx_kwargs(increment_demo): + with connect( + increment_demo, client_kwargs={"httpx_kwargs": {"timeout": 5}} + ) as client: + with patch("httpx.post", MagicMock()) as mock_post: + with pytest.raises(Exception): + client.predict(1, api_name="/increment_with_queue") + assert mock_post.call_args.kwargs["timeout"] == 5 diff --git a/gradio/monitoring_dashboard.py b/gradio/monitoring_dashboard.py index fd39ae1da6..a18a2e5088 100644 --- a/gradio/monitoring_dashboard.py +++ b/gradio/monitoring_dashboard.py @@ -62,7 +62,7 @@ with gr.Blocks() as demo: df = df[(df["time"] >= start) & (df["time"] <= end)] df["time"] = pd.to_datetime(df["time"], unit="s") - unique_users = len(df["session_hash"].unique()) + unique_users = len(df["session_hash"].unique()) # type: ignore total_requests = len(df) process_time = round(df["process_time"].mean(), 2) @@ -74,7 +74,8 @@ with gr.Blocks() as demo: if duration >= 60 * 60 * 3 else "1m" ) - df = df.drop(columns=["session_hash"]) + df = df.drop(columns=["session_hash"]) # type: ignore + assert isinstance(df, pd.DataFrame) # noqa: S101 return ( gr.BarPlot(value=df, x_bin=x_bin, x_lim=[start, end]), unique_users,