Add support for python client connecting to gradio apps running with self-signed SSL certificates (#7718)

* verify

* add changeset

* docstring

* add changeset

* test fixes

* add remaining

* test fixes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-15 13:57:23 -07:00 committed by GitHub
parent 188b86b766
commit 6390d0bf6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 114 additions and 87 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
fix:Add support for python client connecting to gradio apps running with self-signed SSL certificates

View File

@ -2,10 +2,12 @@
from __future__ import annotations
import concurrent.futures
import hashlib
import json
import os
import re
import secrets
import shutil
import tempfile
import threading
import time
@ -81,6 +83,7 @@ class Client:
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
download_files: bool = True, # TODO: consider setting to False in 1.0
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
ssl_verify: bool = True,
):
"""
Parameters:
@ -93,6 +96,7 @@ class Client:
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
"""
self.verbose = verbose
self.hf_token = hf_token
@ -111,6 +115,7 @@ class Client:
)
if headers:
self.headers.update(headers)
self.ssl_verify = ssl_verify
self.space_id = None
self.cookies: dict[str, str] = {}
self.output_dir = (
@ -187,7 +192,9 @@ class Client:
async def stream_messages(self) -> None:
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
) as client:
async with client.stream(
"GET",
self.sse_url,
@ -227,7 +234,7 @@ class Client:
raise e
async def send_data(self, data, hash_data):
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(verify=self.ssl_verify) as client:
req = await client.post(
self.sse_data_url,
json={**data, **hash_data},
@ -484,7 +491,12 @@ class Client:
else:
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
if self.app_version > version.Version("3.36.1"):
r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies)
r = httpx.get(
api_info_url,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
)
if r.is_success:
info = r.json()
else:
@ -735,6 +747,7 @@ class Client:
resp = httpx.post(
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
data={"username": auth[0], "password": auth[1]},
verify=self.ssl_verify,
)
if not resp.is_success:
if resp.status_code == 401:
@ -752,6 +765,7 @@ class Client:
urllib.parse.urljoin(self.src, utils.CONFIG_URL),
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
)
if r.is_success:
return r.json()
@ -760,7 +774,12 @@ class Client:
f"Could not load {self.src} as credentials were not provided. Please login."
)
else: # to support older versions of Gradio
r = httpx.get(self.src, headers=self.headers, cookies=self.cookies)
r = httpx.get(
self.src,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
)
if not r.is_success:
raise ValueError(f"Could not fetch config for {self.src}")
# some basic regex to extract the config
@ -1126,7 +1145,7 @@ class Endpoint:
else:
return data
def _upload_file(self, f: str | dict):
def _upload_file(self, f: str | dict) -> dict[str, str]:
if isinstance(f, str):
warnings.warn(
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
@ -1137,24 +1156,53 @@ class Endpoint:
else:
file_path = f["path"]
if not utils.is_http_url_like(file_path):
file_path = utils.upload_file(
file_path=file_path,
upload_url=self.client.upload_url,
headers=self.client.headers,
cookies=self.client.cookies,
)
with open(file_path, "rb") as f:
files = [("files", (Path(file_path).name, f))]
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
cookies=self.client.cookies,
verify=self.client.ssl_verify,
files=files,
)
r.raise_for_status()
result = r.json()
file_path = result[0]
return {"path": file_path}
def _download_file(self, x: dict) -> str | None:
return utils.download_file(
self.root_url + "file=" + x["path"],
save_dir=self.client.output_dir,
def _download_file(self, x: dict) -> str:
url_path = self.root_url + "file=" + x["path"]
if self.client.output_dir is not None:
os.makedirs(self.client.output_dir, exist_ok=True)
sha1 = hashlib.sha1()
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
temp_dir.mkdir(exist_ok=True, parents=True)
with httpx.stream(
"GET",
url_path,
headers=self.client.headers,
cookies=self.client.cookies,
)
verify=self.client.ssl_verify,
follow_redirects=True,
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
sha1.update(chunk)
f.write(chunk)
directory = Path(self.client.output_dir) / sha1.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = directory / Path(url_path).name
shutil.move(temp_dir / Path(url_path).name, dest)
return str(dest.resolve())
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
) as client:
return await utils.get_pred_from_sse_v0(
client,
data,
@ -1164,6 +1212,7 @@ class Endpoint:
self.client.sse_data_url,
self.client.headers,
self.client.cookies,
self.client.ssl_verify,
)
async def _sse_fn_v1_v2(
@ -1179,6 +1228,7 @@ class Endpoint:
self.client.pending_messages_per_event,
event_id,
protocol,
self.client.ssl_verify,
)

