Ensure JSON component outputs handled properly in postprocess (#8292)

* Add code

* Json postprocess

* add changeset

* add changeset

* Fix json tests

* fix flag

* Address comments

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-05-15 17:38:50 -04:00 committed by GitHub
parent 929d216d49
commit ee1e2942e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 100 additions and 16 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Ensure JSON component outputs handled properly in postprocess

View File

@ -19,7 +19,7 @@ import gradio_client.utils as client_utils
from gradio import utils
from gradio.blocks import Block, BlockContext
from gradio.component_meta import ComponentMeta
from gradio.data_classes import GradioDataModel
from gradio.data_classes import GradioDataModel, JsonData
from gradio.events import EventListener
from gradio.layouts import Form
from gradio.processing_utils import move_files_to_cache
@ -298,6 +298,8 @@ class Component(ComponentBase, Block):
payload = self.data_model.from_json(payload)
Path(flag_dir).mkdir(exist_ok=True)
payload = payload.copy_to_dir(flag_dir).model_dump()
if isinstance(payload, JsonData):
payload = payload.model_dump()
if not isinstance(payload, str):
payload = json.dumps(payload)
return payload

View File

@ -9,6 +9,7 @@ import orjson
from gradio_client.documentation import document
from gradio.components.base import Component
from gradio.data_classes import JsonData
from gradio.events import Events
@ -77,7 +78,7 @@ class JSON(Component):
"""
return payload
def postprocess(self, value: dict | list | str | None) -> dict | list | None:
def postprocess(self, value: dict | list | str | None) -> JsonData | None:
"""
Parameters:
value: Expects a valid JSON `str` -- or a `list` or `dict` that can be serialized to a JSON string. The `list` or `dict` value can contain numpy arrays.
@ -87,16 +88,19 @@ class JSON(Component):
if value is None:
return None
if isinstance(value, str):
return orjson.loads(value)
return JsonData(orjson.loads(value))
else:
# Use orjson to convert NumPy arrays and datetime objects to JSON.
# This ensures a backward compatibility with the previous behavior.
# See https://github.com/gradio-app/gradio/pull/8041
return orjson.loads(
orjson.dumps(
value,
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
default=str,
return JsonData(
orjson.loads(
orjson.dumps(
value,
option=orjson.OPT_SERIALIZE_NUMPY
| orjson.OPT_PASSTHROUGH_DATETIME,
default=str,
)
)
)

View File

@ -16,7 +16,7 @@ from gradio_client.utils import traverse
from . import wasm_utils
if not wasm_utils.IS_WASM or TYPE_CHECKING:
from pydantic import BaseModel, RootModel, ValidationError
from pydantic import BaseModel, JsonValue, RootModel, ValidationError
else:
# XXX: Currently Pyodide V2 is not available on Pyodide,
# so we install V1 for the Wasm version.
@ -25,6 +25,8 @@ else:
from pydantic import BaseModel as BaseModelV1
from pydantic import ValidationError, schema_of
JsonValue = Any
# Map V2 method calls to V1 implementations.
# Ref: https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
class BaseModelMeta(type(BaseModelV1)):
@ -161,6 +163,12 @@ class GradioBaseModel(ABC):
pass
class JsonData(RootModel):
"""JSON data returned from a component that should not be modified further."""
root: JsonValue
class GradioModel(GradioBaseModel, BaseModel):
@classmethod
def from_json(cls, x) -> GradioModel:

View File

@ -933,6 +933,7 @@ def special_args(
):
event_data_index = i
if inputs is not None and event_data is not None:
processing_utils.check_all_files_in_cache(event_data._data)
inputs.insert(i, type_hint(event_data.target, event_data._data))
elif (
param.default is not param.empty and inputs is not None and len(inputs) <= i

View File

@ -20,7 +20,8 @@ from gradio_client import utils as client_utils
from PIL import Image, ImageOps, PngImagePlugin
from gradio import utils, wasm_utils
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData
from gradio.exceptions import Error
from gradio.utils import abspath, get_upload_folder, is_in_or_equal
with warnings.catch_warnings():
@ -341,6 +342,20 @@ def move_resource_to_block_cache(
return block.move_resource_to_block_cache(url_or_file_path)
def check_all_files_in_cache(data: JsonData):
def _in_cache(d: dict):
if (
(path := d.get("path", ""))
and not client_utils.is_http_url_like(path)
and not is_in_or_equal(path, get_upload_folder())
):
raise Error(
f"File {path} is not in the cache folder and cannot be accessed."
)
client_utils.traverse(data, _in_cache, client_utils.is_file_obj)
def move_files_to_cache(
data: Any,
block: Block,
@ -475,7 +490,6 @@ async def async_move_files_to_cache(
if isinstance(data, (GradioRootModel, GradioModel)):
data = data.model_dump()
return await client_utils.async_traverse(
data, _move_to_cache, client_utils.is_file_obj
)

View File

@ -75,7 +75,7 @@ class TestJSON:
await iface.process_api(
0, [{"data": y_data, "headers": ["gender", "age"]}], state={}
)
)["data"][0] == {
)["data"][0].model_dump() == {
"M": 35,
"F": 25,
"O": 20,
@ -95,5 +95,8 @@ class TestJSON:
def test_postprocess_returns_json_serializable_value(self, value, expected):
json_component = gr.JSON()
postprocessed_value = json_component.postprocess(value)
assert postprocessed_value == expected
assert json.loads(json.dumps(postprocessed_value)) == expected
if postprocessed_value is None:
assert value is None
else:
assert postprocessed_value.model_dump() == expected
assert json.loads(json.dumps(postprocessed_value.model_dump())) == expected

View File

@ -17,7 +17,7 @@ from starlette.testclient import TestClient
from tqdm import tqdm
import gradio as gr
from gradio import utils
from gradio import helpers, utils
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
@ -879,3 +879,21 @@ async def test_info_isolation(async_handler: bool):
assert alice_logs == "Hello Alice"
assert bob_logs == "Hello Bob"
def test_check_event_data_in_cache():
def get_select_index(evt: gr.SelectData):
return evt.index
with pytest.raises(gr.Error):
helpers.special_args(
get_select_index,
inputs=[],
event_data=helpers.EventData(
None,
{
"index": {"path": "foo", "meta": {"_type": "gradio.FileData"}},
"value": "whatever",
},
),
)

View File

@ -11,7 +11,7 @@ import pytest
from gradio_client import media_data
from PIL import Image, ImageCms
from gradio import processing_utils, utils
from gradio import components, data_classes, processing_utils, utils
class TestTempFileManagement:
@ -382,3 +382,32 @@ def test_hash_url_encodes_url():
assert processing_utils.hash_url(
"https://www.gradio.app/image 1.jpg"
) == processing_utils.hash_bytes(b"https://www.gradio.app/image 1.jpg")
@pytest.mark.asyncio
async def test_json_data_not_moved_to_cache():
data = data_classes.JsonData(
root={
"file": {
"path": "path",
"url": "/file=path",
"meta": {"_type": "gradio.FileData"},
}
}
)
assert (
processing_utils.move_files_to_cache(data, components.Number(), False) == data
)
assert processing_utils.move_files_to_cache(data, components.Number(), True) == data
assert (
await processing_utils.async_move_files_to_cache(
data, components.Number(), False
)
== data
)
assert (
await processing_utils.async_move_files_to_cache(
data, components.Number(), True
)
== data
)