From aca4892ea591d089e3121ffd144ad6767e0bad71 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 15 Mar 2024 13:59:46 -0700 Subject: [PATCH] More fixes for `gr.load()` as well as a tweaking the `__str__` and `__repr__` methods of components (#7712) * more fixes * add changeset * format * add changeset * add changeset * address review --------- Co-authored-by: gradio-pr-bot --- .changeset/famous-years-itch.md | 5 +++++ gradio/blocks.py | 31 +++++++++++++++++++++---------- gradio/components/base.py | 6 ------ test/test_blocks.py | 8 ++++---- test/test_components.py | 6 +++--- 5 files changed, 33 insertions(+), 23 deletions(-) create mode 100644 .changeset/famous-years-itch.md diff --git a/.changeset/famous-years-itch.md b/.changeset/famous-years-itch.md new file mode 100644 index 0000000000..978760f4eb --- /dev/null +++ b/.changeset/famous-years-itch.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:More fixes for `gr.load()` as well as a tweaking the `__str__` and `__repr__` methods of components diff --git a/gradio/blocks.py b/gradio/blocks.py index 06643e3f20..cc23078f3d 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -185,14 +185,25 @@ class Block: def get_block_name(self) -> str: """ - Gets block's class name. - - If it is template component it gets the parent's class name. - - @return: class name + Gets block's class name. If it is template component it gets the parent's class name. + This is used to identify the Svelte file to use in the frontend. Override this method + if a component should use a different Svelte file than the default naming convention. """ return ( - self.__class__.__base__.__name__.lower() + self.__class__.__base__.__name__.lower() # type: ignore + if hasattr(self, "is_template") + else self.__class__.__name__.lower() + ) + + def get_block_class(self) -> str: + """ + Gets block's class name. If it is template component it gets the parent's class name. + Very similar to the get_block_name method, but this method is used to reconstruct a + Gradio app that is loaded from a Space using gr.load(). This should generally + NOT be overridden. + """ + return ( + self.__class__.__base__.__name__.lower() # type: ignore if hasattr(self, "is_template") else self.__class__.__name__.lower() ) @@ -212,7 +223,7 @@ class Block: if to_add: config = {**to_add, **config} config.pop("render", None) - config = {**config, "proxy_url": self.proxy_url, "name": self.get_block_name()} + config = {**config, "proxy_url": self.proxy_url, "name": self.get_block_class()} if (_selectable := getattr(self, "_selectable", None)) is not None: config["_selectable"] = _selectable return config @@ -701,7 +712,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): break else: raise ValueError(f"Cannot find block with id {id}") - cls = component_or_layout_class(block_config["type"]) + cls = component_or_layout_class(block_config["props"]["name"]) # If a Gradio app B is loaded into a Gradio app A, and B itself loads a # Gradio app C, then the proxy_urls of the components in A need to be the @@ -2481,7 +2492,7 @@ Received outputs: else: skip_endpoint = True # if component not found, skip endpoint break - type = component["type"] + type = component["props"]["name"] if self.blocks[component["id"]].skip_api: continue label = component["props"].get("label", f"parameter_{i}") @@ -2512,7 +2523,7 @@ Received outputs: else: skip_endpoint = True # if component not found, skip endpoint break - type = component["type"] + type = component["props"]["name"] if self.blocks[component["id"]].skip_api: continue label = component["props"].get("label", f"value_{o}") diff --git a/gradio/components/base.py b/gradio/components/base.py index f57fea5ad3..e23dd3a0f5 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -236,12 +236,6 @@ class Component(ComponentBase, Block): load_fn = None return load_fn, initial_value - def __str__(self): - return self.__repr__() - - def __repr__(self): - return f"{self.get_block_name()}" - def attach_load_event(self, callable: Callable, every: float | None): """Add a load event that runs `callable`, optionally every `every` seconds.""" self.load_event_to_attach = (callable, every) diff --git a/test/test_blocks.py b/test/test_blocks.py index c8ac81043f..efcf2a69b2 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -678,7 +678,7 @@ class TestBlocksPostprocessing: button.click(lambda x: x, textbox1, [textbox1, textbox2]) with pytest.raises( ValueError, - match=r'An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]', + match=r"^An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:", ): demo.postprocess_data(fn_index=0, predictions=["test"], state=None) @@ -693,7 +693,7 @@ class TestBlocksPostprocessing: button.click(infer, textbox1, [textbox1, textbox2]) with pytest.raises( ValueError, - match=r'An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]', + match=r"^An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:", ): demo.postprocess_data(fn_index=0, predictions=["test"], state=None) @@ -705,7 +705,7 @@ class TestBlocksPostprocessing: btn.click(lambda a: a, num1, [num1, num2]) with pytest.raises( ValueError, - match=r"An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[number, number\]\nReceived outputs:\n \[1\]", + match=r"^An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:", ): demo.postprocess_data(fn_index=0, predictions=1, state=None) @@ -721,7 +721,7 @@ class TestBlocksPostprocessing: btn.click(infer, num1, [num1, num2, num3]) with pytest.raises( ValueError, - match=r"An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:\n \[number, number, number\]\nReceived outputs:\n \[1, 2\]", + match=r"^An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:", ): demo.postprocess_data(fn_index=0, predictions=(1, 2), state=None) diff --git a/test/test_components.py b/test/test_components.py index bd896976f9..55f027d794 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -2470,7 +2470,7 @@ class TestScatterPlot: "elem_classes": [], "interactive": None, "label": None, - "name": "plot", + "name": "scatterplot", "bokeh_version": "3.0.3", "show_actions_button": False, "proxy_url": None, @@ -2630,7 +2630,7 @@ class TestLinePlot: "elem_classes": [], "interactive": None, "label": None, - "name": "plot", + "name": "lineplot", "bokeh_version": "3.0.3", "show_actions_button": False, "proxy_url": None, @@ -2735,7 +2735,7 @@ class TestBarPlot: "elem_classes": [], "interactive": None, "label": None, - "name": "plot", + "name": "barplot", "bokeh_version": "3.0.3", "show_actions_button": False, "proxy_url": None,