mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Fix state changes within a gr.render (#10095)
* changes * add changeset --------- Co-authored-by: Ali Abid <aliabid94@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
de42c85661
commit
97d647ecfd
5
.changeset/evil-streets-hunt.md
Normal file
5
.changeset/evil-streets-hunt.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix state changes within a gr.render
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", " async def increment(x):\n", " yield x + 1\n", "\n", " n_text = gr.State(0)\n", " add_btn = gr.Button(\"Iterator State Change\")\n", " add_btn.click(increment, n_text, n_text)\n", "\n", " @gr.render(inputs=n_text)\n", " def render_count(count):\n", " for i in range(int(count)):\n", " gr.Markdown(value = f\"Success Box {i} added\", key=i)\n", " \n", " class CustomState():\n", " def __init__(self, val):\n", " self.val = val\n", "\n", " def __hash__(self) -> int:\n", " return self.val\n", "\n", " custom_state = gr.State(CustomState(5))\n", " with gr.Row():\n", " btn_10 = gr.Button(\"Set State to 10\")\n", " custom_changes = gr.Number(0, label=\"Custom State Changes\")\n", " custom_clicks = gr.Number(0, label=\"Custom State Clicks\")\n", "\n", " custom_state.change(increment, custom_changes, custom_changes)\n", " def set_to_10(cs: CustomState):\n", " cs.val = 10\n", " return cs\n", "\n", " btn_10.click(set_to_10, custom_state, custom_state).then(\n", " increment, custom_clicks, custom_clicks\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", " async def increment(x):\n", " yield x + 1\n", "\n", " n_text = gr.State(0)\n", " add_btn = gr.Button(\"Iterator State Change\")\n", " add_btn.click(increment, n_text, n_text)\n", "\n", " @gr.render(inputs=n_text)\n", " def render_count(count):\n", " for i in range(int(count)):\n", " gr.Markdown(value = f\"Success Box {i} added\", key=i)\n", " \n", " class CustomState():\n", " def __init__(self, val):\n", " self.val = val\n", "\n", " def __hash__(self) -> int:\n", " return self.val\n", "\n", " custom_state = gr.State(CustomState(5))\n", " with gr.Row():\n", " btn_10 = gr.Button(\"Set State to 10\")\n", " custom_changes = gr.Number(0, label=\"Custom State Changes\")\n", " custom_clicks = gr.Number(0, label=\"Custom State Clicks\")\n", "\n", " custom_state.change(increment, custom_changes, custom_changes)\n", " def set_to_10(cs: CustomState):\n", " cs.val = 10\n", " return cs\n", "\n", " btn_10.click(set_to_10, custom_state, custom_state).then(\n", " increment, custom_clicks, custom_clicks\n", " )\n", "\n", " @gr.render()\n", " def render_state_changes():\n", " with gr.Row():\n", " box1 = gr.Textbox(label=\"Start State\")\n", " state1 = gr.State()\n", " box2 = gr.Textbox()\n", " state2 = gr.State()\n", " box3 = gr.Textbox(label=\"End State\")\n", "\n", " iden = lambda x: x\n", " box1.change(iden, box1, state1)\n", " state1.change(iden, state1, box2)\n", " box2.change(iden, box2, state2)\n", " state2.change(iden, state2, box3)\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -95,5 +95,20 @@ with gr.Blocks() as demo:
|
||||
increment, custom_clicks, custom_clicks
|
||||
)
|
||||
|
||||
@gr.render()
|
||||
def render_state_changes():
|
||||
with gr.Row():
|
||||
box1 = gr.Textbox(label="Start State")
|
||||
state1 = gr.State()
|
||||
box2 = gr.Textbox()
|
||||
state2 = gr.State()
|
||||
box3 = gr.Textbox(label="End State")
|
||||
|
||||
iden = lambda x: x
|
||||
box1.change(iden, box1, state1)
|
||||
state1.change(iden, state1, box2)
|
||||
box2.change(iden, box2, state2)
|
||||
state2.change(iden, state2, box3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -2114,7 +2114,8 @@ Received inputs:
|
||||
hashed_values = []
|
||||
for block in block_fn.outputs:
|
||||
if block.stateful and any(
|
||||
(block._id, "change") in fn.targets for fn in self.fns.values()
|
||||
(block._id, "change") in fn.targets
|
||||
for fn in state.blocks_config.fns.values()
|
||||
):
|
||||
value = state[block._id]
|
||||
state_ids_to_track.append(block._id)
|
||||
|
@ -83,3 +83,9 @@ test("test state change for custom hashes", async ({ page }) => {
|
||||
"1"
|
||||
);
|
||||
});
|
||||
|
||||
test("test state changes work within gr.render", async ({ page }) => {
|
||||
const textbox = await page.getByLabel("Start State");
|
||||
await textbox.fill("test");
|
||||
await expect(page.getByLabel("End State").first()).toHaveValue("test");
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user