Co-authored-by: Ali Abid <aliabid94@gmail.com>
This commit is contained in:
aliabid94 2024-06-06 03:08:14 -07:00 committed by GitHub
parent a9e6595817
commit 8c18114495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 3 deletions

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"A\")\n", " btn_a.click(lambda x: x+1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"num\")\n", " btn_b.click(lambda x: x+1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", " \n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x+1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(\n", " lambda x: [0] * len(x), inputs=list_state, outputs=list_state\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"A\")\n", " btn_a.click(lambda x: x+1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"num\")\n", " btn_b.click(lambda x: x+1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", " \n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x+1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(\n", " lambda x: [0] * len(x), inputs=list_state, outputs=list_state\n", " )\n", "\n", " gr.on([count_to_3_btn.click, zero_all_btn.click], lambda x: x + 1, click_count, click_count)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -40,6 +40,7 @@ with gr.Blocks() as demo:
)
all_textbox = gr.Textbox(label="Output")
click_count = gr.Number(label="Clicks")
change_count = gr.Number(label="Changes")
gr.on(
inputs=[change_count, dict_state, nested_list_state, set_state],
@ -55,5 +56,7 @@ with gr.Blocks() as demo:
lambda x: [0] * len(x), inputs=list_state, outputs=list_state
)
gr.on([count_to_3_btn.click, zero_all_btn.click], lambda x: x + 1, click_count, click_count)
if __name__ == "__main__":
demo.launch()

View File

@ -24,16 +24,27 @@ test("test datastructure-based state changes", async ({ page }) => {
`{1: 1, 2: 2, 3: 3}\n[[1, 2, 3], [1, 2, 3], [1, 2, 3]]\n{1, 2, 3}`
);
await expect(page.getByLabel("Changes")).toHaveValue("1");
await expect(page.getByLabel("Clicks")).toHaveValue("1");
await page.getByRole("button", { name: "Count to" }).click();
await expect(page.getByLabel("Changes")).toHaveValue("1");
await expect(page.getByLabel("Clicks")).toHaveValue("2");
await page.getByRole("button", { name: "Count to" }).click();
await page.getByRole("button", { name: "Count to" }).click();
await expect(page.getByLabel("Changes")).toHaveValue("1");
await expect(page.getByLabel("Clicks")).toHaveValue("3");
await expect(page.getByLabel("Output")).toHaveValue(
`{1: 1, 2: 2, 3: 3}\n[[1, 2, 3], [1, 2, 3], [1, 2, 3]]\n{1, 2, 3}`
);
await expect(page.getByLabel("Changes")).toHaveValue("1");
await page.getByRole("button", { name: "Zero All" }).click();
await expect(page.getByLabel("Output")).toHaveValue(
`{0: 0}\n[[0, 0, 0], [0, 0, 0], [0, 0, 0]]\n{0}`
);
await expect(page.getByLabel("Changes")).toHaveValue("2");
await expect(page.getByLabel("Clicks")).toHaveValue("4");
await page.getByRole("button", { name: "Zero All" }).click();
await expect(page.getByLabel("Changes")).toHaveValue("2");
await expect(page.getByLabel("Clicks")).toHaveValue("5");
});