diff --git a/.changeset/easy-trains-obey.md b/.changeset/easy-trains-obey.md new file mode 100644 index 0000000000..0188726245 --- /dev/null +++ b/.changeset/easy-trains-obey.md @@ -0,0 +1,6 @@ +--- +"@gradio/client": patch +"gradio": patch +--- + +fix:Connect heartbeat if state created in render. Also fix config cleanup bug #8407 diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 9efa11b503..3ff5e7f6a7 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -150,38 +150,9 @@ export class Client { await this.resolve_cookies(); } - await this._resolve_config().then(async ({ config }) => { - if (config) { - this.config = config; - if (this.config && this.config.connect_heartbeat) { - if (this.config.space_id && this.options.hf_token) { - this.jwt = await get_jwt( - this.config.space_id, - this.options.hf_token, - this.cookies - ); - } - } - } - - if (config.space_id && this.options.hf_token) { - this.jwt = await get_jwt(config.space_id, this.options.hf_token); - } - - if (this.config && this.config.connect_heartbeat) { - // connect to the heartbeat endpoint via GET request - const heartbeat_url = new URL( - `${this.config.root}/heartbeat/${this.session_hash}` - ); - - // if the jwt is available, add it to the query params - if (this.jwt) { - heartbeat_url.searchParams.set("__sign", this.jwt); - } - - this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540 - } - }); + await this._resolve_config().then(({ config }) => + this._resolve_hearbeat(config) + ); } catch (e: any) { throw Error(e); } @@ -190,6 +161,43 @@ export class Client { this.api_map = map_names_to_ids(this.config?.dependencies || []); } + async _resolve_hearbeat(_config: Config): Promise { + if (_config) { + this.config = _config; + if (this.config && this.config.connect_heartbeat) { + if (this.config.space_id && this.options.hf_token) { + this.jwt = await get_jwt( + this.config.space_id, + this.options.hf_token, + this.cookies + ); + } + } + } + + if (_config.space_id && this.options.hf_token) { + this.jwt = await get_jwt(_config.space_id, this.options.hf_token); + } + + if (this.config && this.config.connect_heartbeat) { + // connect to the heartbeat endpoint via GET request + const heartbeat_url = new URL( + `${this.config.root}/heartbeat/${this.session_hash}` + ); + + // if the jwt is available, add it to the query params + if (this.jwt) { + heartbeat_url.searchParams.set("__sign", this.jwt); + } + + // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540 + if (!this.heartbeat_event) + this.heartbeat_event = await this.stream(heartbeat_url); + } else { + this.heartbeat_event?.close(); + } + } + static async connect( app_reference: string, options: ClientOptions = {} diff --git a/client/js/src/utils/submit.ts b/client/js/src/utils/submit.ts index ad650ca35e..80ca94b783 100644 --- a/client/js/src/utils/submit.ts +++ b/client/js/src/utils/submit.ts @@ -162,17 +162,27 @@ export function submit( } } - function handle_render_config(render_config: any): void { + const resolve_heartbeat = async (config: Config): Promise => { + await this._resolve_hearbeat(config); + }; + + async function handle_render_config(render_config: any): Promise { if (!config) return; let render_id: number = render_config.render_id; config.components = [ - ...config.components.filter((c) => c.rendered_in !== render_id), + ...config.components.filter((c) => c.props.rendered_in !== render_id), ...render_config.components ]; config.dependencies = [ ...config.dependencies.filter((d) => d.rendered_in !== render_id), ...render_config.dependencies ]; + const any_state = config.components.some((c) => c.type === "state"); + const any_unload = config.dependencies.some((d) => + d.targets.some((t) => t[1] === "unload") + ); + config.connect_heartbeat = any_state || any_unload; + await resolve_heartbeat(config); fire_event({ type: "render", data: render_config, @@ -628,7 +638,7 @@ export function submit( fn_index }); if (data.render_config) { - handle_render_config(data.render_config); + await handle_render_config(data.render_config); } if (complete) { diff --git a/gradio/blocks.py b/gradio/blocks.py index 2bf9626adb..9675e6bb29 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -903,7 +903,6 @@ class BlocksConfig: if renderable is None or fn.rendered_in == renderable: dependencies.append(fn.get_config()) config["dependencies"] = dependencies - return config def __copy__(self): @@ -2043,18 +2042,9 @@ Received outputs: "fill_height": self.fill_height, } config.update(self.default_config.get_config()) - any_state = any( - isinstance(block, components.State) for block in self.blocks.values() + config["connect_heartbeat"] = utils.connect_heartbeat( + config, self.blocks.values() ) - any_unload = False - for dep in config["dependencies"]: - for target in dep["targets"]: - if isinstance(target, (list, tuple)) and len(target) == 2: - any_unload = target[1] == "unload" - if any_unload: - break - config["connect_heartbeat"] = any_state or any_unload - return config def __enter__(self): diff --git a/gradio/utils.py b/gradio/utils.py index 4524a2d599..df6cc313cf 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1381,3 +1381,17 @@ def _parse_file_size(size: str | int | None) -> int | None: if not multiple: raise ValueError(f"Invalid file size unit: {unit}") return multiple * size_int + + +def connect_heartbeat(config: dict[str, Any], blocks) -> bool: + from gradio.components import State + + any_state = any(isinstance(block, State) for block in blocks) + any_unload = False + for dep in config["dependencies"]: + for target in dep["targets"]: + if isinstance(target, (list, tuple)) and len(target) == 2: + any_unload = target[1] == "unload" + if any_unload: + break + return any_state or any_unload