From 6b8c8afd981fea984da568e9a0bd8bfc2a9c06c4 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 26 Sep 2023 12:20:34 -0700 Subject: [PATCH] Fix incorrect behavior of `gr.load()` with `gr.Examples` (#5690) * testing * fixes * chat fix * lint * add changeset * fix * simplify * simplify * spacing * remove print * docstring * dataset * lint * add changeset * fix test * add test * added test --------- Co-authored-by: gradio-pr-bot --- .changeset/soft-trees-travel.md | 5 +++++ gradio/blocks.py | 15 +++++++++++++++ gradio/chat_interface.py | 2 +- gradio/components/base.py | 15 ++++++++------- gradio/components/dataset.py | 1 + gradio/interface.py | 4 ++-- test/test_components.py | 34 +++++++++++++++++++++++++++------ 7 files changed, 60 insertions(+), 16 deletions(-) create mode 100644 .changeset/soft-trees-travel.md diff --git a/.changeset/soft-trees-travel.md b/.changeset/soft-trees-travel.md new file mode 100644 index 0000000000..af7402e77a --- /dev/null +++ b/.changeset/soft-trees-travel.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fix incorrect behavior of `gr.load()` with `gr.Examples` diff --git a/gradio/blocks.py b/gradio/blocks.py index 14f78bb866..b6dc9828b6 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -730,6 +730,8 @@ class Blocks(BlockContext): cls = component_or_layout_class(block_config["type"]) block_config["props"].pop("type", None) block_config["props"].pop("name", None) + block_config["props"].pop("selectable", None) + # If a Gradio app B is loaded into a Gradio app A, and B itself loads a # Gradio app C, then the root_urls of the components in A need to be the # URL of C, not B. The else clause below handles this case. @@ -737,6 +739,19 @@ class Blocks(BlockContext): block_config["props"]["root_url"] = f"{root_url}/" else: root_urls.add(block_config["props"]["root_url"]) + + # We treat dataset components as a special case because they reference other components + # in the config. Instead of using the component string names, we use the component ids. + if ( + block_config["type"] == "dataset" + and "component_ids" in block_config["props"] + ): + block_config["props"].pop("components", None) + block_config["props"]["components"] = [ + original_mapping[c] for c in block_config["props"]["component_ids"] + ] + block_config["props"].pop("component_ids", None) + # Any component has already processed its initial value, so we skip that step here block = cls(**block_config["props"], _skip_init_processing=True) return block diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index b5e24d891f..a27d92c92d 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -119,7 +119,7 @@ class ChatInterface(Blocks): if not isinstance(additional_inputs, list): additional_inputs = [additional_inputs] self.additional_inputs = [ - get_component_instance(i, render=False) for i in additional_inputs # type: ignore + get_component_instance(i) for i in additional_inputs # type: ignore ] else: self.additional_inputs = [] diff --git a/gradio/components/base.py b/gradio/components/base.py index 9acab862e3..8bf266ca96 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -352,28 +352,29 @@ class FormComponent: return Form -def component(cls_name: str) -> Component: - obj = utils.component_or_layout_class(cls_name)() +def component(cls_name: str, render: bool) -> Component: + obj = utils.component_or_layout_class(cls_name)(render=render) if isinstance(obj, BlockContext): raise ValueError(f"Invalid component: {obj.__class__}") return obj def get_component_instance( - comp: str | dict | Component, render: bool | None = None + comp: str | dict | Component, render: bool = False, unrender: bool = False ) -> Component: """ Returns a component instance from a string, dict, or Component object. Parameters: comp: the component to instantiate. If a string, must be the name of a component, e.g. "dropdown". If a dict, must have a "name" key, e.g. {"name": "dropdown", "choices": ["a", "b"]}. If a Component object, will be returned as is. - render: whether to render the component. If True, renders the component (if not already rendered). If False, *unrenders* the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If None, does not render or unrender the component. + render: whether to render the component. If True, renders the component (if not already rendered). If False, does not do anything. + unrender: whether to unrender the component. If True, unrenders the the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If False, does not do anything. """ if isinstance(comp, str): - component_obj = component(comp) + component_obj = component(comp, render=render) elif isinstance(comp, dict): name = comp.pop("name") component_cls = utils.component_or_layout_class(name) - component_obj = component_cls(**comp) + component_obj = component_cls(**comp, render=render) if isinstance(component_obj, BlockContext): raise ValueError(f"Invalid component: {name}") elif isinstance(comp, Component): @@ -384,6 +385,6 @@ def get_component_instance( ) if render and not component_obj.is_rendered: component_obj.render() - elif render is False and component_obj.is_rendered: + elif unrender and component_obj.is_rendered: component_obj.unrender() return component_obj diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 8cb422778a..62c2b8b09c 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -116,6 +116,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable): config["components"] = [ component.get_block_name() for component in self._components ] + config["component_ids"] = [component._id for component in self._components] return config def preprocess(self, x: Any) -> Any: diff --git a/gradio/interface.py b/gradio/interface.py index c2acc2a3c2..d2db6b78c6 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -250,10 +250,10 @@ class Interface(Blocks): self.cache_examples = False self.input_components = [ - get_component_instance(i, render=False) for i in inputs # type: ignore + get_component_instance(i, unrender=True) for i in inputs # type: ignore ] self.output_components = [ - get_component_instance(o, render=False) for o in outputs # type: ignore + get_component_instance(o, unrender=True) for o in outputs # type: ignore ] for component in self.input_components + self.output_components: diff --git a/test/test_components.py b/test/test_components.py index c958706f95..73f7203a97 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -34,12 +34,30 @@ from gradio.deprecation import ( os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" -class TestComponent: - def test_component_functions(self): - """ - component - """ - assert isinstance(gr.components.component("textarea"), gr.templates.TextArea) +class TestGettingComponents: + def test_component_function(self): + assert isinstance( + gr.components.component("textarea", render=False), gr.templates.TextArea + ) + + @pytest.mark.parametrize( + "component, render, unrender, should_be_rendered", + [ + (gr.Textbox(render=True), False, True, False), + (gr.Textbox(render=False), False, False, False), + (gr.Textbox(render=False), True, False, True), + ("textbox", False, False, False), + ("textbox", True, False, True), + ], + ) + def test_get_component_instance_rendering( + self, component, render, unrender, should_be_rendered + ): + with gr.Blocks(): + textbox = gr.components.get_component_instance( + component, render=render, unrender=unrender + ) + assert textbox.is_rendered == should_be_rendered def test_raise_warnings(): @@ -1387,6 +1405,10 @@ class TestDataset: assert dataset.preprocess(1) == 1 + radio = gr.Radio(choices=[("name 1", "value 1"), ("name 2", "value 2")]) + dataset = gr.Dataset(samples=[["value 1"], ["value 2"]], components=[radio]) + assert dataset.samples == [["name 1"], ["name 2"]] + def test_postprocessing(self): test_file_dir = Path(Path(__file__).parent, "test_files") bus = Path(test_file_dir, "bus.png")