Store configs per session in the backend (#8030)

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changeas

* add changeset

* unrelated fix

* Update gradio/blocks.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2024-04-22 11:20:05 -07:00 committed by GitHub
parent 659d3c51ae
commit 91a7a31cd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 318 additions and 283 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Store configs per session in the backend

View File

@ -184,21 +184,21 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
x, BlockContext
) else blocks_telemetry.append(str(x))
for x in blocks.dependencies:
for x in blocks.fns:
targets_telemetry = targets_telemetry + [
# Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks
str(blocks.blocks[y[0]])
for y in x["targets"]
for y in x.targets
if y[0] in blocks.blocks
]
events_telemetry = events_telemetry + [
y[1] for y in x["targets"] if y[0] in blocks.blocks
y[1] for y in x.targets if y[0] in blocks.blocks
]
inputs_telemetry = inputs_telemetry + [
str(blocks.blocks[y]) for y in x["inputs"] if y in blocks.blocks
str(blocks.blocks[y]) for y in x.inputs if y in blocks.blocks
]
outputs_telemetry = outputs_telemetry + [
str(blocks.blocks[y]) for y in x["outputs"] if y in blocks.blocks
str(blocks.blocks[y]) for y in x.outputs if y in blocks.blocks
]
additional_data = {
"version": get_package_version(),

View File

@ -49,7 +49,6 @@ from gradio.events import (
from gradio.exceptions import (
DuplicateBlockError,
InvalidApiNameError,
InvalidBlockError,
InvalidComponentError,
)
from gradio.helpers import create_tracker, skip, special_args
@ -458,11 +457,24 @@ class BlockFunction:
preprocess: bool,
postprocess: bool,
inputs_as_dict: bool,
targets: list[tuple[int | None, str]],
batch: bool = False,
max_batch_size: int = 4,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
tracks_progress: bool = False,
api_name: str | Literal[False] = False,
js: str | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "full",
every: float | None = None,
cancels: list[int] | None = None,
collects_event_data: bool = False,
trigger_after: int | None = None,
trigger_only_on_success: bool = False,
trigger_mode: Literal["always_last", "once", "multiple"] = "once",
queue: bool | None = None,
scroll_to_output: bool = False,
show_api: bool = True,
):
self.fn = fn
self.inputs = inputs
@ -477,7 +489,27 @@ class BlockFunction:
self.total_runtime = 0
self.total_runs = 0
self.inputs_as_dict = inputs_as_dict
self.targets = targets
self.name = getattr(fn, "__name__", "fn") if fn is not None else None
self.api_name = api_name
self.js = js
self.show_progress = show_progress
self.every = every
self.cancels = cancels or []
self.collects_event_data = collects_event_data
self.trigger_after = trigger_after
self.trigger_only_on_success = trigger_only_on_success
self.trigger_mode = trigger_mode
self.queue = False if fn is None else queue
self.scroll_to_output = False if utils.get_space() else scroll_to_output
self.show_api = show_api
self.zero_gpu = hasattr(self.fn, "zerogpu")
self.types_continuous = bool(self.every)
self.types_generator = (
inspect.isgeneratorfunction(self.fn)
or inspect.isasyncgenfunction(self.fn)
or bool(self.every)
)
self.spaces_auto_wrap()
def spaces_auto_wrap(self):
@ -499,6 +531,33 @@ class BlockFunction:
def __repr__(self):
return str(self)
def get_config(self):
return {
"targets": self.targets,
"inputs": [block._id for block in self.inputs],
"outputs": [block._id for block in self.outputs],
"backend_fn": self.fn is not None,
"js": self.js,
"queue": self.queue,
"api_name": self.api_name,
"scroll_to_output": self.scroll_to_output,
"show_progress": self.show_progress,
"every": self.every,
"batch": self.batch,
"max_batch_size": self.max_batch_size,
"cancels": self.cancels,
"types": {
"continuous": self.types_continuous,
"generator": self.types_generator,
},
"collects_event_data": self.collects_event_data,
"trigger_after": self.trigger_after,
"trigger_only_on_success": self.trigger_only_on_success,
"trigger_mode": self.trigger_mode,
"show_api": self.show_api,
"zerogpu": self.zero_gpu,
}
def postprocess_update_dict(
block: Component | BlockContext, update_dict: dict, postprocess: bool = True
@ -552,6 +611,55 @@ def convert_component_dict_to_list(
return predictions
class BlocksConfig:
def __init__(self, root_block: Blocks):
self._id: int = 0
self.root_block = root_block
self.blocks: dict[int, Component | Block] = {}
self.fns: list[BlockFunction] = []
def get_config(self):
config = {}
def get_layout(block):
if not isinstance(block, BlockContext):
return {"id": block._id}
children_layout = []
for child in block.children:
children_layout.append(get_layout(child))
return {"id": block._id, "children": children_layout}
config["layout"] = get_layout(self.root_block)
config["components"] = []
for _id, block in self.blocks.items():
props = block.get_config() if hasattr(block, "get_config") else {}
block_config = {
"id": _id,
"type": block.get_block_name(),
"props": utils.delete_none(props),
"skip_api": block.skip_api,
"component_class_id": getattr(block, "component_class_id", None),
}
if not block.skip_api:
block_config["api_info"] = block.api_info() # type: ignore
# .example_inputs() has been renamed .example_payload() but
# we use the old name for backwards compatibility with custom components
# created on Gradio 4.20.0 or earlier
block_config["example_inputs"] = block.example_inputs() # type: ignore
config["components"].append(block_config)
config["dependencies"] = [fn.get_config() for fn in self.fns]
return config
def __copy__(self):
new = BlocksConfig(self.root_block)
new.blocks = copy.copy(self.blocks)
new.fns = copy.copy(self.fns)
return new
@document("launch", "queue", "integrate", "load", "unload")
class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
"""
@ -665,12 +773,11 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
t.start()
else:
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
super().__init__(render=False, **kwargs)
self.blocks: dict[int, Component | Block] = {}
self.fns: list[BlockFunction] = []
self.dependencies = []
self.mode = mode
self.default_config = BlocksConfig(self)
super().__init__(render=False, **kwargs)
self.mode = mode
self.is_running = False
self.local_url = None
self.share_url = None
@ -718,6 +825,22 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.queue()
@property
def blocks(self) -> dict[int, Component | Block]:
return self.default_config.blocks
@blocks.setter
def blocks(self, value: dict[int, Component | Block]):
self.default_config.blocks = value
@property
def fns(self) -> list[BlockFunction]:
return self.default_config.fns
@fns.setter
def fns(self, value: list[BlockFunction]):
self.default_config.fns = value
def get_component(self, id: int) -> Component | BlockContext:
comp = self.blocks[id]
if not isinstance(comp, (components.Component, BlockContext)):
@ -850,7 +973,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
original_mapping[o] for o in dependency["outputs"]
]
dependency.pop("status_tracker", None)
dependency.pop("zerogpu")
dependency.pop("zerogpu", None)
dependency["preprocess"] = False
dependency["postprocess"] = False
if is_then_event:
@ -886,12 +1009,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
# Allows some use of Interface-specific methods with loaded Spaces
if first_dependency and Context.root_block:
blocks.predict = [fns[0]]
blocks.input_components = [
Context.root_block.blocks[i] for i in first_dependency["inputs"]
]
blocks.output_components = [
Context.root_block.blocks[o] for o in first_dependency["outputs"]
]
blocks.input_components = first_dependency.inputs
blocks.output_components = first_dependency.outputs
blocks.__name__ = "Interface"
blocks.api_mode = True
blocks.proxy_urls = proxy_urls
@ -901,19 +1020,19 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
return self.__repr__()
def __repr__(self):
num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
num_backend_fns = len([d for d in self.fns if d.fn])
repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
repr += f"\n{'-' * len(repr)}"
for d, dependency in enumerate(self.dependencies):
if dependency["backend_fn"]:
for d, dependency in enumerate(self.fns):
if dependency.fn:
repr += f"\nfn_index={d}"
repr += "\n inputs:"
for input_id in dependency["inputs"]:
block = self.blocks[input_id]
for block in dependency.inputs:
block = self.blocks[block._id]
repr += f"\n |-{block}"
repr += "\n outputs:"
for output_id in dependency["outputs"]:
block = self.blocks[output_id]
for block in dependency.outputs:
block = self.blocks[block._id]
repr += f"\n |-{block}"
return repr
@ -987,7 +1106,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
show_api: bool = True,
) -> tuple[dict[str, Any], int]:
) -> tuple[BlockFunction, int]:
"""
Adds an event to the component's dependencies.
Parameters:
@ -1075,21 +1194,6 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
_, progress_index, event_data_index = (
special_args(fn) if fn else (None, None, None)
)
self.fns.append(
BlockFunction(
fn,
inputs,
outputs,
preprocess,
postprocess,
inputs_as_dict=inputs_as_dict,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
batch=batch,
max_batch_size=max_batch_size,
tracks_progress=progress_index is not None,
)
)
# If api_name is None or empty string, use the function name
if api_name is None or isinstance(api_name, str) and api_name.strip() == "":
@ -1113,7 +1217,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
if api_name is not False:
api_name = utils.append_unique_suffix(
api_name, [dep["api_name"] for dep in self.dependencies]
api_name,
[fn.api_name for fn in self.fns if isinstance(fn.api_name, str)],
)
else:
show_api = False
@ -1124,35 +1229,35 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
if collects_event_data is None:
collects_event_data = event_data_index is not None
dependency = {
"targets": _targets,
"inputs": [block._id for block in inputs],
"outputs": [block._id for block in outputs],
"backend_fn": fn is not None,
"js": js,
"queue": False if fn is None else queue,
"api_name": api_name,
"scroll_to_output": False if utils.get_space() else scroll_to_output,
"show_progress": show_progress,
"every": every,
"batch": batch,
"max_batch_size": max_batch_size,
"cancels": cancels or [],
"types": {
"continuous": bool(every),
"generator": inspect.isgeneratorfunction(fn)
or inspect.isasyncgenfunction(fn)
or bool(every),
},
"collects_event_data": collects_event_data,
"trigger_after": trigger_after,
"trigger_only_on_success": trigger_only_on_success,
"trigger_mode": trigger_mode,
"show_api": show_api,
"zerogpu": hasattr(fn, "zerogpu"),
}
self.dependencies.append(dependency)
return dependency, len(self.dependencies) - 1
block_fn = BlockFunction(
fn,
inputs,
outputs,
preprocess,
postprocess,
inputs_as_dict=inputs_as_dict,
targets=_targets,
batch=batch,
max_batch_size=max_batch_size,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
tracks_progress=progress_index is not None,
api_name=api_name,
js=js,
show_progress=show_progress,
every=every,
cancels=cancels,
collects_event_data=collects_event_data,
trigger_after=trigger_after,
trigger_only_on_success=trigger_only_on_success,
trigger_mode=trigger_mode,
queue=queue,
scroll_to_output=scroll_to_output,
show_api=show_api,
)
self.fns.append(block_fn)
return block_fn, len(self.fns) - 1
def render(self):
if Context.root_block is not None:
@ -1169,40 +1274,34 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
)
Context.root_block.blocks.update(self.blocks)
Context.root_block.fns.extend(self.fns)
dependency_offset = len(Context.root_block.dependencies)
for i, dependency in enumerate(self.dependencies):
api_name = dependency["api_name"]
if api_name is not None and api_name is not False:
dependency_offset = len(Context.root_block.fns)
existing_api_names = [
dep.api_name
for dep in Context.root_block.fns
if isinstance(dep.api_name, str)
]
for dependency in self.fns:
api_name = dependency.api_name
if isinstance(api_name, str):
api_name_ = utils.append_unique_suffix(
api_name,
[dep["api_name"] for dep in Context.root_block.dependencies],
existing_api_names,
)
if api_name != api_name_:
dependency["api_name"] = api_name_
dependency["cancels"] = [
c + dependency_offset for c in dependency["cancels"]
]
if dependency.get("trigger_after") is not None:
dependency["trigger_after"] += dependency_offset
dependency.api_name = api_name_
dependency.cancels = [c + dependency_offset for c in dependency.cancels]
if dependency.trigger_after is not None:
dependency.trigger_after += dependency_offset
# Recreate the cancel function so that it has the latest
# dependency fn indices. This is necessary to properly cancel
# events in the backend
if dependency["cancels"]:
if dependency.cancels:
updated_cancels = [
Context.root_block.dependencies[i]
for i in dependency["cancels"]
Context.root_block.fns[i].get_config()
for i in dependency.cancels
]
new_fn = BlockFunction(
get_cancel_function(updated_cancels)[0],
[],
[],
False,
True,
False,
)
Context.root_block.fns[dependency_offset + i] = new_fn
Context.root_block.dependencies.append(dependency)
dependency.fn = get_cancel_function(updated_cancels)[0]
Context.root_block.fns.append(dependency)
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
Context.root_block.proxy_urls.update(self.proxy_urls)
@ -1213,20 +1312,16 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
def is_callable(self, fn_index: int = 0) -> bool:
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
dependency = self.fns[fn_index]
if inspect.isasyncgenfunction(block_fn.fn):
return False
if inspect.isgeneratorfunction(block_fn.fn):
return False
for input_id in dependency["inputs"]:
block = self.blocks[input_id]
if getattr(block, "stateful", False):
return False
for output_id in dependency["outputs"]:
block = self.blocks[output_id]
if getattr(block, "stateful", False):
return False
if any(block.stateful for block in dependency.inputs):
return False
if any(block.stateful for block in dependency.outputs):
return False
return True
@ -1243,11 +1338,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
"""
if api_name is not None:
inferred_fn_index = next(
(
i
for i, d in enumerate(self.dependencies)
if d.get("api_name") == api_name
),
(i for i, d in enumerate(self.fns) if d.api_name == api_name),
None,
)
if inferred_fn_index is None:
@ -1262,7 +1353,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
inputs = list(inputs)
processed_inputs = self.serialize_data(fn_index, inputs)
batch = self.dependencies[fn_index]["batch"]
batch = self.fns[fn_index].batch
if batch:
processed_inputs = [[inp] for inp in processed_inputs]
@ -1352,7 +1443,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
prediction = await utils.async_iteration(iterator)
is_generating = True
except StopAsyncIteration:
n_outputs = len(self.dependencies[fn_index].get("outputs"))
n_outputs = len(self.fns[fn_index].outputs)
prediction = (
components._Keywords.FINISHED_ITERATING
if n_outputs == 1
@ -1370,22 +1461,16 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
}
def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
dependency = self.dependencies[fn_index]
dependency = self.fns[fn_index]
processed_input = []
def format_file(s):
return FileData(path=s).model_dump()
for i, input_id in enumerate(dependency["inputs"]):
try:
block = self.blocks[input_id]
except KeyError as e:
raise InvalidBlockError(
f"Input component with id {input_id} used in {dependency['trigger']}() event is not defined in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
for i, block in enumerate(dependency.inputs):
if not isinstance(block, components.Component):
raise InvalidComponentError(
f"{block.__class__} Component with id {input_id} not a valid input component."
f"{block.__class__} Component not a valid input component."
)
api_info = block.api_info()
if client_utils.value_is_file(api_info):
@ -1402,19 +1487,13 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
return processed_input
def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
dependency = self.dependencies[fn_index]
dependency = self.fns[fn_index]
predictions = []
for o, output_id in enumerate(dependency["outputs"]):
try:
block = self.blocks[output_id]
except KeyError as e:
raise InvalidBlockError(
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
for o, block in enumerate(dependency.outputs):
if not isinstance(block, components.Component):
raise InvalidComponentError(
f"{block.__class__} Component with id {output_id} not a valid output component."
f"{block.__class__} Component not a valid output component."
)
deserialized = client_utils.traverse(
@ -1426,9 +1505,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
def validate_inputs(self, fn_index: int, inputs: list[Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
dep_inputs = dependency["inputs"]
dep_inputs = block_fn.inputs
# This handles incorrect inputs when args are changed by a JS function
# Only check not enough args case, ignore extra arguments (for now)
@ -1442,8 +1520,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
wanted_args = []
received_args = []
for input_id in dep_inputs:
block = self.blocks[input_id]
for block in dep_inputs:
wanted_args.append(str(block))
for inp in inputs:
v = f'"{inp}"' if isinstance(inp, str) else str(inp)
@ -1471,28 +1548,21 @@ Received inputs:
):
state = state or SessionState(self)
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
self.validate_inputs(fn_index, inputs)
if block_fn.preprocess:
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
try:
block = self.blocks[input_id]
except KeyError as e:
raise InvalidBlockError(
f"Input component with id {input_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
for i, block in enumerate(block_fn.inputs):
if not isinstance(block, components.Component):
raise InvalidComponentError(
f"{block.__class__} Component with id {input_id} not a valid input component."
f"{block.__class__} Component not a valid input component."
)
if getattr(block, "stateful", False):
processed_input.append(state[input_id])
if block.stateful:
processed_input.append(state[block._id])
else:
if input_id in state:
block = state[input_id]
if block._id in state:
block = state[block._id]
inputs_cached = await processing_utils.async_move_files_to_cache(
inputs[i],
block,
@ -1510,9 +1580,9 @@ Received inputs:
def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
dependency = self.fns[fn_index]
dep_outputs = dependency["outputs"]
dep_outputs = dependency.outputs
if not isinstance(predictions, (list, tuple)):
predictions = [predictions]
@ -1526,8 +1596,7 @@ Received inputs:
wanted_args = []
received_args = []
for output_id in dep_outputs:
block = self.blocks[output_id]
for block in dep_outputs:
wanted_args.append(str(block))
for pred in predictions:
v = f'"{pred}"' if isinstance(pred, str) else str(pred)
@ -1549,15 +1618,13 @@ Received outputs:
):
state = state or SessionState(self)
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
batch = dependency["batch"]
if isinstance(predictions, dict) and len(predictions) > 0:
predictions = convert_component_dict_to_list(
dependency["outputs"], predictions
[block._id for block in block_fn.outputs], predictions
)
if len(dependency["outputs"]) == 1 and not (batch):
if len(block_fn.outputs) == 1 and not block_fn.batch:
predictions = [
predictions,
]
@ -1565,7 +1632,7 @@ Received outputs:
self.validate_outputs(fn_index, predictions) # type: ignore
output = []
for i, output_id in enumerate(dependency["outputs"]):
for i, block in enumerate(block_fn.outputs):
try:
if predictions[i] is components._Keywords.FINISHED_ITERATING:
output.append(None)
@ -1576,16 +1643,9 @@ Received outputs:
f"of values returned from from function {block_fn.name}"
) from err
try:
block = self.blocks[output_id]
except KeyError as e:
raise InvalidBlockError(
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
if block.stateful:
if not utils.is_update(predictions[i]):
state[output_id] = predictions[i]
state[block._id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
@ -1600,28 +1660,25 @@ Received outputs:
prediction_value = prediction_value.constructor_args.copy()
prediction_value["__type__"] = "update"
if utils.is_update(prediction_value):
if output_id in state:
kwargs = state[output_id].constructor_args.copy()
else:
kwargs = self.blocks[output_id].constructor_args.copy()
kwargs = state[block._id].constructor_args.copy()
kwargs.update(prediction_value)
kwargs.pop("value", None)
kwargs.pop("__type__")
kwargs["render"] = False
state[output_id] = self.blocks[output_id].__class__(**kwargs)
state[block._id] = block.__class__(**kwargs)
prediction_value = postprocess_update_dict(
block=state[output_id],
block=state[block._id],
update_dict=prediction_value,
postprocess=block_fn.postprocess,
)
elif block_fn.postprocess:
if not isinstance(block, components.Component):
raise InvalidComponentError(
f"{block.__class__} Component with id {output_id} not a valid output component."
f"{block.__class__} Component not a valid output component."
)
if output_id in state:
block = state[output_id]
if block._id in state:
block = state[block._id]
prediction_value = block.postprocess(prediction_value)
outputs_cached = await processing_utils.async_move_files_to_cache(
@ -1647,8 +1704,8 @@ Received outputs:
self.pending_streams[session_hash][run] = {}
stream_run = self.pending_streams[session_hash][run]
for i, output_id in enumerate(self.dependencies[fn_index]["outputs"]):
block = self.blocks[output_id]
for i, block in enumerate(self.fns[fn_index].outputs):
output_id = block._id
if isinstance(block, components.StreamingOutput) and block.streaming:
first_chunk = output_id not in stream_run
binary_data, output_data = block.stream_output(
@ -1686,7 +1743,7 @@ Received outputs:
self.pending_diff_streams[session_hash][run] = [None] * len(data)
last_diffs = self.pending_diff_streams[session_hash][run]
for i in range(len(self.dependencies[fn_index]["outputs"])):
for i in range(len(self.fns[fn_index].outputs)):
if final:
data[i] = last_diffs[i]
continue
@ -1736,10 +1793,10 @@ Received outputs:
Returns: None
"""
block_fn = self.fns[fn_index]
batch = self.dependencies[fn_index]["batch"]
batch = self.fns[fn_index].batch
if batch:
max_batch_size = self.dependencies[fn_index]["max_batch_size"]
max_batch_size = self.fns[fn_index].max_batch_size
batch_sizes = [len(inp) for inp in inputs]
batch_size = batch_sizes[0]
if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
@ -1871,37 +1928,8 @@ Received outputs:
},
"fill_height": self.fill_height,
}
config.update(self.default_config.get_config())
def get_layout(block):
if not isinstance(block, BlockContext):
return {"id": block._id}
children_layout = []
for child in block.children:
children_layout.append(get_layout(child))
return {"id": block._id, "children": children_layout}
config["layout"] = get_layout(self)
for _id, block in self.blocks.items():
props = block.get_config() if hasattr(block, "get_config") else {}
block_config = {
"id": _id,
"type": block.get_block_name(),
"props": utils.delete_none(props),
}
block_config["skip_api"] = block.skip_api
block_config["component_class_id"] = getattr(
block, "component_class_id", None
)
if not block.skip_api:
block_config["api_info"] = block.api_info() # type: ignore
# .example_inputs() has been renamed .example_payload() but
# we use the old name for backwards compatibility with custom components
# created on Gradio 4.20.0 or earlier
block_config["example_inputs"] = block.example_inputs() # type: ignore
config["components"].append(block_config)
config["dependencies"] = self.dependencies
return config
def __enter__(self):
@ -1932,9 +1960,8 @@ Received outputs:
def clear(self):
"""Resets the layout of the Blocks object."""
self.blocks = {}
self.fns = []
self.dependencies = []
self.default_config.blocks = {}
self.default_config.fns = []
self.children = []
return self
@ -1988,8 +2015,8 @@ Received outputs:
return self
def validate_queue_settings(self):
for dep in self.dependencies:
for i in dep["cancels"]:
for dep in self.fns:
for i in dep.cancels:
if not self.queue_enabled_for_fn(i):
raise ValueError(
"Queue needs to be enabled! "
@ -1998,7 +2025,7 @@ Received outputs:
"another event without enabling the queue. Both can be solved by calling .queue() "
"before .launch()"
)
if dep["batch"] and dep["queue"] is False:
if dep.batch and dep.queue is False:
raise ValueError("In order to use batching, the queue must be enabled.")
def launch(
@ -2540,7 +2567,7 @@ Received outputs:
queue=False if every is None else None,
every=every,
)[0]
component.load_event = dep
component.load_event = dep.get_config()
def startup_events(self):
"""Events that should be run when the app containing this block starts up."""
@ -2551,7 +2578,7 @@ Received outputs:
self.create_limiter()
def queue_enabled_for_fn(self, fn_index: int):
return self.dependencies[fn_index]["queue"] is not False
return self.fns[fn_index].queue is not False
def get_api_info(self):
"""
@ -2560,22 +2587,18 @@ Received outputs:
config = self.config
api_info = {"named_endpoints": {}, "unnamed_endpoints": {}}
for dependency, fn in zip(config["dependencies"], self.fns):
if (
not dependency["backend_fn"]
or not dependency["show_api"]
or dependency["api_name"] is False
):
for fn in self.fns:
if not fn.fn or not fn.show_api or fn.api_name is False:
continue
dependency_info = {"parameters": [], "returns": []}
fn_info = utils.get_function_params(fn.fn) # type: ignore
skip_endpoint = False
inputs = dependency["inputs"]
for index, input_id in enumerate(inputs):
inputs = fn.inputs
for index, input_block in enumerate(inputs):
for component in config["components"]:
if component["id"] == input_id:
if component["id"] == input_block._id:
break
else:
skip_endpoint = True # if component not found, skip endpoint
@ -2583,7 +2606,7 @@ Received outputs:
type = component["props"]["name"]
if self.blocks[component["id"]].skip_api:
continue
label = component["props"].get("label", f"parameter_{input_id}")
label = component["props"].get("label", f"parameter_{input_block._id}")
comp = self.get_component(component["id"])
if not isinstance(comp, components.Component):
raise TypeError(f"{comp!r} is not a Component")
@ -2595,7 +2618,7 @@ Received outputs:
# "result_callbacks" to specify the callbacks, we need to make sure that no parameters
# have those names. Hence the final checks.
if (
dependency["backend_fn"]
fn.fn
and index < len(fn_info)
and fn_info[index][0]
not in ["api_name", "fn_index", "result_callbacks"]
@ -2612,7 +2635,7 @@ Received outputs:
parameter_has_default = True
parameter_default = component["props"]["value"]
elif (
dependency["backend_fn"]
fn.fn
and index < len(fn_info)
and fn_info[index][1]
and fn_info[index][2] is None
@ -2639,10 +2662,10 @@ Received outputs:
}
)
outputs = dependency["outputs"]
outputs = fn.outputs
for o in outputs:
for component in config["components"]:
if component["id"] == o:
if component["id"] == o._id:
break
else:
skip_endpoint = True # if component not found, skip endpoint
@ -2650,7 +2673,7 @@ Received outputs:
type = component["props"]["name"]
if self.blocks[component["id"]].skip_api:
continue
label = component["props"].get("label", f"value_{o}")
label = component["props"].get("label", f"value_{o._id}")
comp = self.get_component(component["id"])
if not isinstance(comp, components.Component):
raise TypeError(f"{comp!r} is not a Component")
@ -2670,8 +2693,6 @@ Received outputs:
)
if not skip_endpoint:
api_info["named_endpoints"][
f"/{dependency['api_name']}"
] = dependency_info
api_info["named_endpoints"][f"/{fn.api_name}"] = dependency_info
return api_info

View File

@ -320,7 +320,7 @@ class EventListener(str):
)
if _callback:
_callback(block)
return Dependency(block, dep, dep_index, fn)
return Dependency(block, dep.get_config(), dep_index, fn)
event_trigger.event_name = _event_name
event_trigger.has_trigger = _has_trigger
@ -445,7 +445,7 @@ def on(
trigger_mode=trigger_mode,
)
set_cancel_events(triggers, cancels)
return Dependency(None, dep, dep_index, fn)
return Dependency(None, dep.get_config(), dep_index, fn)
class Events:

View File

@ -300,6 +300,7 @@ class Examples:
api_name=self.api_name,
show_api=False,
)
self.load_input_event_id = len(Context.root_block.fns) - 1
if self.run_on_click and not self.cache_examples:
if self.fn is None:
raise ValueError("Cannot run_on_click if no function is provided")
@ -496,15 +497,12 @@ class Examples:
output = [value[0] for value in output]
self.cache_logger.flag(output)
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)
# Remove the original load_input_event and replace it with one that
# also populates the input. We do it this way to to allow the cache()
# method to be called independently of the create() method
index = Context.root_block.dependencies.index(self.load_input_event)
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)
Context.root_block.fns.pop(self.load_input_event_id)
def load_example(example_id):
processed_example = self.non_none_processed_examples[
@ -522,6 +520,7 @@ class Examples:
api_name=self.api_name,
show_api=False,
)
self.load_input_event_id = len(Context.root_block.fns) - 1
def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.

View File

@ -167,8 +167,8 @@ class FnIndexInferError(Exception):
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int:
if body.fn_index is None:
for i, fn in enumerate(app.get_blocks().dependencies):
if fn["api_name"] == api_name:
for i, fn in enumerate(app.get_blocks().fns):
if fn.api_name == api_name:
return i
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.")
@ -185,7 +185,7 @@ def compile_gr_request(
):
# If this fn_index cancels jobs, then the only input we need is the
# current session hash
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
if app.get_blocks().fns[fn_index_inferred].cancels:
body.data = [body.session_hash]
if body.request:
if body.batched:
@ -245,14 +245,14 @@ async def call_process_api(
):
session_state, iterator = restore_session_state(app=app, body=body)
dependency = app.get_blocks().dependencies[fn_index_inferred]
dependency = app.get_blocks().fns[fn_index_inferred]
event_data = prepare_event_data(app.get_blocks(), body)
event_id = body.event_id
session_hash = getattr(body, "session_hash", None)
inputs = body.data
batch_in_single_out = not body.batched and dependency["batch"]
batch_in_single_out = not body.batched and dependency.batch
if batch_in_single_out:
inputs = [inputs]

View File

@ -648,8 +648,8 @@ class App(FastAPI):
)
unload_fn_indices = [
i
for i, dep in enumerate(app.get_blocks().dependencies)
if any(t for t in dep["targets"] if t[1] == "unload")
for i, dep in enumerate(app.get_blocks().fns)
if any(t for t in dep.targets if t[1] == "unload")
]
for fn_index in unload_fn_indices:
# The task runnning this loop has been cancelled

View File

@ -4,7 +4,7 @@ import datetime
import os
import threading
from collections import OrderedDict
from copy import deepcopy
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any, Iterator
if TYPE_CHECKING:
@ -62,13 +62,13 @@ class StateHolder:
component.delete_callback(value)
to_delete.append(component._id)
for component in to_delete:
del session_state._data[component]
del session_state.state_data[component]
class SessionState:
def __init__(self, blocks: Blocks):
self.blocks = blocks
self._data = {}
self.blocks_config = copy(blocks.default_config)
self.state_data: dict[int, Any] = {}
self._state_ttl = {}
self.is_closed = False
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
@ -78,39 +78,45 @@ class SessionState:
)
def __getitem__(self, key: int) -> Any:
if key not in self._data:
block = self.blocks.blocks[key]
if getattr(block, "stateful", False):
self._data[key] = deepcopy(getattr(block, "value", None))
else:
self._data[key] = None
return self._data[key]
block = self.blocks_config.blocks[key]
if block.stateful:
if key not in self.state_data:
self.state_data[key] = deepcopy(getattr(block, "value", None))
return self.state_data[key]
else:
return block
def __setitem__(self, key: int, value: Any):
from gradio.components import State
block = self.blocks.blocks[key]
block = self.blocks_config.blocks[key]
if isinstance(block, State):
self._state_ttl[key] = (
block.time_to_live,
datetime.datetime.now(),
)
self._data[key] = value
self.state_data[key] = value
else:
self.blocks_config.blocks[key] = value
def __contains__(self, key: int):
return key in self._data
block = self.blocks_config.blocks[key]
if block.stateful:
return key in self.state_data
else:
return key in self.blocks_config.blocks
@property
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
from gradio.components import State
for id in self._data:
block = self.blocks.blocks[id]
for id in self.state_data:
block = self.blocks_config.blocks[id]
if isinstance(block, State) and id in self._state_ttl:
time_to_live, created_at = self._state_ttl[id]
if self.is_closed:
time_to_live = self.STATE_TTL_WHEN_CLOSED
value = self._data[id]
value = self.state_data[id]
yield (
block,
value,

View File

@ -814,7 +814,7 @@ def get_cancel_function(
for dep in dependencies:
if Context.root_block:
fn_index = next(
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
i for i, d in enumerate(Context.root_block.fns) if d.get_config() == dep
)
fn_to_comp[fn_index] = [
Context.root_block.blocks[o] for o in dep["outputs"]

View File

@ -492,7 +492,9 @@ class TestComponentsInBlocks:
with gr.Blocks() as demo:
for component in io_components:
components.append(component(value=lambda: None, every=1))
assert all(comp.load_event in demo.dependencies for comp in components)
assert all(
comp.load_event in demo.config["dependencies"] for comp in components
)
class TestBlocksPostprocessing:
@ -1670,7 +1672,7 @@ def test_emptry_string_api_name_gets_set_as_fn_name():
t2 = gr.Textbox()
t1.change(test_fn, t1, t2, api_name="")
assert demo.dependencies[0]["api_name"] == "test_fn"
assert demo.fns[0].api_name == "test_fn"
@pytest.mark.asyncio
@ -1775,6 +1777,8 @@ def test_time_to_live_and_delete_callback_for_state(capsys, monkeypatch):
assert "deleted 2" in captured.out
assert "deleted 3" in captured.out
for client in [client_1, client_2]:
assert len(app.state_holder.session_data[client.session_hash]._data) == 0
assert (
len(app.state_holder.session_data[client.session_hash].state_data) == 0
)
finally:
demo.close()

View File

@ -13,10 +13,10 @@ class TestClearButton:
textbox = gr.Textbox(scale=3, interactive=True)
gr.ClearButton([textbox, chatbot], scale=1)
clear_event_trigger = demo.dependencies.pop()
assert not clear_event_trigger["backend_fn"]
assert clear_event_trigger["js"]
assert clear_event_trigger["outputs"] == [textbox._id, chatbot._id]
clear_event_trigger = demo.fns.pop()
assert not clear_event_trigger.fn
assert clear_event_trigger.js
assert clear_event_trigger.outputs == [textbox, chatbot]
def test_clear_event_setup_correctly_with_state(self):
with gr.Blocks() as demo:
@ -24,8 +24,8 @@ class TestClearButton:
state = gr.State("")
gr.ClearButton([state, chatbot], scale=1)
clear_event_trigger_state = demo.dependencies.pop()
assert clear_event_trigger_state["backend_fn"]
clear_event_trigger_state = demo.fns.pop()
assert clear_event_trigger_state.fn
class TestOAuthButtons:
@ -47,9 +47,9 @@ class TestOAuthButtons:
with gr.Blocks() as demo:
button = gr.LoginButton()
login_event = demo.dependencies[0]
assert login_event["targets"][0][1] == "click"
assert not login_event["backend_fn"] # No Python code
assert login_event["js"] # But JS code instead
assert login_event["inputs"] == [button._id]
assert login_event["outputs"] == []
login_event = demo.fns[0]
assert login_event.targets[0][1] == "click"
assert not login_event.fn # No Python code
assert login_event.js # But JS code instead
assert login_event.inputs == [button]
assert login_event.outputs == []

View File

@ -77,14 +77,14 @@ class TestInit:
def test_events_attached(self):
chatbot = gr.ChatInterface(double)
dependencies = chatbot.dependencies
dependencies = chatbot.fns
textbox = chatbot.textbox._id
submit_btn = chatbot.submit_btn._id
assert next(
(
d
for d in dependencies
if d["targets"] == [(textbox, "submit"), (submit_btn, "click")]
if d.targets == [(textbox, "submit"), (submit_btn, "click")]
),
None,
)
@ -94,7 +94,7 @@ class TestInit:
chatbot.undo_btn._id,
]:
assert next(
(d for d in dependencies if d["targets"][0] == (btn_id, "click")),
(d for d in dependencies if d.targets[0] == (btn_id, "click")),
None,
)