mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Fix bug in reload mode equality check. Better equality conversion for state variables (#8385)
* Add code * Add deep equality * add changeset * Add code * add changeset * Update gradio/utils.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Add code * Add code * add code --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
e738e26a5d
commit
97ac79bf56
5
.changeset/ripe-tools-jam.md
Normal file
5
.changeset/ripe-tools-jam.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix bug in reload mode equality check. Better equality conversion for state variables
|
@ -1726,7 +1726,15 @@ Received outputs:
|
||||
|
||||
if block.stateful:
|
||||
if not utils.is_update(predictions[i]):
|
||||
if block._id not in state or state[block._id] != predictions[i]:
|
||||
has_change_event = False
|
||||
for dep in state.blocks_config.fns.values():
|
||||
if block._id in [t[0] for t in dep.targets if t[1] == "change"]:
|
||||
has_change_event = True
|
||||
break
|
||||
if has_change_event and (
|
||||
block._id not in state
|
||||
or not utils.deep_equal(state[block._id], predictions[i])
|
||||
):
|
||||
changed_state_ids.append(block._id)
|
||||
state[block._id] = predictions[i]
|
||||
output.append(None)
|
||||
|
@ -46,6 +46,7 @@ from typing import (
|
||||
import anyio
|
||||
import gradio_client.utils as client_utils
|
||||
import httpx
|
||||
import orjson
|
||||
from gradio_client.documentation import document
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
@ -290,6 +291,29 @@ def watchfn(reloader: SourceFileReloader):
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def deep_equal(a: Any, b: Any) -> bool:
|
||||
"""
|
||||
Deep equality check for component values.
|
||||
|
||||
Prefer orjson for performance and compatibility with numpy arrays/dataframes/torch tensors.
|
||||
If objects are not serializable by orjson, fall back to regular equality check.
|
||||
"""
|
||||
|
||||
def _serialize(a: Any) -> bytes:
|
||||
return orjson.dumps(
|
||||
a,
|
||||
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
|
||||
)
|
||||
|
||||
try:
|
||||
return _serialize(a) == _serialize(b)
|
||||
except TypeError:
|
||||
try:
|
||||
return a == b
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
|
||||
from gradio.blocks import BlockContext
|
||||
|
||||
@ -310,8 +334,10 @@ def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
|
||||
old_block.__class__ == new_block.__class__
|
||||
and old_block is not None
|
||||
and old_block.key not in assigned_keys
|
||||
and json.dumps(getattr(old_block, "value", None))
|
||||
== json.dumps(getattr(new_block, "value", None))
|
||||
and deep_equal(
|
||||
getattr(old_block, "value", None),
|
||||
getattr(new_block, "value", None),
|
||||
)
|
||||
):
|
||||
new_block.key = old_block.key
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user