View File

@ -95,7 +95,10 @@ class EndpointV3Compatibility:
raise ValueError(result["error"])
else:
response = httpx.post(
self.client.api_url, headers=self.client.headers, json=data
self.client.api_url,
headers=self.client.headers,
json=data,
verify=self.client.ssl_verify,
)
result = json.loads(response.content.decode("utf-8"))
try:
@ -144,7 +147,12 @@ class EndpointV3Compatibility:
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
files=files,
verify=self.client.ssl_verify,
)
if r.status_code != 200:
uploaded = file_paths
else:

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import base64
import copy
import hashlib
import json
import mimetypes
import os
@ -353,10 +352,11 @@ async def get_pred_from_sse_v0(
sse_data_url: str,
headers: dict[str, str],
cookies: dict[str, str] | None,
ssl_verify: bool,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
asyncio.create_task(
stream_sse_v0(
client,
@ -393,10 +393,11 @@ async def get_pred_from_sse_v1_v2(
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
ssl_verify: bool,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
asyncio.create_task(
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
),
@ -421,7 +422,10 @@ async def get_pred_from_sse_v1_v2(
async def check_for_cancel(
helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
ssl_verify: bool,
):
while True:
await asyncio.sleep(0.05)
@ -429,7 +433,7 @@ async def check_for_cancel(
if helper.should_cancel:
break
if helper.event_id:
async with httpx.AsyncClient() as http:
async with httpx.AsyncClient(ssl_verify=ssl_verify) as http:
await http.post(
helper.reset_url,
json={"event_id": helper.event_id},
@ -625,49 +629,6 @@ def apply_diff(obj, diff):
########################
def upload_file(
file_path: str,
upload_url: str,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
):
with open(file_path, "rb") as f:
files = [("files", (Path(file_path).name, f))]
r = httpx.post(upload_url, headers=headers, cookies=cookies, files=files)
r.raise_for_status()
result = r.json()
return result[0]
def download_file(
url_path: str,
save_dir: str,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
) -> str:
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
sha1 = hashlib.sha1()
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
temp_dir.mkdir(exist_ok=True, parents=True)
with httpx.stream(
"GET", url_path, headers=headers, cookies=cookies, follow_redirects=True
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
sha1.update(chunk)
f.write(chunk)
directory = Path(save_dir) / sha1.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = directory / Path(url_path).name
shutil.move(temp_dir / Path(url_path).name, dest)
return str(dest.resolve())
def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str:
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)

View File

@ -11,6 +11,7 @@ from pathlib import Path
from unittest.mock import MagicMock, patch
import gradio as gr
import httpx
import huggingface_hub
import pytest
import uvicorn
@ -1171,6 +1172,27 @@ class TestEndpoints:
"file7",
]
@pytest.mark.flaky
def test_download_private_file(self, gradio_temp_dir):
client = Client(
src="gradio/zip_files",
)
url_path = "https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
file = client.endpoints[0]._upload_file(url_path) # type: ignore
assert file["path"].endswith(".jpg")
@pytest.mark.flaky
def test_download_tmp_copy_of_file_does_not_save_errors(
self, monkeypatch, gradio_temp_dir
):
client = Client(
src="gradio/zip_files",
)
error_response = httpx.Response(status_code=404)
monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response)
with pytest.raises(httpx.HTTPStatusError):
client.endpoints[0]._download_file({"path": "https://example.com/foo"}) # type: ignore
cpu = huggingface_hub.SpaceHardware.CPU_BASIC

View File

@ -71,26 +71,6 @@ def test_decode_base64_to_file():
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
@pytest.mark.flaky
def test_download_private_file(gradio_temp_dir):
url_path = (
"https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
)
file = utils.download_file(
url_path=url_path,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
save_dir=str(gradio_temp_dir),
)
assert Path(file).name.endswith(".jpg")
def test_download_tmp_copy_of_file_does_not_save_errors(monkeypatch, gradio_temp_dir):
error_response = httpx.Response(status_code=404)
monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response)
with pytest.raises(httpx.HTTPStatusError):
utils.download_file("https://example.com/foo", save_dir=str(gradio_temp_dir))
@pytest.mark.parametrize(
"orig_filename, new_filename",
[