Fix bug in reload mode equality check. Better equality conversion for state variables (#8385)

* Add code

* Add deep equality

* add changeset

* Add code

* add changeset

* Update gradio/utils.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Add code

* Add code

* add code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-05-29 10:46:09 -04:00 committed by GitHub
parent e738e26a5d
commit 97ac79bf56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 3 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix bug in reload mode equality check. Better equality conversion for state variables

View File

@ -1726,7 +1726,15 @@ Received outputs:
if block.stateful: if block.stateful:
if not utils.is_update(predictions[i]): if not utils.is_update(predictions[i]):
if block._id not in state or state[block._id] != predictions[i]: has_change_event = False
for dep in state.blocks_config.fns.values():
if block._id in [t[0] for t in dep.targets if t[1] == "change"]:
has_change_event = True
break
if has_change_event and (
block._id not in state
or not utils.deep_equal(state[block._id], predictions[i])
):
changed_state_ids.append(block._id) changed_state_ids.append(block._id)
state[block._id] = predictions[i] state[block._id] = predictions[i]
output.append(None) output.append(None)

View File

@ -46,6 +46,7 @@ from typing import (
import anyio import anyio
import gradio_client.utils as client_utils import gradio_client.utils as client_utils
import httpx import httpx
import orjson
from gradio_client.documentation import document from gradio_client.documentation import document
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
@ -290,6 +291,29 @@ def watchfn(reloader: SourceFileReloader):
time.sleep(0.05) time.sleep(0.05)
def deep_equal(a: Any, b: Any) -> bool:
"""
Deep equality check for component values.
Prefer orjson for performance and compatibility with numpy arrays/dataframes/torch tensors.
If objects are not serializable by orjson, fall back to regular equality check.
"""
def _serialize(a: Any) -> bytes:
return orjson.dumps(
a,
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
)
try:
return _serialize(a) == _serialize(b)
except TypeError:
try:
return a == b
except Exception:
return False
def reassign_keys(old_blocks: Blocks, new_blocks: Blocks): def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
from gradio.blocks import BlockContext from gradio.blocks import BlockContext
@ -310,8 +334,10 @@ def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
old_block.__class__ == new_block.__class__ old_block.__class__ == new_block.__class__
and old_block is not None and old_block is not None
and old_block.key not in assigned_keys and old_block.key not in assigned_keys
and json.dumps(getattr(old_block, "value", None)) and deep_equal(
== json.dumps(getattr(new_block, "value", None)) getattr(old_block, "value", None),
getattr(new_block, "value", None),
)
): ):
new_block.key = old_block.key new_block.key = old_block.key
else: else: