mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
Ensure JSON component outputs handled properly in postprocess (#8292)
* Add code * Json postprocess * add changeset * add changeset * Fix json tests * fix flag * Address comments --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
929d216d49
commit
ee1e2942e0
5
.changeset/gentle-socks-pay.md
Normal file
5
.changeset/gentle-socks-pay.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Ensure JSON component outputs handled properly in postprocess
|
@ -19,7 +19,7 @@ import gradio_client.utils as client_utils
|
||||
from gradio import utils
|
||||
from gradio.blocks import Block, BlockContext
|
||||
from gradio.component_meta import ComponentMeta
|
||||
from gradio.data_classes import GradioDataModel
|
||||
from gradio.data_classes import GradioDataModel, JsonData
|
||||
from gradio.events import EventListener
|
||||
from gradio.layouts import Form
|
||||
from gradio.processing_utils import move_files_to_cache
|
||||
@ -298,6 +298,8 @@ class Component(ComponentBase, Block):
|
||||
payload = self.data_model.from_json(payload)
|
||||
Path(flag_dir).mkdir(exist_ok=True)
|
||||
payload = payload.copy_to_dir(flag_dir).model_dump()
|
||||
if isinstance(payload, JsonData):
|
||||
payload = payload.model_dump()
|
||||
if not isinstance(payload, str):
|
||||
payload = json.dumps(payload)
|
||||
return payload
|
||||
|
@ -9,6 +9,7 @@ import orjson
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import JsonData
|
||||
from gradio.events import Events
|
||||
|
||||
|
||||
@ -77,7 +78,7 @@ class JSON(Component):
|
||||
"""
|
||||
return payload
|
||||
|
||||
def postprocess(self, value: dict | list | str | None) -> dict | list | None:
|
||||
def postprocess(self, value: dict | list | str | None) -> JsonData | None:
|
||||
"""
|
||||
Parameters:
|
||||
value: Expects a valid JSON `str` -- or a `list` or `dict` that can be serialized to a JSON string. The `list` or `dict` value can contain numpy arrays.
|
||||
@ -87,16 +88,19 @@ class JSON(Component):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return orjson.loads(value)
|
||||
return JsonData(orjson.loads(value))
|
||||
else:
|
||||
# Use orjson to convert NumPy arrays and datetime objects to JSON.
|
||||
# This ensures a backward compatibility with the previous behavior.
|
||||
# See https://github.com/gradio-app/gradio/pull/8041
|
||||
return orjson.loads(
|
||||
orjson.dumps(
|
||||
value,
|
||||
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
|
||||
default=str,
|
||||
return JsonData(
|
||||
orjson.loads(
|
||||
orjson.dumps(
|
||||
value,
|
||||
option=orjson.OPT_SERIALIZE_NUMPY
|
||||
| orjson.OPT_PASSTHROUGH_DATETIME,
|
||||
default=str,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,7 @@ from gradio_client.utils import traverse
|
||||
from . import wasm_utils
|
||||
|
||||
if not wasm_utils.IS_WASM or TYPE_CHECKING:
|
||||
from pydantic import BaseModel, RootModel, ValidationError
|
||||
from pydantic import BaseModel, JsonValue, RootModel, ValidationError
|
||||
else:
|
||||
# XXX: Currently Pyodide V2 is not available on Pyodide,
|
||||
# so we install V1 for the Wasm version.
|
||||
@ -25,6 +25,8 @@ else:
|
||||
from pydantic import BaseModel as BaseModelV1
|
||||
from pydantic import ValidationError, schema_of
|
||||
|
||||
JsonValue = Any
|
||||
|
||||
# Map V2 method calls to V1 implementations.
|
||||
# Ref: https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
|
||||
class BaseModelMeta(type(BaseModelV1)):
|
||||
@ -161,6 +163,12 @@ class GradioBaseModel(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class JsonData(RootModel):
|
||||
"""JSON data returned from a component that should not be modified further."""
|
||||
|
||||
root: JsonValue
|
||||
|
||||
|
||||
class GradioModel(GradioBaseModel, BaseModel):
|
||||
@classmethod
|
||||
def from_json(cls, x) -> GradioModel:
|
||||
|
@ -933,6 +933,7 @@ def special_args(
|
||||
):
|
||||
event_data_index = i
|
||||
if inputs is not None and event_data is not None:
|
||||
processing_utils.check_all_files_in_cache(event_data._data)
|
||||
inputs.insert(i, type_hint(event_data.target, event_data._data))
|
||||
elif (
|
||||
param.default is not param.empty and inputs is not None and len(inputs) <= i
|
||||
|
@ -20,7 +20,8 @@ from gradio_client import utils as client_utils
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
|
||||
from gradio import utils, wasm_utils
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData
|
||||
from gradio.exceptions import Error
|
||||
from gradio.utils import abspath, get_upload_folder, is_in_or_equal
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@ -341,6 +342,20 @@ def move_resource_to_block_cache(
|
||||
return block.move_resource_to_block_cache(url_or_file_path)
|
||||
|
||||
|
||||
def check_all_files_in_cache(data: JsonData):
|
||||
def _in_cache(d: dict):
|
||||
if (
|
||||
(path := d.get("path", ""))
|
||||
and not client_utils.is_http_url_like(path)
|
||||
and not is_in_or_equal(path, get_upload_folder())
|
||||
):
|
||||
raise Error(
|
||||
f"File {path} is not in the cache folder and cannot be accessed."
|
||||
)
|
||||
|
||||
client_utils.traverse(data, _in_cache, client_utils.is_file_obj)
|
||||
|
||||
|
||||
def move_files_to_cache(
|
||||
data: Any,
|
||||
block: Block,
|
||||
@ -475,7 +490,6 @@ async def async_move_files_to_cache(
|
||||
|
||||
if isinstance(data, (GradioRootModel, GradioModel)):
|
||||
data = data.model_dump()
|
||||
|
||||
return await client_utils.async_traverse(
|
||||
data, _move_to_cache, client_utils.is_file_obj
|
||||
)
|
||||
|
@ -75,7 +75,7 @@ class TestJSON:
|
||||
await iface.process_api(
|
||||
0, [{"data": y_data, "headers": ["gender", "age"]}], state={}
|
||||
)
|
||||
)["data"][0] == {
|
||||
)["data"][0].model_dump() == {
|
||||
"M": 35,
|
||||
"F": 25,
|
||||
"O": 20,
|
||||
@ -95,5 +95,8 @@ class TestJSON:
|
||||
def test_postprocess_returns_json_serializable_value(self, value, expected):
|
||||
json_component = gr.JSON()
|
||||
postprocessed_value = json_component.postprocess(value)
|
||||
assert postprocessed_value == expected
|
||||
assert json.loads(json.dumps(postprocessed_value)) == expected
|
||||
if postprocessed_value is None:
|
||||
assert value is None
|
||||
else:
|
||||
assert postprocessed_value.model_dump() == expected
|
||||
assert json.loads(json.dumps(postprocessed_value.model_dump())) == expected
|
||||
|
@ -17,7 +17,7 @@ from starlette.testclient import TestClient
|
||||
from tqdm import tqdm
|
||||
|
||||
import gradio as gr
|
||||
from gradio import utils
|
||||
from gradio import helpers, utils
|
||||
|
||||
|
||||
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
|
||||
@ -879,3 +879,21 @@ async def test_info_isolation(async_handler: bool):
|
||||
|
||||
assert alice_logs == "Hello Alice"
|
||||
assert bob_logs == "Hello Bob"
|
||||
|
||||
|
||||
def test_check_event_data_in_cache():
|
||||
def get_select_index(evt: gr.SelectData):
|
||||
return evt.index
|
||||
|
||||
with pytest.raises(gr.Error):
|
||||
helpers.special_args(
|
||||
get_select_index,
|
||||
inputs=[],
|
||||
event_data=helpers.EventData(
|
||||
None,
|
||||
{
|
||||
"index": {"path": "foo", "meta": {"_type": "gradio.FileData"}},
|
||||
"value": "whatever",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ import pytest
|
||||
from gradio_client import media_data
|
||||
from PIL import Image, ImageCms
|
||||
|
||||
from gradio import processing_utils, utils
|
||||
from gradio import components, data_classes, processing_utils, utils
|
||||
|
||||
|
||||
class TestTempFileManagement:
|
||||
@ -382,3 +382,32 @@ def test_hash_url_encodes_url():
|
||||
assert processing_utils.hash_url(
|
||||
"https://www.gradio.app/image 1.jpg"
|
||||
) == processing_utils.hash_bytes(b"https://www.gradio.app/image 1.jpg")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_data_not_moved_to_cache():
|
||||
data = data_classes.JsonData(
|
||||
root={
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": "/file=path",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert (
|
||||
processing_utils.move_files_to_cache(data, components.Number(), False) == data
|
||||
)
|
||||
assert processing_utils.move_files_to_cache(data, components.Number(), True) == data
|
||||
assert (
|
||||
await processing_utils.async_move_files_to_cache(
|
||||
data, components.Number(), False
|
||||
)
|
||||
== data
|
||||
)
|
||||
assert (
|
||||
await processing_utils.async_move_files_to_cache(
|
||||
data, components.Number(), True
|
||||
)
|
||||
== data
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user