Connect heartbeat if state created in render. Also fix config cleanup bug #8407 (#8408)

* Add code

* add changeset

* add changeset

* lint

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-05-29 17:49:03 -04:00 committed by GitHub
parent 8028c33bbc
commit e86dd01b6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 75 additions and 47 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---
fix:Connect heartbeat if state created in render. Also fix config cleanup bug #8407

View File

@ -150,38 +150,9 @@ export class Client {
await this.resolve_cookies();
}
await this._resolve_config().then(async ({ config }) => {
if (config) {
this.config = config;
if (this.config && this.config.connect_heartbeat) {
if (this.config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(
this.config.space_id,
this.options.hf_token,
this.cookies
);
}
}
}
if (config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(config.space_id, this.options.hf_token);
}
if (this.config && this.config.connect_heartbeat) {
// connect to the heartbeat endpoint via GET request
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
);
// if the jwt is available, add it to the query params
if (this.jwt) {
heartbeat_url.searchParams.set("__sign", this.jwt);
}
this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
}
});
await this._resolve_config().then(({ config }) =>
this._resolve_hearbeat(config)
);
} catch (e: any) {
throw Error(e);
}
@ -190,6 +161,43 @@ export class Client {
this.api_map = map_names_to_ids(this.config?.dependencies || []);
}
async _resolve_hearbeat(_config: Config): Promise<void> {
if (_config) {
this.config = _config;
if (this.config && this.config.connect_heartbeat) {
if (this.config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(
this.config.space_id,
this.options.hf_token,
this.cookies
);
}
}
}
if (_config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(_config.space_id, this.options.hf_token);
}
if (this.config && this.config.connect_heartbeat) {
// connect to the heartbeat endpoint via GET request
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
);
// if the jwt is available, add it to the query params
if (this.jwt) {
heartbeat_url.searchParams.set("__sign", this.jwt);
}
// Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
if (!this.heartbeat_event)
this.heartbeat_event = await this.stream(heartbeat_url);
} else {
this.heartbeat_event?.close();
}
}
static async connect(
app_reference: string,
options: ClientOptions = {}

View File

@ -162,17 +162,27 @@ export function submit(
}
}
function handle_render_config(render_config: any): void {
const resolve_heartbeat = async (config: Config): Promise<void> => {
await this._resolve_hearbeat(config);
};
async function handle_render_config(render_config: any): Promise<void> {
if (!config) return;
let render_id: number = render_config.render_id;
config.components = [
...config.components.filter((c) => c.rendered_in !== render_id),
...config.components.filter((c) => c.props.rendered_in !== render_id),
...render_config.components
];
config.dependencies = [
...config.dependencies.filter((d) => d.rendered_in !== render_id),
...render_config.dependencies
];
const any_state = config.components.some((c) => c.type === "state");
const any_unload = config.dependencies.some((d) =>
d.targets.some((t) => t[1] === "unload")
);
config.connect_heartbeat = any_state || any_unload;
await resolve_heartbeat(config);
fire_event({
type: "render",
data: render_config,
@ -628,7 +638,7 @@ export function submit(
fn_index
});
if (data.render_config) {
handle_render_config(data.render_config);
await handle_render_config(data.render_config);
}
if (complete) {

View File

@ -903,7 +903,6 @@ class BlocksConfig:
if renderable is None or fn.rendered_in == renderable:
dependencies.append(fn.get_config())
config["dependencies"] = dependencies
return config
def __copy__(self):
@ -2043,18 +2042,9 @@ Received outputs:
"fill_height": self.fill_height,
}
config.update(self.default_config.get_config())
any_state = any(
isinstance(block, components.State) for block in self.blocks.values()
config["connect_heartbeat"] = utils.connect_heartbeat(
config, self.blocks.values()
)
any_unload = False
for dep in config["dependencies"]:
for target in dep["targets"]:
if isinstance(target, (list, tuple)) and len(target) == 2:
any_unload = target[1] == "unload"
if any_unload:
break
config["connect_heartbeat"] = any_state or any_unload
return config
def __enter__(self):

View File

@ -1381,3 +1381,17 @@ def _parse_file_size(size: str | int | None) -> int | None:
if not multiple:
raise ValueError(f"Invalid file size unit: {unit}")
return multiple * size_int
def connect_heartbeat(config: dict[str, Any], blocks) -> bool:
from gradio.components import State
any_state = any(isinstance(block, State) for block in blocks)
any_unload = False
for dep in config["dependencies"]:
for target in dep["targets"]:
if isinstance(target, (list, tuple)) and len(target) == 2:
any_unload = target[1] == "unload"
if any_unload:
break
return any_state or any_unload