mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-13 11:57:29 +08:00
Store configs per session in the backend (#8030)
* changes * add changeset * changes * changes * changes * changes * changes * changes * changeas * add changeset * unrelated fix * Update gradio/blocks.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> --------- Co-authored-by: Ali Abid <aliabid94@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
659d3c51ae
commit
91a7a31cd1
5
.changeset/twenty-corners-peel.md
Normal file
5
.changeset/twenty-corners-peel.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Store configs per session in the backend
|
@ -184,21 +184,21 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
x, BlockContext
|
||||
) else blocks_telemetry.append(str(x))
|
||||
|
||||
for x in blocks.dependencies:
|
||||
for x in blocks.fns:
|
||||
targets_telemetry = targets_telemetry + [
|
||||
# Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks
|
||||
str(blocks.blocks[y[0]])
|
||||
for y in x["targets"]
|
||||
for y in x.targets
|
||||
if y[0] in blocks.blocks
|
||||
]
|
||||
events_telemetry = events_telemetry + [
|
||||
y[1] for y in x["targets"] if y[0] in blocks.blocks
|
||||
y[1] for y in x.targets if y[0] in blocks.blocks
|
||||
]
|
||||
inputs_telemetry = inputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["inputs"] if y in blocks.blocks
|
||||
str(blocks.blocks[y]) for y in x.inputs if y in blocks.blocks
|
||||
]
|
||||
outputs_telemetry = outputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["outputs"] if y in blocks.blocks
|
||||
str(blocks.blocks[y]) for y in x.outputs if y in blocks.blocks
|
||||
]
|
||||
additional_data = {
|
||||
"version": get_package_version(),
|
||||
|
479
gradio/blocks.py
479
gradio/blocks.py
@ -49,7 +49,6 @@ from gradio.events import (
|
||||
from gradio.exceptions import (
|
||||
DuplicateBlockError,
|
||||
InvalidApiNameError,
|
||||
InvalidBlockError,
|
||||
InvalidComponentError,
|
||||
)
|
||||
from gradio.helpers import create_tracker, skip, special_args
|
||||
@ -458,11 +457,24 @@ class BlockFunction:
|
||||
preprocess: bool,
|
||||
postprocess: bool,
|
||||
inputs_as_dict: bool,
|
||||
targets: list[tuple[int | None, str]],
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
tracks_progress: bool = False,
|
||||
api_name: str | Literal[False] = False,
|
||||
js: str | None = None,
|
||||
show_progress: Literal["full", "minimal", "hidden"] = "full",
|
||||
every: float | None = None,
|
||||
cancels: list[int] | None = None,
|
||||
collects_event_data: bool = False,
|
||||
trigger_after: int | None = None,
|
||||
trigger_only_on_success: bool = False,
|
||||
trigger_mode: Literal["always_last", "once", "multiple"] = "once",
|
||||
queue: bool | None = None,
|
||||
scroll_to_output: bool = False,
|
||||
show_api: bool = True,
|
||||
):
|
||||
self.fn = fn
|
||||
self.inputs = inputs
|
||||
@ -477,7 +489,27 @@ class BlockFunction:
|
||||
self.total_runtime = 0
|
||||
self.total_runs = 0
|
||||
self.inputs_as_dict = inputs_as_dict
|
||||
self.targets = targets
|
||||
self.name = getattr(fn, "__name__", "fn") if fn is not None else None
|
||||
self.api_name = api_name
|
||||
self.js = js
|
||||
self.show_progress = show_progress
|
||||
self.every = every
|
||||
self.cancels = cancels or []
|
||||
self.collects_event_data = collects_event_data
|
||||
self.trigger_after = trigger_after
|
||||
self.trigger_only_on_success = trigger_only_on_success
|
||||
self.trigger_mode = trigger_mode
|
||||
self.queue = False if fn is None else queue
|
||||
self.scroll_to_output = False if utils.get_space() else scroll_to_output
|
||||
self.show_api = show_api
|
||||
self.zero_gpu = hasattr(self.fn, "zerogpu")
|
||||
self.types_continuous = bool(self.every)
|
||||
self.types_generator = (
|
||||
inspect.isgeneratorfunction(self.fn)
|
||||
or inspect.isasyncgenfunction(self.fn)
|
||||
or bool(self.every)
|
||||
)
|
||||
self.spaces_auto_wrap()
|
||||
|
||||
def spaces_auto_wrap(self):
|
||||
@ -499,6 +531,33 @@ class BlockFunction:
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"targets": self.targets,
|
||||
"inputs": [block._id for block in self.inputs],
|
||||
"outputs": [block._id for block in self.outputs],
|
||||
"backend_fn": self.fn is not None,
|
||||
"js": self.js,
|
||||
"queue": self.queue,
|
||||
"api_name": self.api_name,
|
||||
"scroll_to_output": self.scroll_to_output,
|
||||
"show_progress": self.show_progress,
|
||||
"every": self.every,
|
||||
"batch": self.batch,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"cancels": self.cancels,
|
||||
"types": {
|
||||
"continuous": self.types_continuous,
|
||||
"generator": self.types_generator,
|
||||
},
|
||||
"collects_event_data": self.collects_event_data,
|
||||
"trigger_after": self.trigger_after,
|
||||
"trigger_only_on_success": self.trigger_only_on_success,
|
||||
"trigger_mode": self.trigger_mode,
|
||||
"show_api": self.show_api,
|
||||
"zerogpu": self.zero_gpu,
|
||||
}
|
||||
|
||||
|
||||
def postprocess_update_dict(
|
||||
block: Component | BlockContext, update_dict: dict, postprocess: bool = True
|
||||
@ -552,6 +611,55 @@ def convert_component_dict_to_list(
|
||||
return predictions
|
||||
|
||||
|
||||
class BlocksConfig:
|
||||
def __init__(self, root_block: Blocks):
|
||||
self._id: int = 0
|
||||
self.root_block = root_block
|
||||
self.blocks: dict[int, Component | Block] = {}
|
||||
self.fns: list[BlockFunction] = []
|
||||
|
||||
def get_config(self):
|
||||
config = {}
|
||||
|
||||
def get_layout(block):
|
||||
if not isinstance(block, BlockContext):
|
||||
return {"id": block._id}
|
||||
children_layout = []
|
||||
for child in block.children:
|
||||
children_layout.append(get_layout(child))
|
||||
return {"id": block._id, "children": children_layout}
|
||||
|
||||
config["layout"] = get_layout(self.root_block)
|
||||
|
||||
config["components"] = []
|
||||
for _id, block in self.blocks.items():
|
||||
props = block.get_config() if hasattr(block, "get_config") else {}
|
||||
block_config = {
|
||||
"id": _id,
|
||||
"type": block.get_block_name(),
|
||||
"props": utils.delete_none(props),
|
||||
"skip_api": block.skip_api,
|
||||
"component_class_id": getattr(block, "component_class_id", None),
|
||||
}
|
||||
if not block.skip_api:
|
||||
block_config["api_info"] = block.api_info() # type: ignore
|
||||
# .example_inputs() has been renamed .example_payload() but
|
||||
# we use the old name for backwards compatibility with custom components
|
||||
# created on Gradio 4.20.0 or earlier
|
||||
block_config["example_inputs"] = block.example_inputs() # type: ignore
|
||||
config["components"].append(block_config)
|
||||
|
||||
config["dependencies"] = [fn.get_config() for fn in self.fns]
|
||||
|
||||
return config
|
||||
|
||||
def __copy__(self):
|
||||
new = BlocksConfig(self.root_block)
|
||||
new.blocks = copy.copy(self.blocks)
|
||||
new.fns = copy.copy(self.fns)
|
||||
return new
|
||||
|
||||
|
||||
@document("launch", "queue", "integrate", "load", "unload")
|
||||
class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
"""
|
||||
@ -665,12 +773,11 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
t.start()
|
||||
else:
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
|
||||
super().__init__(render=False, **kwargs)
|
||||
self.blocks: dict[int, Component | Block] = {}
|
||||
self.fns: list[BlockFunction] = []
|
||||
self.dependencies = []
|
||||
self.mode = mode
|
||||
|
||||
self.default_config = BlocksConfig(self)
|
||||
super().__init__(render=False, **kwargs)
|
||||
|
||||
self.mode = mode
|
||||
self.is_running = False
|
||||
self.local_url = None
|
||||
self.share_url = None
|
||||
@ -718,6 +825,22 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
|
||||
self.queue()
|
||||
|
||||
@property
|
||||
def blocks(self) -> dict[int, Component | Block]:
|
||||
return self.default_config.blocks
|
||||
|
||||
@blocks.setter
|
||||
def blocks(self, value: dict[int, Component | Block]):
|
||||
self.default_config.blocks = value
|
||||
|
||||
@property
|
||||
def fns(self) -> list[BlockFunction]:
|
||||
return self.default_config.fns
|
||||
|
||||
@fns.setter
|
||||
def fns(self, value: list[BlockFunction]):
|
||||
self.default_config.fns = value
|
||||
|
||||
def get_component(self, id: int) -> Component | BlockContext:
|
||||
comp = self.blocks[id]
|
||||
if not isinstance(comp, (components.Component, BlockContext)):
|
||||
@ -850,7 +973,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
original_mapping[o] for o in dependency["outputs"]
|
||||
]
|
||||
dependency.pop("status_tracker", None)
|
||||
dependency.pop("zerogpu")
|
||||
dependency.pop("zerogpu", None)
|
||||
dependency["preprocess"] = False
|
||||
dependency["postprocess"] = False
|
||||
if is_then_event:
|
||||
@ -886,12 +1009,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
# Allows some use of Interface-specific methods with loaded Spaces
|
||||
if first_dependency and Context.root_block:
|
||||
blocks.predict = [fns[0]]
|
||||
blocks.input_components = [
|
||||
Context.root_block.blocks[i] for i in first_dependency["inputs"]
|
||||
]
|
||||
blocks.output_components = [
|
||||
Context.root_block.blocks[o] for o in first_dependency["outputs"]
|
||||
]
|
||||
blocks.input_components = first_dependency.inputs
|
||||
blocks.output_components = first_dependency.outputs
|
||||
blocks.__name__ = "Interface"
|
||||
blocks.api_mode = True
|
||||
blocks.proxy_urls = proxy_urls
|
||||
@ -901,19 +1020,19 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
|
||||
num_backend_fns = len([d for d in self.fns if d.fn])
|
||||
repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
|
||||
repr += f"\n{'-' * len(repr)}"
|
||||
for d, dependency in enumerate(self.dependencies):
|
||||
if dependency["backend_fn"]:
|
||||
for d, dependency in enumerate(self.fns):
|
||||
if dependency.fn:
|
||||
repr += f"\nfn_index={d}"
|
||||
repr += "\n inputs:"
|
||||
for input_id in dependency["inputs"]:
|
||||
block = self.blocks[input_id]
|
||||
for block in dependency.inputs:
|
||||
block = self.blocks[block._id]
|
||||
repr += f"\n |-{block}"
|
||||
repr += "\n outputs:"
|
||||
for output_id in dependency["outputs"]:
|
||||
block = self.blocks[output_id]
|
||||
for block in dependency.outputs:
|
||||
block = self.blocks[block._id]
|
||||
repr += f"\n |-{block}"
|
||||
return repr
|
||||
|
||||
@ -987,7 +1106,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
show_api: bool = True,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
) -> tuple[BlockFunction, int]:
|
||||
"""
|
||||
Adds an event to the component's dependencies.
|
||||
Parameters:
|
||||
@ -1075,21 +1194,6 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
_, progress_index, event_data_index = (
|
||||
special_args(fn) if fn else (None, None, None)
|
||||
)
|
||||
self.fns.append(
|
||||
BlockFunction(
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess,
|
||||
postprocess,
|
||||
inputs_as_dict=inputs_as_dict,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
tracks_progress=progress_index is not None,
|
||||
)
|
||||
)
|
||||
|
||||
# If api_name is None or empty string, use the function name
|
||||
if api_name is None or isinstance(api_name, str) and api_name.strip() == "":
|
||||
@ -1113,7 +1217,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
|
||||
if api_name is not False:
|
||||
api_name = utils.append_unique_suffix(
|
||||
api_name, [dep["api_name"] for dep in self.dependencies]
|
||||
api_name,
|
||||
[fn.api_name for fn in self.fns if isinstance(fn.api_name, str)],
|
||||
)
|
||||
else:
|
||||
show_api = False
|
||||
@ -1124,35 +1229,35 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
if collects_event_data is None:
|
||||
collects_event_data = event_data_index is not None
|
||||
|
||||
dependency = {
|
||||
"targets": _targets,
|
||||
"inputs": [block._id for block in inputs],
|
||||
"outputs": [block._id for block in outputs],
|
||||
"backend_fn": fn is not None,
|
||||
"js": js,
|
||||
"queue": False if fn is None else queue,
|
||||
"api_name": api_name,
|
||||
"scroll_to_output": False if utils.get_space() else scroll_to_output,
|
||||
"show_progress": show_progress,
|
||||
"every": every,
|
||||
"batch": batch,
|
||||
"max_batch_size": max_batch_size,
|
||||
"cancels": cancels or [],
|
||||
"types": {
|
||||
"continuous": bool(every),
|
||||
"generator": inspect.isgeneratorfunction(fn)
|
||||
or inspect.isasyncgenfunction(fn)
|
||||
or bool(every),
|
||||
},
|
||||
"collects_event_data": collects_event_data,
|
||||
"trigger_after": trigger_after,
|
||||
"trigger_only_on_success": trigger_only_on_success,
|
||||
"trigger_mode": trigger_mode,
|
||||
"show_api": show_api,
|
||||
"zerogpu": hasattr(fn, "zerogpu"),
|
||||
}
|
||||
self.dependencies.append(dependency)
|
||||
return dependency, len(self.dependencies) - 1
|
||||
block_fn = BlockFunction(
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess,
|
||||
postprocess,
|
||||
inputs_as_dict=inputs_as_dict,
|
||||
targets=_targets,
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
tracks_progress=progress_index is not None,
|
||||
api_name=api_name,
|
||||
js=js,
|
||||
show_progress=show_progress,
|
||||
every=every,
|
||||
cancels=cancels,
|
||||
collects_event_data=collects_event_data,
|
||||
trigger_after=trigger_after,
|
||||
trigger_only_on_success=trigger_only_on_success,
|
||||
trigger_mode=trigger_mode,
|
||||
queue=queue,
|
||||
scroll_to_output=scroll_to_output,
|
||||
show_api=show_api,
|
||||
)
|
||||
|
||||
self.fns.append(block_fn)
|
||||
return block_fn, len(self.fns) - 1
|
||||
|
||||
def render(self):
|
||||
if Context.root_block is not None:
|
||||
@ -1169,40 +1274,34 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
)
|
||||
|
||||
Context.root_block.blocks.update(self.blocks)
|
||||
Context.root_block.fns.extend(self.fns)
|
||||
dependency_offset = len(Context.root_block.dependencies)
|
||||
for i, dependency in enumerate(self.dependencies):
|
||||
api_name = dependency["api_name"]
|
||||
if api_name is not None and api_name is not False:
|
||||
dependency_offset = len(Context.root_block.fns)
|
||||
existing_api_names = [
|
||||
dep.api_name
|
||||
for dep in Context.root_block.fns
|
||||
if isinstance(dep.api_name, str)
|
||||
]
|
||||
for dependency in self.fns:
|
||||
api_name = dependency.api_name
|
||||
if isinstance(api_name, str):
|
||||
api_name_ = utils.append_unique_suffix(
|
||||
api_name,
|
||||
[dep["api_name"] for dep in Context.root_block.dependencies],
|
||||
existing_api_names,
|
||||
)
|
||||
if api_name != api_name_:
|
||||
dependency["api_name"] = api_name_
|
||||
dependency["cancels"] = [
|
||||
c + dependency_offset for c in dependency["cancels"]
|
||||
]
|
||||
if dependency.get("trigger_after") is not None:
|
||||
dependency["trigger_after"] += dependency_offset
|
||||
dependency.api_name = api_name_
|
||||
dependency.cancels = [c + dependency_offset for c in dependency.cancels]
|
||||
if dependency.trigger_after is not None:
|
||||
dependency.trigger_after += dependency_offset
|
||||
# Recreate the cancel function so that it has the latest
|
||||
# dependency fn indices. This is necessary to properly cancel
|
||||
# events in the backend
|
||||
if dependency["cancels"]:
|
||||
if dependency.cancels:
|
||||
updated_cancels = [
|
||||
Context.root_block.dependencies[i]
|
||||
for i in dependency["cancels"]
|
||||
Context.root_block.fns[i].get_config()
|
||||
for i in dependency.cancels
|
||||
]
|
||||
new_fn = BlockFunction(
|
||||
get_cancel_function(updated_cancels)[0],
|
||||
[],
|
||||
[],
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
Context.root_block.fns[dependency_offset + i] = new_fn
|
||||
Context.root_block.dependencies.append(dependency)
|
||||
dependency.fn = get_cancel_function(updated_cancels)[0]
|
||||
Context.root_block.fns.append(dependency)
|
||||
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
|
||||
Context.root_block.proxy_urls.update(self.proxy_urls)
|
||||
|
||||
@ -1213,20 +1312,16 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
def is_callable(self, fn_index: int = 0) -> bool:
|
||||
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
dependency = self.fns[fn_index]
|
||||
|
||||
if inspect.isasyncgenfunction(block_fn.fn):
|
||||
return False
|
||||
if inspect.isgeneratorfunction(block_fn.fn):
|
||||
return False
|
||||
for input_id in dependency["inputs"]:
|
||||
block = self.blocks[input_id]
|
||||
if getattr(block, "stateful", False):
|
||||
return False
|
||||
for output_id in dependency["outputs"]:
|
||||
block = self.blocks[output_id]
|
||||
if getattr(block, "stateful", False):
|
||||
return False
|
||||
if any(block.stateful for block in dependency.inputs):
|
||||
return False
|
||||
if any(block.stateful for block in dependency.outputs):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -1243,11 +1338,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
"""
|
||||
if api_name is not None:
|
||||
inferred_fn_index = next(
|
||||
(
|
||||
i
|
||||
for i, d in enumerate(self.dependencies)
|
||||
if d.get("api_name") == api_name
|
||||
),
|
||||
(i for i, d in enumerate(self.fns) if d.api_name == api_name),
|
||||
None,
|
||||
)
|
||||
if inferred_fn_index is None:
|
||||
@ -1262,7 +1353,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
|
||||
inputs = list(inputs)
|
||||
processed_inputs = self.serialize_data(fn_index, inputs)
|
||||
batch = self.dependencies[fn_index]["batch"]
|
||||
batch = self.fns[fn_index].batch
|
||||
if batch:
|
||||
processed_inputs = [[inp] for inp in processed_inputs]
|
||||
|
||||
@ -1352,7 +1443,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
prediction = await utils.async_iteration(iterator)
|
||||
is_generating = True
|
||||
except StopAsyncIteration:
|
||||
n_outputs = len(self.dependencies[fn_index].get("outputs"))
|
||||
n_outputs = len(self.fns[fn_index].outputs)
|
||||
prediction = (
|
||||
components._Keywords.FINISHED_ITERATING
|
||||
if n_outputs == 1
|
||||
@ -1370,22 +1461,16 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
}
|
||||
|
||||
def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
|
||||
dependency = self.dependencies[fn_index]
|
||||
dependency = self.fns[fn_index]
|
||||
processed_input = []
|
||||
|
||||
def format_file(s):
|
||||
return FileData(path=s).model_dump()
|
||||
|
||||
for i, input_id in enumerate(dependency["inputs"]):
|
||||
try:
|
||||
block = self.blocks[input_id]
|
||||
except KeyError as e:
|
||||
raise InvalidBlockError(
|
||||
f"Input component with id {input_id} used in {dependency['trigger']}() event is not defined in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
for i, block in enumerate(dependency.inputs):
|
||||
if not isinstance(block, components.Component):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
f"{block.__class__} Component not a valid input component."
|
||||
)
|
||||
api_info = block.api_info()
|
||||
if client_utils.value_is_file(api_info):
|
||||
@ -1402,19 +1487,13 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
return processed_input
|
||||
|
||||
def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
|
||||
dependency = self.dependencies[fn_index]
|
||||
dependency = self.fns[fn_index]
|
||||
predictions = []
|
||||
|
||||
for o, output_id in enumerate(dependency["outputs"]):
|
||||
try:
|
||||
block = self.blocks[output_id]
|
||||
except KeyError as e:
|
||||
raise InvalidBlockError(
|
||||
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
for o, block in enumerate(dependency.outputs):
|
||||
if not isinstance(block, components.Component):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
f"{block.__class__} Component not a valid output component."
|
||||
)
|
||||
|
||||
deserialized = client_utils.traverse(
|
||||
@ -1426,9 +1505,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
|
||||
def validate_inputs(self, fn_index: int, inputs: list[Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
dep_inputs = dependency["inputs"]
|
||||
dep_inputs = block_fn.inputs
|
||||
|
||||
# This handles incorrect inputs when args are changed by a JS function
|
||||
# Only check not enough args case, ignore extra arguments (for now)
|
||||
@ -1442,8 +1520,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
|
||||
wanted_args = []
|
||||
received_args = []
|
||||
for input_id in dep_inputs:
|
||||
block = self.blocks[input_id]
|
||||
for block in dep_inputs:
|
||||
wanted_args.append(str(block))
|
||||
for inp in inputs:
|
||||
v = f'"{inp}"' if isinstance(inp, str) else str(inp)
|
||||
@ -1471,28 +1548,21 @@ Received inputs:
|
||||
):
|
||||
state = state or SessionState(self)
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
self.validate_inputs(fn_index, inputs)
|
||||
|
||||
if block_fn.preprocess:
|
||||
processed_input = []
|
||||
for i, input_id in enumerate(dependency["inputs"]):
|
||||
try:
|
||||
block = self.blocks[input_id]
|
||||
except KeyError as e:
|
||||
raise InvalidBlockError(
|
||||
f"Input component with id {input_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
for i, block in enumerate(block_fn.inputs):
|
||||
if not isinstance(block, components.Component):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
f"{block.__class__} Component not a valid input component."
|
||||
)
|
||||
if getattr(block, "stateful", False):
|
||||
processed_input.append(state[input_id])
|
||||
if block.stateful:
|
||||
processed_input.append(state[block._id])
|
||||
else:
|
||||
if input_id in state:
|
||||
block = state[input_id]
|
||||
if block._id in state:
|
||||
block = state[block._id]
|
||||
inputs_cached = await processing_utils.async_move_files_to_cache(
|
||||
inputs[i],
|
||||
block,
|
||||
@ -1510,9 +1580,9 @@ Received inputs:
|
||||
|
||||
def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
dependency = self.fns[fn_index]
|
||||
|
||||
dep_outputs = dependency["outputs"]
|
||||
dep_outputs = dependency.outputs
|
||||
|
||||
if not isinstance(predictions, (list, tuple)):
|
||||
predictions = [predictions]
|
||||
@ -1526,8 +1596,7 @@ Received inputs:
|
||||
|
||||
wanted_args = []
|
||||
received_args = []
|
||||
for output_id in dep_outputs:
|
||||
block = self.blocks[output_id]
|
||||
for block in dep_outputs:
|
||||
wanted_args.append(str(block))
|
||||
for pred in predictions:
|
||||
v = f'"{pred}"' if isinstance(pred, str) else str(pred)
|
||||
@ -1549,15 +1618,13 @@ Received outputs:
|
||||
):
|
||||
state = state or SessionState(self)
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
batch = dependency["batch"]
|
||||
|
||||
if isinstance(predictions, dict) and len(predictions) > 0:
|
||||
predictions = convert_component_dict_to_list(
|
||||
dependency["outputs"], predictions
|
||||
[block._id for block in block_fn.outputs], predictions
|
||||
)
|
||||
|
||||
if len(dependency["outputs"]) == 1 and not (batch):
|
||||
if len(block_fn.outputs) == 1 and not block_fn.batch:
|
||||
predictions = [
|
||||
predictions,
|
||||
]
|
||||
@ -1565,7 +1632,7 @@ Received outputs:
|
||||
self.validate_outputs(fn_index, predictions) # type: ignore
|
||||
|
||||
output = []
|
||||
for i, output_id in enumerate(dependency["outputs"]):
|
||||
for i, block in enumerate(block_fn.outputs):
|
||||
try:
|
||||
if predictions[i] is components._Keywords.FINISHED_ITERATING:
|
||||
output.append(None)
|
||||
@ -1576,16 +1643,9 @@ Received outputs:
|
||||
f"of values returned from from function {block_fn.name}"
|
||||
) from err
|
||||
|
||||
try:
|
||||
block = self.blocks[output_id]
|
||||
except KeyError as e:
|
||||
raise InvalidBlockError(
|
||||
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
|
||||
if block.stateful:
|
||||
if not utils.is_update(predictions[i]):
|
||||
state[output_id] = predictions[i]
|
||||
state[block._id] = predictions[i]
|
||||
output.append(None)
|
||||
else:
|
||||
prediction_value = predictions[i]
|
||||
@ -1600,28 +1660,25 @@ Received outputs:
|
||||
prediction_value = prediction_value.constructor_args.copy()
|
||||
prediction_value["__type__"] = "update"
|
||||
if utils.is_update(prediction_value):
|
||||
if output_id in state:
|
||||
kwargs = state[output_id].constructor_args.copy()
|
||||
else:
|
||||
kwargs = self.blocks[output_id].constructor_args.copy()
|
||||
kwargs = state[block._id].constructor_args.copy()
|
||||
kwargs.update(prediction_value)
|
||||
kwargs.pop("value", None)
|
||||
kwargs.pop("__type__")
|
||||
kwargs["render"] = False
|
||||
state[output_id] = self.blocks[output_id].__class__(**kwargs)
|
||||
state[block._id] = block.__class__(**kwargs)
|
||||
|
||||
prediction_value = postprocess_update_dict(
|
||||
block=state[output_id],
|
||||
block=state[block._id],
|
||||
update_dict=prediction_value,
|
||||
postprocess=block_fn.postprocess,
|
||||
)
|
||||
elif block_fn.postprocess:
|
||||
if not isinstance(block, components.Component):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
f"{block.__class__} Component not a valid output component."
|
||||
)
|
||||
if output_id in state:
|
||||
block = state[output_id]
|
||||
if block._id in state:
|
||||
block = state[block._id]
|
||||
prediction_value = block.postprocess(prediction_value)
|
||||
|
||||
outputs_cached = await processing_utils.async_move_files_to_cache(
|
||||
@ -1647,8 +1704,8 @@ Received outputs:
|
||||
self.pending_streams[session_hash][run] = {}
|
||||
stream_run = self.pending_streams[session_hash][run]
|
||||
|
||||
for i, output_id in enumerate(self.dependencies[fn_index]["outputs"]):
|
||||
block = self.blocks[output_id]
|
||||
for i, block in enumerate(self.fns[fn_index].outputs):
|
||||
output_id = block._id
|
||||
if isinstance(block, components.StreamingOutput) and block.streaming:
|
||||
first_chunk = output_id not in stream_run
|
||||
binary_data, output_data = block.stream_output(
|
||||
@ -1686,7 +1743,7 @@ Received outputs:
|
||||
self.pending_diff_streams[session_hash][run] = [None] * len(data)
|
||||
last_diffs = self.pending_diff_streams[session_hash][run]
|
||||
|
||||
for i in range(len(self.dependencies[fn_index]["outputs"])):
|
||||
for i in range(len(self.fns[fn_index].outputs)):
|
||||
if final:
|
||||
data[i] = last_diffs[i]
|
||||
continue
|
||||
@ -1736,10 +1793,10 @@ Received outputs:
|
||||
Returns: None
|
||||
"""
|
||||
block_fn = self.fns[fn_index]
|
||||
batch = self.dependencies[fn_index]["batch"]
|
||||
batch = self.fns[fn_index].batch
|
||||
|
||||
if batch:
|
||||
max_batch_size = self.dependencies[fn_index]["max_batch_size"]
|
||||
max_batch_size = self.fns[fn_index].max_batch_size
|
||||
batch_sizes = [len(inp) for inp in inputs]
|
||||
batch_size = batch_sizes[0]
|
||||
if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
|
||||
@ -1871,37 +1928,8 @@ Received outputs:
|
||||
},
|
||||
"fill_height": self.fill_height,
|
||||
}
|
||||
config.update(self.default_config.get_config())
|
||||
|
||||
def get_layout(block):
|
||||
if not isinstance(block, BlockContext):
|
||||
return {"id": block._id}
|
||||
children_layout = []
|
||||
for child in block.children:
|
||||
children_layout.append(get_layout(child))
|
||||
return {"id": block._id, "children": children_layout}
|
||||
|
||||
config["layout"] = get_layout(self)
|
||||
|
||||
for _id, block in self.blocks.items():
|
||||
props = block.get_config() if hasattr(block, "get_config") else {}
|
||||
block_config = {
|
||||
"id": _id,
|
||||
"type": block.get_block_name(),
|
||||
"props": utils.delete_none(props),
|
||||
}
|
||||
block_config["skip_api"] = block.skip_api
|
||||
block_config["component_class_id"] = getattr(
|
||||
block, "component_class_id", None
|
||||
)
|
||||
|
||||
if not block.skip_api:
|
||||
block_config["api_info"] = block.api_info() # type: ignore
|
||||
# .example_inputs() has been renamed .example_payload() but
|
||||
# we use the old name for backwards compatibility with custom components
|
||||
# created on Gradio 4.20.0 or earlier
|
||||
block_config["example_inputs"] = block.example_inputs() # type: ignore
|
||||
config["components"].append(block_config)
|
||||
config["dependencies"] = self.dependencies
|
||||
return config
|
||||
|
||||
def __enter__(self):
|
||||
@ -1932,9 +1960,8 @@ Received outputs:
|
||||
|
||||
def clear(self):
|
||||
"""Resets the layout of the Blocks object."""
|
||||
self.blocks = {}
|
||||
self.fns = []
|
||||
self.dependencies = []
|
||||
self.default_config.blocks = {}
|
||||
self.default_config.fns = []
|
||||
self.children = []
|
||||
return self
|
||||
|
||||
@ -1988,8 +2015,8 @@ Received outputs:
|
||||
return self
|
||||
|
||||
def validate_queue_settings(self):
|
||||
for dep in self.dependencies:
|
||||
for i in dep["cancels"]:
|
||||
for dep in self.fns:
|
||||
for i in dep.cancels:
|
||||
if not self.queue_enabled_for_fn(i):
|
||||
raise ValueError(
|
||||
"Queue needs to be enabled! "
|
||||
@ -1998,7 +2025,7 @@ Received outputs:
|
||||
"another event without enabling the queue. Both can be solved by calling .queue() "
|
||||
"before .launch()"
|
||||
)
|
||||
if dep["batch"] and dep["queue"] is False:
|
||||
if dep.batch and dep.queue is False:
|
||||
raise ValueError("In order to use batching, the queue must be enabled.")
|
||||
|
||||
def launch(
|
||||
@ -2540,7 +2567,7 @@ Received outputs:
|
||||
queue=False if every is None else None,
|
||||
every=every,
|
||||
)[0]
|
||||
component.load_event = dep
|
||||
component.load_event = dep.get_config()
|
||||
|
||||
def startup_events(self):
|
||||
"""Events that should be run when the app containing this block starts up."""
|
||||
@ -2551,7 +2578,7 @@ Received outputs:
|
||||
self.create_limiter()
|
||||
|
||||
def queue_enabled_for_fn(self, fn_index: int):
|
||||
return self.dependencies[fn_index]["queue"] is not False
|
||||
return self.fns[fn_index].queue is not False
|
||||
|
||||
def get_api_info(self):
|
||||
"""
|
||||
@ -2560,22 +2587,18 @@ Received outputs:
|
||||
config = self.config
|
||||
api_info = {"named_endpoints": {}, "unnamed_endpoints": {}}
|
||||
|
||||
for dependency, fn in zip(config["dependencies"], self.fns):
|
||||
if (
|
||||
not dependency["backend_fn"]
|
||||
or not dependency["show_api"]
|
||||
or dependency["api_name"] is False
|
||||
):
|
||||
for fn in self.fns:
|
||||
if not fn.fn or not fn.show_api or fn.api_name is False:
|
||||
continue
|
||||
|
||||
dependency_info = {"parameters": [], "returns": []}
|
||||
fn_info = utils.get_function_params(fn.fn) # type: ignore
|
||||
skip_endpoint = False
|
||||
|
||||
inputs = dependency["inputs"]
|
||||
for index, input_id in enumerate(inputs):
|
||||
inputs = fn.inputs
|
||||
for index, input_block in enumerate(inputs):
|
||||
for component in config["components"]:
|
||||
if component["id"] == input_id:
|
||||
if component["id"] == input_block._id:
|
||||
break
|
||||
else:
|
||||
skip_endpoint = True # if component not found, skip endpoint
|
||||
@ -2583,7 +2606,7 @@ Received outputs:
|
||||
type = component["props"]["name"]
|
||||
if self.blocks[component["id"]].skip_api:
|
||||
continue
|
||||
label = component["props"].get("label", f"parameter_{input_id}")
|
||||
label = component["props"].get("label", f"parameter_{input_block._id}")
|
||||
comp = self.get_component(component["id"])
|
||||
if not isinstance(comp, components.Component):
|
||||
raise TypeError(f"{comp!r} is not a Component")
|
||||
@ -2595,7 +2618,7 @@ Received outputs:
|
||||
# "result_callbacks" to specify the callbacks, we need to make sure that no parameters
|
||||
# have those names. Hence the final checks.
|
||||
if (
|
||||
dependency["backend_fn"]
|
||||
fn.fn
|
||||
and index < len(fn_info)
|
||||
and fn_info[index][0]
|
||||
not in ["api_name", "fn_index", "result_callbacks"]
|
||||
@ -2612,7 +2635,7 @@ Received outputs:
|
||||
parameter_has_default = True
|
||||
parameter_default = component["props"]["value"]
|
||||
elif (
|
||||
dependency["backend_fn"]
|
||||
fn.fn
|
||||
and index < len(fn_info)
|
||||
and fn_info[index][1]
|
||||
and fn_info[index][2] is None
|
||||
@ -2639,10 +2662,10 @@ Received outputs:
|
||||
}
|
||||
)
|
||||
|
||||
outputs = dependency["outputs"]
|
||||
outputs = fn.outputs
|
||||
for o in outputs:
|
||||
for component in config["components"]:
|
||||
if component["id"] == o:
|
||||
if component["id"] == o._id:
|
||||
break
|
||||
else:
|
||||
skip_endpoint = True # if component not found, skip endpoint
|
||||
@ -2650,7 +2673,7 @@ Received outputs:
|
||||
type = component["props"]["name"]
|
||||
if self.blocks[component["id"]].skip_api:
|
||||
continue
|
||||
label = component["props"].get("label", f"value_{o}")
|
||||
label = component["props"].get("label", f"value_{o._id}")
|
||||
comp = self.get_component(component["id"])
|
||||
if not isinstance(comp, components.Component):
|
||||
raise TypeError(f"{comp!r} is not a Component")
|
||||
@ -2670,8 +2693,6 @@ Received outputs:
|
||||
)
|
||||
|
||||
if not skip_endpoint:
|
||||
api_info["named_endpoints"][
|
||||
f"/{dependency['api_name']}"
|
||||
] = dependency_info
|
||||
api_info["named_endpoints"][f"/{fn.api_name}"] = dependency_info
|
||||
|
||||
return api_info
|
||||
|
@ -320,7 +320,7 @@ class EventListener(str):
|
||||
)
|
||||
if _callback:
|
||||
_callback(block)
|
||||
return Dependency(block, dep, dep_index, fn)
|
||||
return Dependency(block, dep.get_config(), dep_index, fn)
|
||||
|
||||
event_trigger.event_name = _event_name
|
||||
event_trigger.has_trigger = _has_trigger
|
||||
@ -445,7 +445,7 @@ def on(
|
||||
trigger_mode=trigger_mode,
|
||||
)
|
||||
set_cancel_events(triggers, cancels)
|
||||
return Dependency(None, dep, dep_index, fn)
|
||||
return Dependency(None, dep.get_config(), dep_index, fn)
|
||||
|
||||
|
||||
class Events:
|
||||
|
@ -300,6 +300,7 @@ class Examples:
|
||||
api_name=self.api_name,
|
||||
show_api=False,
|
||||
)
|
||||
self.load_input_event_id = len(Context.root_block.fns) - 1
|
||||
if self.run_on_click and not self.cache_examples:
|
||||
if self.fn is None:
|
||||
raise ValueError("Cannot run_on_click if no function is provided")
|
||||
@ -496,15 +497,12 @@ class Examples:
|
||||
output = [value[0] for value in output]
|
||||
self.cache_logger.flag(output)
|
||||
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
|
||||
Context.root_block.dependencies.remove(dependency)
|
||||
Context.root_block.fns.pop(fn_index)
|
||||
|
||||
# Remove the original load_input_event and replace it with one that
|
||||
# also populates the input. We do it this way to to allow the cache()
|
||||
# method to be called independently of the create() method
|
||||
index = Context.root_block.dependencies.index(self.load_input_event)
|
||||
Context.root_block.dependencies.pop(index)
|
||||
Context.root_block.fns.pop(index)
|
||||
Context.root_block.fns.pop(self.load_input_event_id)
|
||||
|
||||
def load_example(example_id):
|
||||
processed_example = self.non_none_processed_examples[
|
||||
@ -522,6 +520,7 @@ class Examples:
|
||||
api_name=self.api_name,
|
||||
show_api=False,
|
||||
)
|
||||
self.load_input_event_id = len(Context.root_block.fns) - 1
|
||||
|
||||
def load_from_cache(self, example_id: int) -> list[Any]:
|
||||
"""Loads a particular cached example for the interface.
|
||||
|
@ -167,8 +167,8 @@ class FnIndexInferError(Exception):
|
||||
|
||||
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int:
|
||||
if body.fn_index is None:
|
||||
for i, fn in enumerate(app.get_blocks().dependencies):
|
||||
if fn["api_name"] == api_name:
|
||||
for i, fn in enumerate(app.get_blocks().fns):
|
||||
if fn.api_name == api_name:
|
||||
return i
|
||||
|
||||
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.")
|
||||
@ -185,7 +185,7 @@ def compile_gr_request(
|
||||
):
|
||||
# If this fn_index cancels jobs, then the only input we need is the
|
||||
# current session hash
|
||||
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
|
||||
if app.get_blocks().fns[fn_index_inferred].cancels:
|
||||
body.data = [body.session_hash]
|
||||
if body.request:
|
||||
if body.batched:
|
||||
@ -245,14 +245,14 @@ async def call_process_api(
|
||||
):
|
||||
session_state, iterator = restore_session_state(app=app, body=body)
|
||||
|
||||
dependency = app.get_blocks().dependencies[fn_index_inferred]
|
||||
dependency = app.get_blocks().fns[fn_index_inferred]
|
||||
event_data = prepare_event_data(app.get_blocks(), body)
|
||||
event_id = body.event_id
|
||||
|
||||
session_hash = getattr(body, "session_hash", None)
|
||||
inputs = body.data
|
||||
|
||||
batch_in_single_out = not body.batched and dependency["batch"]
|
||||
batch_in_single_out = not body.batched and dependency.batch
|
||||
if batch_in_single_out:
|
||||
inputs = [inputs]
|
||||
|
||||
|
@ -648,8 +648,8 @@ class App(FastAPI):
|
||||
)
|
||||
unload_fn_indices = [
|
||||
i
|
||||
for i, dep in enumerate(app.get_blocks().dependencies)
|
||||
if any(t for t in dep["targets"] if t[1] == "unload")
|
||||
for i, dep in enumerate(app.get_blocks().fns)
|
||||
if any(t for t in dep.targets if t[1] == "unload")
|
||||
]
|
||||
for fn_index in unload_fn_indices:
|
||||
# The task runnning this loop has been cancelled
|
||||
|
@ -4,7 +4,7 @@ import datetime
|
||||
import os
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Iterator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -62,13 +62,13 @@ class StateHolder:
|
||||
component.delete_callback(value)
|
||||
to_delete.append(component._id)
|
||||
for component in to_delete:
|
||||
del session_state._data[component]
|
||||
del session_state.state_data[component]
|
||||
|
||||
|
||||
class SessionState:
|
||||
def __init__(self, blocks: Blocks):
|
||||
self.blocks = blocks
|
||||
self._data = {}
|
||||
self.blocks_config = copy(blocks.default_config)
|
||||
self.state_data: dict[int, Any] = {}
|
||||
self._state_ttl = {}
|
||||
self.is_closed = False
|
||||
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
|
||||
@ -78,39 +78,45 @@ class SessionState:
|
||||
)
|
||||
|
||||
def __getitem__(self, key: int) -> Any:
|
||||
if key not in self._data:
|
||||
block = self.blocks.blocks[key]
|
||||
if getattr(block, "stateful", False):
|
||||
self._data[key] = deepcopy(getattr(block, "value", None))
|
||||
else:
|
||||
self._data[key] = None
|
||||
return self._data[key]
|
||||
block = self.blocks_config.blocks[key]
|
||||
if block.stateful:
|
||||
if key not in self.state_data:
|
||||
self.state_data[key] = deepcopy(getattr(block, "value", None))
|
||||
return self.state_data[key]
|
||||
else:
|
||||
return block
|
||||
|
||||
def __setitem__(self, key: int, value: Any):
|
||||
from gradio.components import State
|
||||
|
||||
block = self.blocks.blocks[key]
|
||||
block = self.blocks_config.blocks[key]
|
||||
if isinstance(block, State):
|
||||
self._state_ttl[key] = (
|
||||
block.time_to_live,
|
||||
datetime.datetime.now(),
|
||||
)
|
||||
self._data[key] = value
|
||||
self.state_data[key] = value
|
||||
else:
|
||||
self.blocks_config.blocks[key] = value
|
||||
|
||||
def __contains__(self, key: int):
|
||||
return key in self._data
|
||||
block = self.blocks_config.blocks[key]
|
||||
if block.stateful:
|
||||
return key in self.state_data
|
||||
else:
|
||||
return key in self.blocks_config.blocks
|
||||
|
||||
@property
|
||||
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
|
||||
from gradio.components import State
|
||||
|
||||
for id in self._data:
|
||||
block = self.blocks.blocks[id]
|
||||
for id in self.state_data:
|
||||
block = self.blocks_config.blocks[id]
|
||||
if isinstance(block, State) and id in self._state_ttl:
|
||||
time_to_live, created_at = self._state_ttl[id]
|
||||
if self.is_closed:
|
||||
time_to_live = self.STATE_TTL_WHEN_CLOSED
|
||||
value = self._data[id]
|
||||
value = self.state_data[id]
|
||||
yield (
|
||||
block,
|
||||
value,
|
||||
|
@ -814,7 +814,7 @@ def get_cancel_function(
|
||||
for dep in dependencies:
|
||||
if Context.root_block:
|
||||
fn_index = next(
|
||||
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
|
||||
i for i, d in enumerate(Context.root_block.fns) if d.get_config() == dep
|
||||
)
|
||||
fn_to_comp[fn_index] = [
|
||||
Context.root_block.blocks[o] for o in dep["outputs"]
|
||||
|
@ -492,7 +492,9 @@ class TestComponentsInBlocks:
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
components.append(component(value=lambda: None, every=1))
|
||||
assert all(comp.load_event in demo.dependencies for comp in components)
|
||||
assert all(
|
||||
comp.load_event in demo.config["dependencies"] for comp in components
|
||||
)
|
||||
|
||||
|
||||
class TestBlocksPostprocessing:
|
||||
@ -1670,7 +1672,7 @@ def test_emptry_string_api_name_gets_set_as_fn_name():
|
||||
t2 = gr.Textbox()
|
||||
t1.change(test_fn, t1, t2, api_name="")
|
||||
|
||||
assert demo.dependencies[0]["api_name"] == "test_fn"
|
||||
assert demo.fns[0].api_name == "test_fn"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1775,6 +1777,8 @@ def test_time_to_live_and_delete_callback_for_state(capsys, monkeypatch):
|
||||
assert "deleted 2" in captured.out
|
||||
assert "deleted 3" in captured.out
|
||||
for client in [client_1, client_2]:
|
||||
assert len(app.state_holder.session_data[client.session_hash]._data) == 0
|
||||
assert (
|
||||
len(app.state_holder.session_data[client.session_hash].state_data) == 0
|
||||
)
|
||||
finally:
|
||||
demo.close()
|
||||
|
@ -13,10 +13,10 @@ class TestClearButton:
|
||||
textbox = gr.Textbox(scale=3, interactive=True)
|
||||
gr.ClearButton([textbox, chatbot], scale=1)
|
||||
|
||||
clear_event_trigger = demo.dependencies.pop()
|
||||
assert not clear_event_trigger["backend_fn"]
|
||||
assert clear_event_trigger["js"]
|
||||
assert clear_event_trigger["outputs"] == [textbox._id, chatbot._id]
|
||||
clear_event_trigger = demo.fns.pop()
|
||||
assert not clear_event_trigger.fn
|
||||
assert clear_event_trigger.js
|
||||
assert clear_event_trigger.outputs == [textbox, chatbot]
|
||||
|
||||
def test_clear_event_setup_correctly_with_state(self):
|
||||
with gr.Blocks() as demo:
|
||||
@ -24,8 +24,8 @@ class TestClearButton:
|
||||
state = gr.State("")
|
||||
gr.ClearButton([state, chatbot], scale=1)
|
||||
|
||||
clear_event_trigger_state = demo.dependencies.pop()
|
||||
assert clear_event_trigger_state["backend_fn"]
|
||||
clear_event_trigger_state = demo.fns.pop()
|
||||
assert clear_event_trigger_state.fn
|
||||
|
||||
|
||||
class TestOAuthButtons:
|
||||
@ -47,9 +47,9 @@ class TestOAuthButtons:
|
||||
with gr.Blocks() as demo:
|
||||
button = gr.LoginButton()
|
||||
|
||||
login_event = demo.dependencies[0]
|
||||
assert login_event["targets"][0][1] == "click"
|
||||
assert not login_event["backend_fn"] # No Python code
|
||||
assert login_event["js"] # But JS code instead
|
||||
assert login_event["inputs"] == [button._id]
|
||||
assert login_event["outputs"] == []
|
||||
login_event = demo.fns[0]
|
||||
assert login_event.targets[0][1] == "click"
|
||||
assert not login_event.fn # No Python code
|
||||
assert login_event.js # But JS code instead
|
||||
assert login_event.inputs == [button]
|
||||
assert login_event.outputs == []
|
||||
|
@ -77,14 +77,14 @@ class TestInit:
|
||||
|
||||
def test_events_attached(self):
|
||||
chatbot = gr.ChatInterface(double)
|
||||
dependencies = chatbot.dependencies
|
||||
dependencies = chatbot.fns
|
||||
textbox = chatbot.textbox._id
|
||||
submit_btn = chatbot.submit_btn._id
|
||||
assert next(
|
||||
(
|
||||
d
|
||||
for d in dependencies
|
||||
if d["targets"] == [(textbox, "submit"), (submit_btn, "click")]
|
||||
if d.targets == [(textbox, "submit"), (submit_btn, "click")]
|
||||
),
|
||||
None,
|
||||
)
|
||||
@ -94,7 +94,7 @@ class TestInit:
|
||||
chatbot.undo_btn._id,
|
||||
]:
|
||||
assert next(
|
||||
(d for d in dependencies if d["targets"][0] == (btn_id, "click")),
|
||||
(d for d in dependencies if d.targets[0] == (btn_id, "click")),
|
||||
None,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user