diff --git a/.changeset/ten-lands-change.md b/.changeset/ten-lands-change.md new file mode 100644 index 0000000000..293e412fed --- /dev/null +++ b/.changeset/ten-lands-change.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Some tweaks to is_in_or_equal diff --git a/gradio/blocks.py b/gradio/blocks.py index 246d4b5b08..15c764bfba 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -56,7 +56,13 @@ from gradio.context import ( get_render_context, set_render_context, ) -from gradio.data_classes import BlocksConfigDict, FileData, GradioModel, GradioRootModel +from gradio.data_classes import ( + BlocksConfigDict, + DeveloperPath, + FileData, + GradioModel, + GradioRootModel, +) from gradio.events import ( EventData, EventListener, @@ -409,7 +415,7 @@ class BlockContext(Block): render=render, ) - TEMPLATE_DIR = "./templates/" + TEMPLATE_DIR = DeveloperPath("./templates/") FRONTEND_DIR = "../../frontend/" @property diff --git a/gradio/components/base.py b/gradio/components/base.py index 93efa94f69..96d3957c6e 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -19,7 +19,7 @@ import gradio_client.utils as client_utils from gradio import utils from gradio.blocks import Block, BlockContext from gradio.component_meta import ComponentMeta -from gradio.data_classes import BaseModel, GradioDataModel +from gradio.data_classes import BaseModel, DeveloperPath, GradioDataModel from gradio.events import EventListener from gradio.layouts import Form from gradio.processing_utils import move_files_to_cache @@ -228,7 +228,7 @@ class Component(ComponentBase, Block): self.component_class_id = self.__class__.get_component_class_id() - TEMPLATE_DIR = "./templates/" + TEMPLATE_DIR = DeveloperPath("./templates/") FRONTEND_DIR = "../../frontend/" def get_config(self): diff --git a/gradio/components/file_explorer.py b/gradio/components/file_explorer.py index 1429d50281..9cf3c389d3 100644 --- a/gradio/components/file_explorer.py +++ b/gradio/components/file_explorer.py @@ -11,7 +11,8 @@ from typing import TYPE_CHECKING, Any, Callable, List, Literal, Sequence from gradio_client.documentation import document from gradio.components.base import Component, server -from gradio.data_classes import GradioRootModel +from gradio.data_classes import DeveloperPath, GradioRootModel, UserProvidedPath +from gradio.utils import safe_join if TYPE_CHECKING: from gradio.components import Timer @@ -85,7 +86,7 @@ class FileExplorer(Component): ) root_dir = root self._constructor_args[0]["root_dir"] = root - self.root_dir = os.path.abspath(root_dir) + self.root_dir = DeveloperPath(os.path.abspath(root_dir)) self.glob = glob self.ignore_glob = ignore_glob valid_file_count = ["single", "multiple"] @@ -202,11 +203,8 @@ class FileExplorer(Component): return folders + files - def _safe_join(self, folders): - combined_path = os.path.join(self.root_dir, *folders) - absolute_path = os.path.abspath(combined_path) - if os.path.commonprefix([self.root_dir, absolute_path]) != os.path.abspath( - self.root_dir - ): - raise ValueError("Attempted to navigate outside of root directory") - return absolute_path + def _safe_join(self, folders: list[str]): + if not folders or len(folders) == 0: + return self.root_dir + combined_path = UserProvidedPath(os.path.join(*folders)) + return safe_join(self.root_dir, combined_path) diff --git a/gradio/data_classes.py b/gradio/data_classes.py index 7c1489d2b1..3dedcfe302 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -8,7 +8,17 @@ import secrets import shutil from abc import ABC, abstractmethod from enum import Enum, auto -from typing import Any, Iterator, List, Literal, Optional, Tuple, TypedDict, Union +from typing import ( + Any, + Iterator, + List, + Literal, + NewType, + Optional, + Tuple, + TypedDict, + Union, +) from fastapi import Request from gradio_client.documentation import document @@ -21,6 +31,9 @@ try: except ImportError: JsonValue = Any +DeveloperPath = NewType("DeveloperPath", str) +UserProvidedPath = NewType("UserProvidedPath", str) + class CancelBody(BaseModel): session_hash: str diff --git a/gradio/exceptions.py b/gradio/exceptions.py index f353e526e6..1ab9dc4c09 100644 --- a/gradio/exceptions.py +++ b/gradio/exceptions.py @@ -98,3 +98,7 @@ class Error(Exception): class ComponentDefinitionError(NotImplementedError): pass + + +class InvalidPathError(ValueError): + pass diff --git a/gradio/route_utils.py b/gradio/route_utils.py index f6e6e64c16..380470238f 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -42,7 +42,10 @@ from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send from gradio import processing_utils, utils -from gradio.data_classes import BlocksConfigDict, PredictBody +from gradio.data_classes import ( + BlocksConfigDict, + PredictBody, +) from gradio.exceptions import Error from gradio.helpers import EventData from gradio.state_holder import SessionState diff --git a/gradio/routes.py b/gradio/routes.py index 8aced1bcc8..63d397b526 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -18,7 +18,6 @@ import inspect import json import mimetypes import os -import posixpath import secrets import time import traceback @@ -35,6 +34,7 @@ from typing import ( Optional, Type, Union, + cast, ) import fastapi @@ -73,10 +73,13 @@ from gradio.data_classes import ( ComponentServerBlobBody, ComponentServerJSONBody, DataWithFiles, + DeveloperPath, PredictBody, ResetBody, SimplePredictBody, + UserProvidedPath, ) +from gradio.exceptions import InvalidPathError from gradio.oauth import attach_oauth from gradio.route_utils import ( # noqa: F401 CustomCORSMiddleware, @@ -109,9 +112,18 @@ if TYPE_CHECKING: mimetypes.init() -STATIC_TEMPLATE_LIB = files("gradio").joinpath("templates").as_posix() # type: ignore -STATIC_PATH_LIB = files("gradio").joinpath("templates", "frontend", "static").as_posix() # type: ignore -BUILD_PATH_LIB = files("gradio").joinpath("templates", "frontend", "assets").as_posix() # type: ignore +STATIC_TEMPLATE_LIB = cast( + DeveloperPath, + files("gradio").joinpath("templates").as_posix(), # type: ignore +) +STATIC_PATH_LIB = cast( + DeveloperPath, + files("gradio").joinpath("templates", "frontend", "static").as_posix(), # type: ignore +) +BUILD_PATH_LIB = cast( + DeveloperPath, + files("gradio").joinpath("templates", "frontend", "assets").as_posix(), # type: ignore +) VERSION = get_package_version() @@ -446,7 +458,7 @@ class App(FastAPI): @app.get("/static/{path:path}") def static_resource(path: str): - static_file = safe_join(STATIC_PATH_LIB, path) + static_file = routes_safe_join(STATIC_PATH_LIB, UserProvidedPath(path)) return FileResponse(static_file) @app.get("/custom_component/{id}/{type}/{file_name}") @@ -458,7 +470,6 @@ class App(FastAPI): location = next( (item for item in components if item["component_class_id"] == id), None ) - if location is None: raise HTTPException(status_code=404, detail="Component not found.") @@ -470,9 +481,14 @@ class App(FastAPI): if module_path is None or component_instance is None: raise HTTPException(status_code=404, detail="Component not found.") - path = safe_join( - str(Path(module_path).parent), - f"{component_instance.__class__.TEMPLATE_DIR}/{type}/{file_name}", + requested_path = utils.safe_join( + component_instance.__class__.TEMPLATE_DIR, + UserProvidedPath(f"{type}/{file_name}"), + ) + + path = routes_safe_join( + DeveloperPath(str(Path(module_path).parent)), + UserProvidedPath(requested_path), ) key = f"{id}-{type}-{file_name}" @@ -494,7 +510,7 @@ class App(FastAPI): @app.get("/assets/{path:path}") def build_resource(path: str): - build_file = safe_join(BUILD_PATH_LIB, path) + build_file = routes_safe_join(BUILD_PATH_LIB, UserProvidedPath(path)) return FileResponse(build_file) @app.get("/favicon.ico") @@ -543,7 +559,7 @@ class App(FastAPI): is_dir = abs_path.is_dir() - if in_blocklist or is_dir: + if is_dir or in_blocklist: raise HTTPException(403, f"File not allowed: {path_or_url}.") created_by_app = False @@ -1142,7 +1158,14 @@ class App(FastAPI): name = f"tmp{secrets.token_hex(5)}" directory = Path(app.uploaded_file_dir) / temp_file.sha.hexdigest() directory.mkdir(exist_ok=True, parents=True) - dest = (directory / name).resolve() + try: + dest = utils.safe_join( + DeveloperPath(str(directory)), UserProvidedPath(name) + ) + except InvalidPathError as err: + raise HTTPException( + status_code=400, detail=f"Invalid file name: {name}" + ) from err temp_file.file.close() # we need to move the temp file to the cache directory # but that's possibly blocking and we're in an async function @@ -1153,9 +1176,9 @@ class App(FastAPI): os.rename(temp_file.file.name, dest) except OSError: files_to_copy.append(temp_file.file.name) - locations.append(str(dest)) + locations.append(dest) output_files.append(dest) - blocks.upload_file_set.add(str(dest)) + blocks.upload_file_set.add(dest) if files_to_copy: bg_tasks.add_task( move_uploaded_files_to_cache, files_to_copy, locations @@ -1218,32 +1241,22 @@ class App(FastAPI): ######## -def safe_join(directory: str, path: str) -> str: - """Safely path to a base directory to avoid escaping the base directory. - Borrowed from: werkzeug.security.safe_join""" - _os_alt_seps: List[str] = [ - sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/" - ] - +def routes_safe_join(directory: DeveloperPath, path: UserProvidedPath) -> str: + """Safely join the user path to the directory while performing some additional http-related checks, + e.g. ensuring that the full path exists on the local file system and is not a directory""" if path == "": - raise HTTPException(400) + raise fastapi.HTTPException(400) if route_utils.starts_with_protocol(path): - raise HTTPException(403) - filename = posixpath.normpath(path) - fullpath = os.path.join(directory, filename) - if ( - any(sep in filename for sep in _os_alt_seps) - or os.path.isabs(filename) - or filename == ".." - or filename.startswith("../") - or os.path.isdir(fullpath) - ): - raise HTTPException(403) - - if not os.path.exists(fullpath): - raise HTTPException(404, "File not found") - - return fullpath + raise fastapi.HTTPException(403) + try: + fullpath = Path(utils.safe_join(directory, path)) + except InvalidPathError as e: + raise fastapi.HTTPException(403) from e + if fullpath.is_dir(): + raise fastapi.HTTPException(403) + if not fullpath.exists(): + raise fastapi.HTTPException(404) + return str(fullpath) def get_types(cls_set: List[Type]): diff --git a/gradio/utils.py b/gradio/utils.py index 7cc8d1badd..573a7a1836 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -14,6 +14,7 @@ import json import json.decoder import os import pkgutil +import posixpath import re import sys import tempfile @@ -38,6 +39,7 @@ from typing import ( Generic, Iterable, Iterator, + List, Literal, Optional, Sequence, @@ -53,8 +55,13 @@ from typing_extensions import ParamSpec import gradio from gradio.context import get_blocks_context -from gradio.data_classes import BlocksConfigDict, FileData -from gradio.exceptions import Error +from gradio.data_classes import ( + BlocksConfigDict, + DeveloperPath, + FileData, + UserProvidedPath, +) +from gradio.exceptions import Error, InvalidPathError from gradio.strings import en if TYPE_CHECKING: # Only import for type checking (is False at runtime). @@ -1056,24 +1063,10 @@ def tex2svg(formula, *_args): def abspath(path: str | Path) -> Path: - """Returns absolute path of a str or Path path, but does not resolve symlinks.""" - path = Path(path) - - if path.is_absolute(): - return path - - # recursively check if there is a symlink within the path - is_symlink = path.is_symlink() or any( - parent.is_symlink() for parent in path.parents - ) - - if is_symlink or path == path.resolve(): # in case path couldn't be resolved - return Path.cwd() / path - else: - return path.resolve() + return Path(os.path.abspath(str(path))) -def is_in_or_equal(path_1: str | Path, path_2: str | Path): +def is_in_or_equal(path_1: str | Path, path_2: str | Path) -> bool: """ True if path_1 is a descendant (i.e. located within) path_2 or if the paths are the same, returns False otherwise. @@ -1090,7 +1083,6 @@ def is_in_or_equal(path_1: str | Path, path_2: str | Path): return ".." not in str(relative_path) except ValueError: return False - return True @document() @@ -1466,3 +1458,23 @@ class UnhashableKeyDict(MutableMapping): def as_list(self): return [v for _, v in self.data] + + +def safe_join(directory: DeveloperPath, path: UserProvidedPath) -> str: + """Safely path to a base directory to avoid escaping the base directory. + Borrowed from: werkzeug.security.safe_join""" + _os_alt_seps: List[str] = [ + sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/" + ] + + filename = posixpath.normpath(path) + fullpath = os.path.join(directory, filename) + if ( + any(sep in filename for sep in _os_alt_seps) + or os.path.isabs(filename) + or filename == ".." + or filename.startswith("../") + ): + raise InvalidPathError() + + return fullpath diff --git a/test/components/test_file_explorer.py b/test/components/test_file_explorer.py index bc9371f7e0..585fa70192 100644 --- a/test/components/test_file_explorer.py +++ b/test/components/test_file_explorer.py @@ -1,7 +1,10 @@ from pathlib import Path +import pytest + import gradio as gr from gradio.components.file_explorer import FileExplorerData +from gradio.exceptions import InvalidPathError class TestFileExplorer: @@ -61,3 +64,9 @@ class TestFileExplorer: {"name": "file2.txt", "type": "file", "valid": True}, ] assert tree == answer + + def test_file_explorer_prevents_path_traversal(self, tmpdir): + file_explorer = gr.FileExplorer(glob="*.txt", root_dir=Path(tmpdir)) + + with pytest.raises(InvalidPathError): + file_explorer.preprocess(FileExplorerData(root=[["../file.txt"]])) diff --git a/test/requirements.txt b/test/requirements.txt index 67fbfd889d..200bcc1982 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -20,7 +20,7 @@ appnope==0.1.4 # via ipython asyncio==3.4.3 # via -r requirements.in -attrs==21.4.0 +attrs==23.1.0 # via # jsonschema # pytest @@ -106,6 +106,7 @@ huggingface-hub==0.21.4 # gradio-client # tokenizers # transformers +hypothesis==6.108.9 idna==3.3 # via # anyio diff --git a/test/test_routes.py b/test/test_routes.py index 223cebdab6..11fd68d6a8 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -5,6 +5,7 @@ import os import tempfile import time from contextlib import asynccontextmanager, closing +from pathlib import Path from typing import Dict from unittest.mock import patch @@ -1335,3 +1336,45 @@ def test_docs_url(): assert r.status_code == 200 finally: demo.close() + + +def test_file_access(): + with gr.Blocks() as demo: + gr.Markdown("Test") + + allowed_dir = (Path(tempfile.gettempdir()) / "test_file_access_dir").resolve() + allowed_dir.mkdir(parents=True, exist_ok=True) + allowed_file = Path(allowed_dir / "allowed.txt") + allowed_file.touch() + + not_allowed_file = Path(tempfile.gettempdir()) / "not_allowed.txt" + not_allowed_file.touch() + + app, _, _ = demo.launch( + prevent_thread_lock=True, + blocked_paths=["test/test_files"], + allowed_paths=[str(allowed_dir)], + ) + test_client = TestClient(app) + try: + with test_client: + r = test_client.get(f"/file={allowed_dir}/allowed.txt") + assert r.status_code == 200 + r = test_client.get(f"/file={allowed_dir}/../not_allowed.txt") + assert r.status_code == 403 + r = test_client.get("/file=//test/test_files/cheetah1.jpg") + assert r.status_code == 403 + r = test_client.get("/file=test/test_files/cheetah1.jpg") + assert r.status_code == 403 + r = test_client.get("/file=//test/test_files/cheetah1.jpg") + assert r.status_code == 403 + tmp = Path(tempfile.gettempdir()) / "upload_test.txt" + tmp.write_text("Hello") + with open(str(tmp), "rb") as f: + files = {"files": ("..", f)} + response = test_client.post("/upload", files=files) + assert response.status_code == 400 + finally: + demo.close() + not_allowed_file.unlink() + allowed_file.unlink() diff --git a/test/test_utils.py b/test/test_utils.py index 06cc60c02d..d3ad92818d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,6 +9,8 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest +from hypothesis import given, settings +from hypothesis import strategies as st from typing_extensions import Literal from gradio import EventData, Request @@ -369,6 +371,69 @@ def test_is_in_or_equal(): assert not is_in_or_equal("/safe_dir/subdir/../../unsafe_file.txt", "/safe_dir/") +def create_path_string(): + return st.lists( + st.one_of( + st.text( + alphabet="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-", + min_size=1, + ), + st.just(".."), + st.just("."), + ), + min_size=1, + max_size=10, # Limit depth to avoid excessively long paths + ).map(lambda x: os.path.join(*x)) + + +def my_check(path_1, path_2): + try: + path_1 = Path(path_1).resolve() + path_2 = Path(path_2).resolve() + _ = path_1.relative_to(path_2) + return True + except ValueError: + return False + + +@settings(derandomize=os.getenv("CI") is not None) +@given( + path_1=create_path_string(), + path_2=create_path_string(), +) +def test_is_in_or_equal_fuzzer(path_1, path_2): + try: + # Convert to absolute paths + abs_path_1 = abspath(path_1) + abs_path_2 = abspath(path_2) + result = is_in_or_equal(abs_path_1, abs_path_2) + assert result == my_check(abs_path_1, abs_path_2) + + except Exception as e: + pytest.fail(f"Exception raised: {e}") + + +# Additional test for known edge cases +@pytest.mark.parametrize( + "path_1,path_2,expected", + [ + ("/AAA/a/../a", "/AAA", True), + ("//AA/a", "/tmp", False), + ("/AAA/..", "/AAA", False), + ("/a/b/c", "/d/e/f", False), + (".", "..", True), + ("..", ".", False), + ("/a/b/./c", "/a/b", True), + ("/a/b/../c", "/a", True), + ("/a/b/c", "/a/b/c/../d", False), + ("/", "/a", False), + ("/a", "/", True), + ], +) +def test_is_in_or_equal_edge_cases(path_1, path_2, expected): + assert is_in_or_equal(path_1, path_2) == expected + + @pytest.mark.parametrize( "path_or_url, extension", [