diff --git a/.changeset/gentle-socks-pay.md b/.changeset/gentle-socks-pay.md new file mode 100644 index 0000000000..acf4fb82d6 --- /dev/null +++ b/.changeset/gentle-socks-pay.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Ensure JSON component outputs handled properly in postprocess diff --git a/gradio/components/base.py b/gradio/components/base.py index 99ecdbd515..7bdb81974b 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -19,7 +19,7 @@ import gradio_client.utils as client_utils from gradio import utils from gradio.blocks import Block, BlockContext from gradio.component_meta import ComponentMeta -from gradio.data_classes import GradioDataModel +from gradio.data_classes import GradioDataModel, JsonData from gradio.events import EventListener from gradio.layouts import Form from gradio.processing_utils import move_files_to_cache @@ -298,6 +298,8 @@ class Component(ComponentBase, Block): payload = self.data_model.from_json(payload) Path(flag_dir).mkdir(exist_ok=True) payload = payload.copy_to_dir(flag_dir).model_dump() + if isinstance(payload, JsonData): + payload = payload.model_dump() if not isinstance(payload, str): payload = json.dumps(payload) return payload diff --git a/gradio/components/json_component.py b/gradio/components/json_component.py index 1d50074c98..6f70d4e87e 100644 --- a/gradio/components/json_component.py +++ b/gradio/components/json_component.py @@ -9,6 +9,7 @@ import orjson from gradio_client.documentation import document from gradio.components.base import Component +from gradio.data_classes import JsonData from gradio.events import Events @@ -77,7 +78,7 @@ class JSON(Component): """ return payload - def postprocess(self, value: dict | list | str | None) -> dict | list | None: + def postprocess(self, value: dict | list | str | None) -> JsonData | None: """ Parameters: value: Expects a valid JSON `str` -- or a `list` or `dict` that can be serialized to a JSON string. The `list` or `dict` value can contain numpy arrays. @@ -87,16 +88,19 @@ class JSON(Component): if value is None: return None if isinstance(value, str): - return orjson.loads(value) + return JsonData(orjson.loads(value)) else: # Use orjson to convert NumPy arrays and datetime objects to JSON. # This ensures a backward compatibility with the previous behavior. # See https://github.com/gradio-app/gradio/pull/8041 - return orjson.loads( - orjson.dumps( - value, - option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME, - default=str, + return JsonData( + orjson.loads( + orjson.dumps( + value, + option=orjson.OPT_SERIALIZE_NUMPY + | orjson.OPT_PASSTHROUGH_DATETIME, + default=str, + ) ) ) diff --git a/gradio/data_classes.py b/gradio/data_classes.py index e82da146bc..bcdd12f851 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -16,7 +16,7 @@ from gradio_client.utils import traverse from . import wasm_utils if not wasm_utils.IS_WASM or TYPE_CHECKING: - from pydantic import BaseModel, RootModel, ValidationError + from pydantic import BaseModel, JsonValue, RootModel, ValidationError else: # XXX: Currently Pyodide V2 is not available on Pyodide, # so we install V1 for the Wasm version. @@ -25,6 +25,8 @@ else: from pydantic import BaseModel as BaseModelV1 from pydantic import ValidationError, schema_of + JsonValue = Any + # Map V2 method calls to V1 implementations. # Ref: https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel class BaseModelMeta(type(BaseModelV1)): @@ -161,6 +163,12 @@ class GradioBaseModel(ABC): pass +class JsonData(RootModel): + """JSON data returned from a component that should not be modified further.""" + + root: JsonValue + + class GradioModel(GradioBaseModel, BaseModel): @classmethod def from_json(cls, x) -> GradioModel: diff --git a/gradio/helpers.py b/gradio/helpers.py index 4c9c7b57ad..ce63a36632 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -933,6 +933,7 @@ def special_args( ): event_data_index = i if inputs is not None and event_data is not None: + processing_utils.check_all_files_in_cache(event_data._data) inputs.insert(i, type_hint(event_data.target, event_data._data)) elif ( param.default is not param.empty and inputs is not None and len(inputs) <= i diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 40ac2b967a..3c0d4b7bd5 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -20,7 +20,8 @@ from gradio_client import utils as client_utils from PIL import Image, ImageOps, PngImagePlugin from gradio import utils, wasm_utils -from gradio.data_classes import FileData, GradioModel, GradioRootModel +from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData +from gradio.exceptions import Error from gradio.utils import abspath, get_upload_folder, is_in_or_equal with warnings.catch_warnings(): @@ -341,6 +342,20 @@ def move_resource_to_block_cache( return block.move_resource_to_block_cache(url_or_file_path) +def check_all_files_in_cache(data: JsonData): + def _in_cache(d: dict): + if ( + (path := d.get("path", "")) + and not client_utils.is_http_url_like(path) + and not is_in_or_equal(path, get_upload_folder()) + ): + raise Error( + f"File {path} is not in the cache folder and cannot be accessed." + ) + + client_utils.traverse(data, _in_cache, client_utils.is_file_obj) + + def move_files_to_cache( data: Any, block: Block, @@ -475,7 +490,6 @@ async def async_move_files_to_cache( if isinstance(data, (GradioRootModel, GradioModel)): data = data.model_dump() - return await client_utils.async_traverse( data, _move_to_cache, client_utils.is_file_obj ) diff --git a/test/components/test_json.py b/test/components/test_json.py index 722dc809b8..cfa6430a58 100644 --- a/test/components/test_json.py +++ b/test/components/test_json.py @@ -75,7 +75,7 @@ class TestJSON: await iface.process_api( 0, [{"data": y_data, "headers": ["gender", "age"]}], state={} ) - )["data"][0] == { + )["data"][0].model_dump() == { "M": 35, "F": 25, "O": 20, @@ -95,5 +95,8 @@ class TestJSON: def test_postprocess_returns_json_serializable_value(self, value, expected): json_component = gr.JSON() postprocessed_value = json_component.postprocess(value) - assert postprocessed_value == expected - assert json.loads(json.dumps(postprocessed_value)) == expected + if postprocessed_value is None: + assert value is None + else: + assert postprocessed_value.model_dump() == expected + assert json.loads(json.dumps(postprocessed_value.model_dump())) == expected diff --git a/test/test_helpers.py b/test/test_helpers.py index f25043d1b3..2a31cbd7f1 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -17,7 +17,7 @@ from starlette.testclient import TestClient from tqdm import tqdm import gradio as gr -from gradio import utils +from gradio import helpers, utils @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) @@ -879,3 +879,21 @@ async def test_info_isolation(async_handler: bool): assert alice_logs == "Hello Alice" assert bob_logs == "Hello Bob" + + +def test_check_event_data_in_cache(): + def get_select_index(evt: gr.SelectData): + return evt.index + + with pytest.raises(gr.Error): + helpers.special_args( + get_select_index, + inputs=[], + event_data=helpers.EventData( + None, + { + "index": {"path": "foo", "meta": {"_type": "gradio.FileData"}}, + "value": "whatever", + }, + ), + ) diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 1fa5c03ab5..2369fd9a77 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -11,7 +11,7 @@ import pytest from gradio_client import media_data from PIL import Image, ImageCms -from gradio import processing_utils, utils +from gradio import components, data_classes, processing_utils, utils class TestTempFileManagement: @@ -382,3 +382,32 @@ def test_hash_url_encodes_url(): assert processing_utils.hash_url( "https://www.gradio.app/image 1.jpg" ) == processing_utils.hash_bytes(b"https://www.gradio.app/image 1.jpg") + + +@pytest.mark.asyncio +async def test_json_data_not_moved_to_cache(): + data = data_classes.JsonData( + root={ + "file": { + "path": "path", + "url": "/file=path", + "meta": {"_type": "gradio.FileData"}, + } + } + ) + assert ( + processing_utils.move_files_to_cache(data, components.Number(), False) == data + ) + assert processing_utils.move_files_to_cache(data, components.Number(), True) == data + assert ( + await processing_utils.async_move_files_to_cache( + data, components.Number(), False + ) + == data + ) + assert ( + await processing_utils.async_move_files_to_cache( + data, components.Number(), True + ) + == data + )