More fixes for gr.load() as well as a tweaking the __str__ and __repr__ methods of components (#7712)

* more fixes

* add changeset

* format

* add changeset

* add changeset

* address review

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-15 13:59:46 -07:00 committed by GitHub
parent 6390d0bf6c
commit aca4892ea5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 23 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:More fixes for `gr.load()` as well as a tweaking the `__str__` and `__repr__` methods of components

View File

@ -185,14 +185,25 @@ class Block:
def get_block_name(self) -> str:
"""
Gets block's class name.
If it is template component it gets the parent's class name.
@return: class name
Gets block's class name. If it is template component it gets the parent's class name.
This is used to identify the Svelte file to use in the frontend. Override this method
if a component should use a different Svelte file than the default naming convention.
"""
return (
self.__class__.__base__.__name__.lower()
self.__class__.__base__.__name__.lower() # type: ignore
if hasattr(self, "is_template")
else self.__class__.__name__.lower()
)
def get_block_class(self) -> str:
"""
Gets block's class name. If it is template component it gets the parent's class name.
Very similar to the get_block_name method, but this method is used to reconstruct a
Gradio app that is loaded from a Space using gr.load(). This should generally
NOT be overridden.
"""
return (
self.__class__.__base__.__name__.lower() # type: ignore
if hasattr(self, "is_template")
else self.__class__.__name__.lower()
)
@ -212,7 +223,7 @@ class Block:
if to_add:
config = {**to_add, **config}
config.pop("render", None)
config = {**config, "proxy_url": self.proxy_url, "name": self.get_block_name()}
config = {**config, "proxy_url": self.proxy_url, "name": self.get_block_class()}
if (_selectable := getattr(self, "_selectable", None)) is not None:
config["_selectable"] = _selectable
return config
@ -701,7 +712,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
break
else:
raise ValueError(f"Cannot find block with id {id}")
cls = component_or_layout_class(block_config["type"])
cls = component_or_layout_class(block_config["props"]["name"])
# If a Gradio app B is loaded into a Gradio app A, and B itself loads a
# Gradio app C, then the proxy_urls of the components in A need to be the
@ -2481,7 +2492,7 @@ Received outputs:
else:
skip_endpoint = True # if component not found, skip endpoint
break
type = component["type"]
type = component["props"]["name"]
if self.blocks[component["id"]].skip_api:
continue
label = component["props"].get("label", f"parameter_{i}")
@ -2512,7 +2523,7 @@ Received outputs:
else:
skip_endpoint = True # if component not found, skip endpoint
break
type = component["type"]
type = component["props"]["name"]
if self.blocks[component["id"]].skip_api:
continue
label = component["props"].get("label", f"value_{o}")

View File

@ -236,12 +236,6 @@ class Component(ComponentBase, Block):
load_fn = None
return load_fn, initial_value
def __str__(self):
return self.__repr__()
def __repr__(self):
return f"{self.get_block_name()}"
def attach_load_event(self, callable: Callable, every: float | None):
"""Add a load event that runs `callable`, optionally every `every` seconds."""
self.load_event_to_attach = (callable, every)

View File

@ -678,7 +678,7 @@ class TestBlocksPostprocessing:
button.click(lambda x: x, textbox1, [textbox1, textbox2])
with pytest.raises(
ValueError,
match=r'An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
match=r"^An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:",
):
demo.postprocess_data(fn_index=0, predictions=["test"], state=None)
@ -693,7 +693,7 @@ class TestBlocksPostprocessing:
button.click(infer, textbox1, [textbox1, textbox2])
with pytest.raises(
ValueError,
match=r'An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
match=r"^An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:",
):
demo.postprocess_data(fn_index=0, predictions=["test"], state=None)
@ -705,7 +705,7 @@ class TestBlocksPostprocessing:
btn.click(lambda a: a, num1, [num1, num2])
with pytest.raises(
ValueError,
match=r"An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[number, number\]\nReceived outputs:\n \[1\]",
match=r"^An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:",
):
demo.postprocess_data(fn_index=0, predictions=1, state=None)
@ -721,7 +721,7 @@ class TestBlocksPostprocessing:
btn.click(infer, num1, [num1, num2, num3])
with pytest.raises(
ValueError,
match=r"An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:\n \[number, number, number\]\nReceived outputs:\n \[1, 2\]",
match=r"^An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:",
):
demo.postprocess_data(fn_index=0, predictions=(1, 2), state=None)

View File

@ -2470,7 +2470,7 @@ class TestScatterPlot:
"elem_classes": [],
"interactive": None,
"label": None,
"name": "plot",
"name": "scatterplot",
"bokeh_version": "3.0.3",
"show_actions_button": False,
"proxy_url": None,
@ -2630,7 +2630,7 @@ class TestLinePlot:
"elem_classes": [],
"interactive": None,
"label": None,
"name": "plot",
"name": "lineplot",
"bokeh_version": "3.0.3",
"show_actions_button": False,
"proxy_url": None,
@ -2735,7 +2735,7 @@ class TestBarPlot:
"elem_classes": [],
"interactive": None,
"label": None,
"name": "plot",
"name": "barplot",
"bokeh_version": "3.0.3",
"show_actions_button": False,
"proxy_url": None,