Make fix in #7444 (Block /file= filepaths that could expose credentials on Windows) more general (#7453)

* test routes

* chagne

* add changeset

* add changeset

* type fixes

* fix typing issues

* typed dict

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-02-16 10:09:31 -08:00 committed by GitHub
parent f52cab634b
commit ba747adb87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 10 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Make fix in #7444 (Block /file= filepaths that could expose credentials on Windows) more general

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import hashlib
import hmac
import json
import re
import shutil
from collections import deque
from dataclasses import dataclass as python_dataclass
@ -466,12 +467,10 @@ class GradioMultiPartParser:
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
disposition, options = parse_options_header(
self._current_part.content_disposition
)
_, options = parse_options_header(self._current_part.content_disposition or b"")
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
options[b"name"], str(self._charset)
)
except KeyError as e:
raise MultiPartException(
@ -483,7 +482,7 @@ class GradioMultiPartParser:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
filename = _user_safe_decode(options[b"filename"], self._charset)
filename = _user_safe_decode(options[b"filename"], str(self._charset))
tempfile = NamedTemporaryFile(delete=False)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = GradioUploadFile(
@ -516,7 +515,7 @@ class GradioMultiPartParser:
raise MultiPartException("Missing boundary in multipart.") from e
# Callbacks dictionary.
callbacks = {
callbacks: multipart.multipart.MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
@ -579,3 +578,11 @@ def update_root_in_config(config: dict, root: str) -> dict:
def compare_passwords_securely(input_password: str, correct_password: str) -> bool:
return hmac.compare_digest(input_password.encode(), correct_password.encode())
def starts_with_protocol(string: str) -> bool:
"""This regex matches strings that start with a scheme (one or more characters not including colon, slash, or space)
followed by ://
"""
pattern = r"^[a-zA-Z][a-zA-Z0-9+\-.]*://"
return re.match(pattern, string) is not None

View File

@ -429,8 +429,7 @@ class App(FastAPI):
url=path_or_url, status_code=status.HTTP_302_FOUND
)
invalid_prefixes = ["//", "file://", "ftp://", "sftp://", "smb://"]
if any(path_or_url.startswith(prefix) for prefix in invalid_prefixes):
if route_utils.starts_with_protocol(path_or_url):
raise HTTPException(403, f"File not allowed: {path_or_url}.")
abs_path = utils.abspath(path_or_url)
@ -779,7 +778,7 @@ class App(FastAPI):
):
content_type_header = request.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
content_type, _ = parse_options_header(content_type_header or "")
if content_type != b"multipart/form-data":
raise HTTPException(status_code=400, detail="Invalid content type.")

View File

@ -15,7 +15,7 @@ packaging
pandas>=1.0,<3.0
pillow>=8.0,<11.0
pydantic>=2.0
python-multipart # required for fastapi forms
python-multipart>=0.0.9 # required for fastapi forms
pydub
pyyaml>=5.0,<7.0
semantic_version~=2.0

View File

@ -29,6 +29,7 @@ from gradio.route_utils import (
FnIndexInferError,
compare_passwords_securely,
get_root_url,
starts_with_protocol,
)
@ -920,3 +921,21 @@ def test_compare_passwords_securely():
assert compare_passwords_securely(password1, password1)
assert not compare_passwords_securely(password1, password2)
assert compare_passwords_securely(password2, password2)
@pytest.mark.parametrize(
"string, expected",
[
("http://localhost:7860/", True),
("https://localhost:7860/", True),
("ftp://localhost:7860/", True),
("smb://example.com", True),
("ipfs://QmTzQ1Nj5R9BzF1djVQv8gvzZxVkJb1vhrLcXL1QyJzZE", True),
("usr/local/bin", False),
("localhost:7860", False),
("localhost", False),
("C:/Users/username", False),
],
)
def test_starts_with_protocol(string, expected):
assert starts_with_protocol(string) == expected