mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Add support for python client connecting to gradio apps running with self-signed SSL certificates (#7718)
* verify * add changeset * docstring * add changeset * test fixes * add remaining * test fixes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
188b86b766
commit
6390d0bf6c
6
.changeset/kind-eyes-shake.md
Normal file
6
.changeset/kind-eyes-shake.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": patch
|
||||
"gradio_client": patch
|
||||
---
|
||||
|
||||
fix:Add support for python client connecting to gradio apps running with self-signed SSL certificates
|
@ -2,10 +2,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
@ -81,6 +83,7 @@ class Client:
|
||||
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
|
||||
download_files: bool = True, # TODO: consider setting to False in 1.0
|
||||
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
|
||||
ssl_verify: bool = True,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -93,6 +96,7 @@ class Client:
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
|
||||
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
|
||||
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
|
||||
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
@ -111,6 +115,7 @@ class Client:
|
||||
)
|
||||
if headers:
|
||||
self.headers.update(headers)
|
||||
self.ssl_verify = ssl_verify
|
||||
self.space_id = None
|
||||
self.cookies: dict[str, str] = {}
|
||||
self.output_dir = (
|
||||
@ -187,7 +192,9 @@ class Client:
|
||||
|
||||
async def stream_messages(self) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
|
||||
) as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
self.sse_url,
|
||||
@ -227,7 +234,7 @@ class Client:
|
||||
raise e
|
||||
|
||||
async def send_data(self, data, hash_data):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=self.ssl_verify) as client:
|
||||
req = await client.post(
|
||||
self.sse_data_url,
|
||||
json={**data, **hash_data},
|
||||
@ -484,7 +491,12 @@ class Client:
|
||||
else:
|
||||
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
|
||||
if self.app_version > version.Version("3.36.1"):
|
||||
r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies)
|
||||
r = httpx.get(
|
||||
api_info_url,
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
)
|
||||
if r.is_success:
|
||||
info = r.json()
|
||||
else:
|
||||
@ -735,6 +747,7 @@ class Client:
|
||||
resp = httpx.post(
|
||||
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
|
||||
data={"username": auth[0], "password": auth[1]},
|
||||
verify=self.ssl_verify,
|
||||
)
|
||||
if not resp.is_success:
|
||||
if resp.status_code == 401:
|
||||
@ -752,6 +765,7 @@ class Client:
|
||||
urllib.parse.urljoin(self.src, utils.CONFIG_URL),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
)
|
||||
if r.is_success:
|
||||
return r.json()
|
||||
@ -760,7 +774,12 @@ class Client:
|
||||
f"Could not load {self.src} as credentials were not provided. Please login."
|
||||
)
|
||||
else: # to support older versions of Gradio
|
||||
r = httpx.get(self.src, headers=self.headers, cookies=self.cookies)
|
||||
r = httpx.get(
|
||||
self.src,
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
)
|
||||
if not r.is_success:
|
||||
raise ValueError(f"Could not fetch config for {self.src}")
|
||||
# some basic regex to extract the config
|
||||
@ -1126,7 +1145,7 @@ class Endpoint:
|
||||
else:
|
||||
return data
|
||||
|
||||
def _upload_file(self, f: str | dict):
|
||||
def _upload_file(self, f: str | dict) -> dict[str, str]:
|
||||
if isinstance(f, str):
|
||||
warnings.warn(
|
||||
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
|
||||
@ -1137,24 +1156,53 @@ class Endpoint:
|
||||
else:
|
||||
file_path = f["path"]
|
||||
if not utils.is_http_url_like(file_path):
|
||||
file_path = utils.upload_file(
|
||||
file_path=file_path,
|
||||
upload_url=self.client.upload_url,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
)
|
||||
with open(file_path, "rb") as f:
|
||||
files = [("files", (Path(file_path).name, f))]
|
||||
r = httpx.post(
|
||||
self.client.upload_url,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
verify=self.client.ssl_verify,
|
||||
files=files,
|
||||
)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
file_path = result[0]
|
||||
return {"path": file_path}
|
||||
|
||||
def _download_file(self, x: dict) -> str | None:
|
||||
return utils.download_file(
|
||||
self.root_url + "file=" + x["path"],
|
||||
save_dir=self.client.output_dir,
|
||||
def _download_file(self, x: dict) -> str:
|
||||
url_path = self.root_url + "file=" + x["path"]
|
||||
if self.client.output_dir is not None:
|
||||
os.makedirs(self.client.output_dir, exist_ok=True)
|
||||
|
||||
sha1 = hashlib.sha1()
|
||||
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
with httpx.stream(
|
||||
"GET",
|
||||
url_path,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
)
|
||||
verify=self.client.ssl_verify,
|
||||
follow_redirects=True,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
with open(temp_dir / Path(url_path).name, "wb") as f:
|
||||
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
|
||||
sha1.update(chunk)
|
||||
f.write(chunk)
|
||||
|
||||
directory = Path(self.client.output_dir) / sha1.hexdigest()
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
dest = directory / Path(url_path).name
|
||||
shutil.move(temp_dir / Path(url_path).name, dest)
|
||||
return str(dest.resolve())
|
||||
|
||||
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
|
||||
) as client:
|
||||
return await utils.get_pred_from_sse_v0(
|
||||
client,
|
||||
data,
|
||||
@ -1164,6 +1212,7 @@ class Endpoint:
|
||||
self.client.sse_data_url,
|
||||
self.client.headers,
|
||||
self.client.cookies,
|
||||
self.client.ssl_verify,
|
||||
)
|
||||
|
||||
async def _sse_fn_v1_v2(
|
||||
@ -1179,6 +1228,7 @@ class Endpoint:
|
||||
self.client.pending_messages_per_event,
|
||||
event_id,
|
||||
protocol,
|
||||
self.client.ssl_verify,
|
||||
)
|
||||
|
||||
|
||||
|
@ -95,7 +95,10 @@ class EndpointV3Compatibility:
|
||||
raise ValueError(result["error"])
|
||||
else:
|
||||
response = httpx.post(
|
||||
self.client.api_url, headers=self.client.headers, json=data
|
||||
self.client.api_url,
|
||||
headers=self.client.headers,
|
||||
json=data,
|
||||
verify=self.client.ssl_verify,
|
||||
)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
@ -144,7 +147,12 @@ class EndpointV3Compatibility:
|
||||
for f in fs:
|
||||
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
|
||||
indices.append(i)
|
||||
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
|
||||
r = httpx.post(
|
||||
self.client.upload_url,
|
||||
headers=self.client.headers,
|
||||
files=files,
|
||||
verify=self.client.ssl_verify,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
uploaded = file_paths
|
||||
else:
|
||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@ -353,10 +352,11 @@ async def get_pred_from_sse_v0(
|
||||
sse_data_url: str,
|
||||
headers: dict[str, str],
|
||||
cookies: dict[str, str] | None,
|
||||
ssl_verify: bool,
|
||||
) -> dict[str, Any] | None:
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
|
||||
asyncio.create_task(
|
||||
stream_sse_v0(
|
||||
client,
|
||||
@ -393,10 +393,11 @@ async def get_pred_from_sse_v1_v2(
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
|
||||
ssl_verify: bool,
|
||||
) -> dict[str, Any] | None:
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
|
||||
asyncio.create_task(
|
||||
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
|
||||
),
|
||||
@ -421,7 +422,10 @@ async def get_pred_from_sse_v1_v2(
|
||||
|
||||
|
||||
async def check_for_cancel(
|
||||
helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None
|
||||
helper: Communicator,
|
||||
headers: dict[str, str],
|
||||
cookies: dict[str, str] | None,
|
||||
ssl_verify: bool,
|
||||
):
|
||||
while True:
|
||||
await asyncio.sleep(0.05)
|
||||
@ -429,7 +433,7 @@ async def check_for_cancel(
|
||||
if helper.should_cancel:
|
||||
break
|
||||
if helper.event_id:
|
||||
async with httpx.AsyncClient() as http:
|
||||
async with httpx.AsyncClient(ssl_verify=ssl_verify) as http:
|
||||
await http.post(
|
||||
helper.reset_url,
|
||||
json={"event_id": helper.event_id},
|
||||
@ -625,49 +629,6 @@ def apply_diff(obj, diff):
|
||||
########################
|
||||
|
||||
|
||||
def upload_file(
|
||||
file_path: str,
|
||||
upload_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
cookies: dict[str, str] | None = None,
|
||||
):
|
||||
with open(file_path, "rb") as f:
|
||||
files = [("files", (Path(file_path).name, f))]
|
||||
r = httpx.post(upload_url, headers=headers, cookies=cookies, files=files)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
return result[0]
|
||||
|
||||
|
||||
def download_file(
|
||||
url_path: str,
|
||||
save_dir: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
cookies: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
sha1 = hashlib.sha1()
|
||||
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
with httpx.stream(
|
||||
"GET", url_path, headers=headers, cookies=cookies, follow_redirects=True
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
with open(temp_dir / Path(url_path).name, "wb") as f:
|
||||
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
|
||||
sha1.update(chunk)
|
||||
f.write(chunk)
|
||||
|
||||
directory = Path(save_dir) / sha1.hexdigest()
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
dest = directory / Path(url_path).name
|
||||
shutil.move(temp_dir / Path(url_path).name, dest)
|
||||
return str(dest.resolve())
|
||||
|
||||
|
||||
def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str:
|
||||
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gradio as gr
|
||||
import httpx
|
||||
import huggingface_hub
|
||||
import pytest
|
||||
import uvicorn
|
||||
@ -1171,6 +1172,27 @@ class TestEndpoints:
|
||||
"file7",
|
||||
]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_download_private_file(self, gradio_temp_dir):
|
||||
client = Client(
|
||||
src="gradio/zip_files",
|
||||
)
|
||||
url_path = "https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
|
||||
file = client.endpoints[0]._upload_file(url_path) # type: ignore
|
||||
assert file["path"].endswith(".jpg")
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_download_tmp_copy_of_file_does_not_save_errors(
|
||||
self, monkeypatch, gradio_temp_dir
|
||||
):
|
||||
client = Client(
|
||||
src="gradio/zip_files",
|
||||
)
|
||||
error_response = httpx.Response(status_code=404)
|
||||
monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
client.endpoints[0]._download_file({"path": "https://example.com/foo"}) # type: ignore
|
||||
|
||||
|
||||
cpu = huggingface_hub.SpaceHardware.CPU_BASIC
|
||||
|
||||
|
@ -71,26 +71,6 @@ def test_decode_base64_to_file():
|
||||
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_download_private_file(gradio_temp_dir):
|
||||
url_path = (
|
||||
"https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
|
||||
)
|
||||
file = utils.download_file(
|
||||
url_path=url_path,
|
||||
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|
||||
save_dir=str(gradio_temp_dir),
|
||||
)
|
||||
assert Path(file).name.endswith(".jpg")
|
||||
|
||||
|
||||
def test_download_tmp_copy_of_file_does_not_save_errors(monkeypatch, gradio_temp_dir):
|
||||
error_response = httpx.Response(status_code=404)
|
||||
monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
utils.download_file("https://example.com/foo", save_dir=str(gradio_temp_dir))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"orig_filename, new_filename",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user