mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Fix bug in reload mode equality check. Better equality conversion for state variables (#8385)
* Add code * Add deep equality * add changeset * Add code * add changeset * Update gradio/utils.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Add code * Add code * add code --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
e738e26a5d
commit
97ac79bf56
5
.changeset/ripe-tools-jam.md
Normal file
5
.changeset/ripe-tools-jam.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
"gradio": patch
|
||||||
|
---
|
||||||
|
|
||||||
|
fix:Fix bug in reload mode equality check. Better equality conversion for state variables
|
@ -1726,7 +1726,15 @@ Received outputs:
|
|||||||
|
|
||||||
if block.stateful:
|
if block.stateful:
|
||||||
if not utils.is_update(predictions[i]):
|
if not utils.is_update(predictions[i]):
|
||||||
if block._id not in state or state[block._id] != predictions[i]:
|
has_change_event = False
|
||||||
|
for dep in state.blocks_config.fns.values():
|
||||||
|
if block._id in [t[0] for t in dep.targets if t[1] == "change"]:
|
||||||
|
has_change_event = True
|
||||||
|
break
|
||||||
|
if has_change_event and (
|
||||||
|
block._id not in state
|
||||||
|
or not utils.deep_equal(state[block._id], predictions[i])
|
||||||
|
):
|
||||||
changed_state_ids.append(block._id)
|
changed_state_ids.append(block._id)
|
||||||
state[block._id] = predictions[i]
|
state[block._id] = predictions[i]
|
||||||
output.append(None)
|
output.append(None)
|
||||||
|
@ -46,6 +46,7 @@ from typing import (
|
|||||||
import anyio
|
import anyio
|
||||||
import gradio_client.utils as client_utils
|
import gradio_client.utils as client_utils
|
||||||
import httpx
|
import httpx
|
||||||
|
import orjson
|
||||||
from gradio_client.documentation import document
|
from gradio_client.documentation import document
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
@ -290,6 +291,29 @@ def watchfn(reloader: SourceFileReloader):
|
|||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_equal(a: Any, b: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Deep equality check for component values.
|
||||||
|
|
||||||
|
Prefer orjson for performance and compatibility with numpy arrays/dataframes/torch tensors.
|
||||||
|
If objects are not serializable by orjson, fall back to regular equality check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _serialize(a: Any) -> bytes:
|
||||||
|
return orjson.dumps(
|
||||||
|
a,
|
||||||
|
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _serialize(a) == _serialize(b)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
return a == b
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
|
def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
|
||||||
from gradio.blocks import BlockContext
|
from gradio.blocks import BlockContext
|
||||||
|
|
||||||
@ -310,8 +334,10 @@ def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
|
|||||||
old_block.__class__ == new_block.__class__
|
old_block.__class__ == new_block.__class__
|
||||||
and old_block is not None
|
and old_block is not None
|
||||||
and old_block.key not in assigned_keys
|
and old_block.key not in assigned_keys
|
||||||
and json.dumps(getattr(old_block, "value", None))
|
and deep_equal(
|
||||||
== json.dumps(getattr(new_block, "value", None))
|
getattr(old_block, "value", None),
|
||||||
|
getattr(new_block, "value", None),
|
||||||
|
)
|
||||||
):
|
):
|
||||||
new_block.key = old_block.key
|
new_block.key = old_block.key
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user