Custom component in rerender (#10170)

* changes

* add changeset

* changes

* Support gr.Examples in gr.render (#10173)

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>

* add changeset

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
aliabid94 2024-12-12 17:40:28 -08:00 committed by GitHub
parent e525680316
commit 5e6e234cba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 51 additions and 19 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/core": patch
"gradio": patch
---
fix:Custom component in rerender

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: render_tests"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from datetime import datetime\n", "\n", "import gradio as gr\n", "\n", "def update_log():\n", " return datetime.now().timestamp()\n", "\n", "def get_target(evt: gr.EventData):\n", " return evt.target\n", "\n", "def get_select_index(evt: gr.SelectData):\n", " return evt.index\n", "\n", "with gr.Blocks() as demo:\n", " gr.Textbox(value=update_log, every=0.2, label=\"Time\")\n", " \n", " slider = gr.Slider(1, 10, step=1)\n", " @gr.render(inputs=[slider])\n", " def show_log(s):\n", " with gr.Row():\n", " for i in range(s):\n", " gr.Textbox(value=update_log, every=0.2, label=f\"Render {i + 1}\")\n", "\n", " with gr.Row():\n", " selected_btn = gr.Textbox(label=\"Selected Button\")\n", " selected_chat = gr.Textbox(label=\"Selected Chat\")\n", " @gr.render(inputs=[slider])\n", " def show_buttons(s):\n", " with gr.Row():\n", " with gr.Column():\n", " for i in range(s):\n", " btn = gr.Button(f\"Button {i + 1}\")\n", " btn.click(get_target, None, selected_btn)\n", " chatbot = gr.Chatbot([[\"Hello\", \"Hi\"], [\"How are you?\", \"I'm good.\"]])\n", " chatbot.select(get_select_index, None, selected_chat)\n", "\n", "if __name__ == '__main__':\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: render_tests"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from datetime import datetime\n", "\n", "import gradio as gr\n", "\n", "def update_log():\n", " return datetime.now().timestamp()\n", "\n", "def get_target(evt: gr.EventData):\n", " return evt.target\n", "\n", "def get_select_index(evt: gr.SelectData):\n", " return evt.index\n", "\n", "with gr.Blocks() as demo:\n", " gr.Textbox(value=update_log, every=0.2, label=\"Time\")\n", " \n", " slider = gr.Slider(1, 10, step=1)\n", " @gr.render(inputs=[slider])\n", " def show_log(s):\n", " with gr.Row():\n", " for i in range(s):\n", " gr.Textbox(value=update_log, every=0.2, label=f\"Render {i + 1}\")\n", "\n", " with gr.Row():\n", " selected_btn = gr.Textbox(label=\"Selected Button\")\n", " selected_chat = gr.Textbox(label=\"Selected Chat\")\n", " @gr.render(inputs=[slider])\n", " def show_buttons(s):\n", " with gr.Row():\n", " with gr.Column():\n", " for i in range(s):\n", " btn = gr.Button(f\"Button {i + 1}\")\n", " btn.click(get_target, None, selected_btn)\n", " chatbot = gr.Chatbot([[\"Hello\", \"Hi\"], [\"How are you?\", \"I'm good.\"]])\n", " chatbot.select(get_select_index, None, selected_chat)\n", "\n", " @gr.render()\n", " def examples_in_interface():\n", " gr.Interface(lambda x:x, gr.Textbox(label=\"input\"), gr.Textbox(), examples=[[\"test\"]])\n", "\n", " @gr.render()\n", " def examples_in_blocks():\n", " a = gr.Textbox(label=\"little textbox\")\n", " gr.Examples([[\"abc\"], [\"def\"]], [a])\n", "\n", "\n", "if __name__ == '__main__':\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -34,5 +34,15 @@ with gr.Blocks() as demo:
chatbot = gr.Chatbot([["Hello", "Hi"], ["How are you?", "I'm good."]])
chatbot.select(get_select_index, None, selected_chat)
@gr.render()
def examples_in_interface():
gr.Interface(lambda x:x, gr.Textbox(label="input"), gr.Textbox(), examples=[["test"]])
@gr.render()
def examples_in_blocks():
a = gr.Textbox(label="little textbox")
gr.Examples([["abc"], ["def"]], [a])
if __name__ == '__main__':
demo.launch()

View File

@ -20,7 +20,7 @@ from gradio_client import utils as client_utils
from gradio_client.documentation import document
from gradio import components, oauth, processing_utils, routes, utils, wasm_utils
from gradio.context import Context, LocalContext
from gradio.context import Context, LocalContext, get_blocks_context
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
@ -338,10 +338,13 @@ class Examples:
def create(self) -> None:
"""Creates the Dataset component to hold the examples"""
self.root_block = Context.root_block
if self.root_block:
self.root_block.extra_startup_events.append(self._start_caching)
blocks_config = get_blocks_context()
self.root_block = Context.root_block or (
blocks_config.root_block if blocks_config else None
)
if blocks_config:
if self.root_block:
self.root_block.extra_startup_events.append(self._start_caching)
if self.cache_examples:

View File

@ -613,25 +613,23 @@ class App(FastAPI):
raise HTTPException(
status_code=404, detail="Environment not supported."
)
config = app.get_blocks().config
components = config["components"]
components = utils.get_all_components()
location = next(
(item for item in components if item["component_class_id"] == id), None
(item for item in components if item.get_component_class_id() == id),
None,
)
if location is None:
raise HTTPException(status_code=404, detail="Component not found.")
component_instance = app.get_blocks().get_component(location["id"])
module_name = component_instance.__class__.__module__
module_name = location.__module__
module_path = sys.modules[module_name].__file__
if module_path is None or component_instance is None:
if module_path is None:
raise HTTPException(status_code=404, detail="Component not found.")
try:
requested_path = utils.safe_join(
component_instance.__class__.TEMPLATE_DIR,
location.TEMPLATE_DIR,
UserProvidedPath(f"{type}/{file_name}"),
)
except InvalidPathError:

View File

@ -412,7 +412,7 @@
rerender_layout({
components: _components,
layout: render_layout,
root: root,
root: root + api_prefix,
dependencies: dependencies,
render_id: render_id
});

View File

@ -154,7 +154,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
{} as { [id: number]: ComponentMeta }
);
await walk_layout(layout, root);
await walk_layout(layout, root, _components);
layout_store.set(_rootNode);
set_event_specific_args(dependencies);
@ -230,7 +230,12 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
] = instance_map[layout.id];
}
walk_layout(layout, root, current_element.parent).then(() => {
walk_layout(
layout,
root,
_components.concat(components),
current_element.parent
).then(() => {
layout_store.set(_rootNode);
});
@ -240,6 +245,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
async function walk_layout(
node: LayoutNode,
root: string,
components: ComponentMeta[],
parent?: ComponentMeta
): Promise<ComponentMeta> {
const instance = instance_map[node.id];
@ -254,7 +260,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
instance.type,
instance.component_class_id,
root,
_components,
components,
instance.props.components
).example_components;
}
@ -288,7 +294,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
if (node.children) {
instance.children = await Promise.all(
node.children.map((v) => walk_layout(v, root, instance))
node.children.map((v) => walk_layout(v, root, components, instance))
);
}

View File

@ -34,3 +34,12 @@ test("Test event/selection data works in render", async ({ page }) => {
await page.getByText("Hi").click();
await expect(selected_chat).toHaveValue("[0, 1]");
});
test("Test examples work in render", async ({ page }) => {
await page.getByRole("button", { name: "test" }).click();
await expect(page.getByLabel("input", { exact: true })).toHaveValue("test");
await page.getByRole("button", { name: "def", exact: true }).click();
await expect(page.getByLabel("little textbox", { exact: true })).toHaveValue(
"def"
);
});