mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Use safehttpx.get()
instead of async_get_with_secure_transport()
(#9795)
* changes * add changeset * format * format * format * add changeset * remove tests * bump --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
5e89b6d23a
commit
ff5be457dc
5
.changeset/green-rings-create.md
Normal file
5
.changeset/green-rings-create.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Use `safehttpx.get()` instead of `async_get_with_secure_transport()`
|
@ -9,8 +9,6 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import ssl
|
||||
import subprocess
|
||||
import tempfile
|
||||
import warnings
|
||||
@ -24,6 +22,7 @@ from urllib.parse import urlparse
|
||||
import aiofiles
|
||||
import httpx
|
||||
import numpy as np
|
||||
import safehttpx as sh
|
||||
from gradio_client import utils as client_utils
|
||||
from PIL import Image, ImageOps, ImageSequence, PngImagePlugin
|
||||
|
||||
@ -326,84 +325,6 @@ def lru_cache_async(maxsize: int = 128):
|
||||
return decorator
|
||||
|
||||
|
||||
@lru_cache_async(maxsize=256)
|
||||
async def async_resolve_hostname_google(hostname: str) -> list[str]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response_v4 = await client.get(
|
||||
f"https://dns.google/resolve?name={hostname}&type=A"
|
||||
)
|
||||
response_v6 = await client.get(
|
||||
f"https://dns.google/resolve?name={hostname}&type=AAAA"
|
||||
)
|
||||
|
||||
ips = []
|
||||
for response in [response_v4.json(), response_v6.json()]:
|
||||
ips.extend([answer["data"] for answer in response.get("Answer", [])])
|
||||
return ips
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
class AsyncSecureTransport(httpx.AsyncHTTPTransport):
|
||||
def __init__(self, verified_ip: str):
|
||||
self.verified_ip = verified_ip
|
||||
super().__init__()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
_timeout: float | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
**_kwargs: Any,
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
sock = await loop.getaddrinfo(self.verified_ip, port)
|
||||
sock = socket.socket(sock[0][0], sock[0][1])
|
||||
await loop.sock_connect(sock, (self.verified_ip, port))
|
||||
if ssl_context:
|
||||
sock = ssl_context.wrap_socket(sock, server_hostname=hostname)
|
||||
return sock
|
||||
|
||||
|
||||
async def async_validate_url(url: str) -> str:
|
||||
hostname = urlparse(url).hostname
|
||||
if not hostname:
|
||||
raise ValueError(f"URL {url} does not have a valid hostname")
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
addrinfo = await loop.getaddrinfo(hostname, None)
|
||||
except socket.gaierror as e:
|
||||
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e
|
||||
|
||||
for family, _, _, _, sockaddr in addrinfo:
|
||||
ip_address = sockaddr[0]
|
||||
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
|
||||
return ip_address
|
||||
|
||||
if not wasm_utils.IS_WASM:
|
||||
for ip_address in await async_resolve_hostname_google(hostname):
|
||||
if is_public_ip(ip_address):
|
||||
return ip_address
|
||||
|
||||
raise ValueError(f"Hostname {hostname} failed validation")
|
||||
|
||||
|
||||
async def async_get_with_secure_transport(
|
||||
url: str, trust_hostname: bool = False
|
||||
) -> httpx.Response:
|
||||
if wasm_utils.IS_WASM:
|
||||
transport = PyodideHttpTransport()
|
||||
elif trust_hostname:
|
||||
transport = None
|
||||
else:
|
||||
verified_ip = await async_validate_url(url)
|
||||
transport = AsyncSecureTransport(verified_ip)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
return await client.get(url, follow_redirects=False)
|
||||
|
||||
|
||||
async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
|
||||
temp_dir = Path(cache_dir) / hash_url(url)
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
@ -416,8 +337,8 @@ async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
||||
response = await async_get_with_secure_transport(
|
||||
url, trust_hostname=hostname in PUBLIC_HOSTNAME_WHITELIST
|
||||
response = await sh.get(
|
||||
url, domain_whitelist=PUBLIC_HOSTNAME_WHITELIST, _transport=async_transport
|
||||
)
|
||||
|
||||
while response.is_redirect:
|
||||
@ -427,7 +348,11 @@ async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
|
||||
if not redirect_parsed.hostname:
|
||||
redirect_url = f"{parsed_url.scheme}://{hostname}{redirect_url}"
|
||||
|
||||
response = await async_get_with_secure_transport(redirect_url)
|
||||
response = await sh.get(
|
||||
redirect_url,
|
||||
domain_whitelist=PUBLIC_HOSTNAME_WHITELIST,
|
||||
_transport=async_transport,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to download file. Status code: {response.status_code}")
|
||||
|
@ -18,6 +18,7 @@ python-multipart>=0.0.9,!=0.0.13 # required for fastapi forms. 0.0.13 was yanke
|
||||
pydub
|
||||
pyyaml>=5.0,<7.0
|
||||
ruff>=0.2.2; sys.platform != 'emscripten'
|
||||
safehttpx>=0.1.1,<1.0
|
||||
semantic_version~=2.0
|
||||
starlette>=0.40.0,<1.0; sys.platform != 'emscripten'
|
||||
tomlkit==0.12.0
|
||||
|
@ -408,37 +408,6 @@ async def test_json_data_not_moved_to_cache():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"https://localhost",
|
||||
"http://127.0.0.1/file/a/b/c",
|
||||
"http://[::1]",
|
||||
"https://192.168.0.1",
|
||||
"http://10.0.0.1?q=a",
|
||||
"http://192.168.1.250.nip.io",
|
||||
],
|
||||
)
|
||||
async def test_local_urls_fail(url):
|
||||
with pytest.raises(ValueError, match="failed validation"):
|
||||
await processing_utils.async_validate_url(url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"https://google.com",
|
||||
"https://8.8.8.8/",
|
||||
"http://93.184.215.14.nip.io/",
|
||||
"https://huggingface.co/datasets/dylanebert/3dgs/resolve/main/luigi/luigi.ply",
|
||||
],
|
||||
)
|
||||
async def test_public_urls_pass(url):
|
||||
await processing_utils.async_validate_url(url)
|
||||
|
||||
|
||||
def test_public_request_pass():
|
||||
tempdir = tempfile.TemporaryDirectory()
|
||||
file = processing_utils.ssrf_protected_download(
|
||||
|
Loading…
x
Reference in New Issue
Block a user