Fix incorrect behavior of gr.load() with gr.Examples (#5690)

* testing

* fixes

* chat fix

* lint

* add changeset

* fix

* simplify

* simplify

* spacing

* remove print

* docstring

* dataset

* lint

* add changeset

* fix test

* add test

* added test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-09-26 12:20:34 -07:00 committed by GitHub
parent e51fcd5d54
commit 6b8c8afd98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 60 additions and 16 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix incorrect behavior of `gr.load()` with `gr.Examples`

View File

@ -730,6 +730,8 @@ class Blocks(BlockContext):
cls = component_or_layout_class(block_config["type"])
block_config["props"].pop("type", None)
block_config["props"].pop("name", None)
block_config["props"].pop("selectable", None)
# If a Gradio app B is loaded into a Gradio app A, and B itself loads a
# Gradio app C, then the root_urls of the components in A need to be the
# URL of C, not B. The else clause below handles this case.
@ -737,6 +739,19 @@ class Blocks(BlockContext):
block_config["props"]["root_url"] = f"{root_url}/"
else:
root_urls.add(block_config["props"]["root_url"])
# We treat dataset components as a special case because they reference other components
# in the config. Instead of using the component string names, we use the component ids.
if (
block_config["type"] == "dataset"
and "component_ids" in block_config["props"]
):
block_config["props"].pop("components", None)
block_config["props"]["components"] = [
original_mapping[c] for c in block_config["props"]["component_ids"]
]
block_config["props"].pop("component_ids", None)
# Any component has already processed its initial value, so we skip that step here
block = cls(**block_config["props"], _skip_init_processing=True)
return block

View File

@ -119,7 +119,7 @@ class ChatInterface(Blocks):
if not isinstance(additional_inputs, list):
additional_inputs = [additional_inputs]
self.additional_inputs = [
get_component_instance(i, render=False) for i in additional_inputs # type: ignore
get_component_instance(i) for i in additional_inputs # type: ignore
]
else:
self.additional_inputs = []

View File

@ -352,28 +352,29 @@ class FormComponent:
return Form
def component(cls_name: str) -> Component:
obj = utils.component_or_layout_class(cls_name)()
def component(cls_name: str, render: bool) -> Component:
obj = utils.component_or_layout_class(cls_name)(render=render)
if isinstance(obj, BlockContext):
raise ValueError(f"Invalid component: {obj.__class__}")
return obj
def get_component_instance(
comp: str | dict | Component, render: bool | None = None
comp: str | dict | Component, render: bool = False, unrender: bool = False
) -> Component:
"""
Returns a component instance from a string, dict, or Component object.
Parameters:
comp: the component to instantiate. If a string, must be the name of a component, e.g. "dropdown". If a dict, must have a "name" key, e.g. {"name": "dropdown", "choices": ["a", "b"]}. If a Component object, will be returned as is.
render: whether to render the component. If True, renders the component (if not already rendered). If False, *unrenders* the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If None, does not render or unrender the component.
render: whether to render the component. If True, renders the component (if not already rendered). If False, does not do anything.
unrender: whether to unrender the component. If True, unrenders the the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If False, does not do anything.
"""
if isinstance(comp, str):
component_obj = component(comp)
component_obj = component(comp, render=render)
elif isinstance(comp, dict):
name = comp.pop("name")
component_cls = utils.component_or_layout_class(name)
component_obj = component_cls(**comp)
component_obj = component_cls(**comp, render=render)
if isinstance(component_obj, BlockContext):
raise ValueError(f"Invalid component: {name}")
elif isinstance(comp, Component):
@ -384,6 +385,6 @@ def get_component_instance(
)
if render and not component_obj.is_rendered:
component_obj.render()
elif render is False and component_obj.is_rendered:
elif unrender and component_obj.is_rendered:
component_obj.unrender()
return component_obj

View File

@ -116,6 +116,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
config["components"] = [
component.get_block_name() for component in self._components
]
config["component_ids"] = [component._id for component in self._components]
return config
def preprocess(self, x: Any) -> Any:

View File

@ -250,10 +250,10 @@ class Interface(Blocks):
self.cache_examples = False
self.input_components = [
get_component_instance(i, render=False) for i in inputs # type: ignore
get_component_instance(i, unrender=True) for i in inputs # type: ignore
]
self.output_components = [
get_component_instance(o, render=False) for o in outputs # type: ignore
get_component_instance(o, unrender=True) for o in outputs # type: ignore
]
for component in self.input_components + self.output_components:

View File

@ -34,12 +34,30 @@ from gradio.deprecation import (
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestComponent:
def test_component_functions(self):
"""
component
"""
assert isinstance(gr.components.component("textarea"), gr.templates.TextArea)
class TestGettingComponents:
def test_component_function(self):
assert isinstance(
gr.components.component("textarea", render=False), gr.templates.TextArea
)
@pytest.mark.parametrize(
"component, render, unrender, should_be_rendered",
[
(gr.Textbox(render=True), False, True, False),
(gr.Textbox(render=False), False, False, False),
(gr.Textbox(render=False), True, False, True),
("textbox", False, False, False),
("textbox", True, False, True),
],
)
def test_get_component_instance_rendering(
self, component, render, unrender, should_be_rendered
):
with gr.Blocks():
textbox = gr.components.get_component_instance(
component, render=render, unrender=unrender
)
assert textbox.is_rendered == should_be_rendered
def test_raise_warnings():
@ -1387,6 +1405,10 @@ class TestDataset:
assert dataset.preprocess(1) == 1
radio = gr.Radio(choices=[("name 1", "value 1"), ("name 2", "value 2")])
dataset = gr.Dataset(samples=[["value 1"], ["value 2"]], components=[radio])
assert dataset.samples == [["name 1"], ["name 2"]]
def test_postprocessing(self):
test_file_dir = Path(Path(__file__).parent, "test_files")
bus = Path(test_file_dir, "bus.png")