diff --git a/.changeset/crazy-dancers-allow.md b/.changeset/crazy-dancers-allow.md new file mode 100644 index 0000000000..62dfe1aad4 --- /dev/null +++ b/.changeset/crazy-dancers-allow.md @@ -0,0 +1,5 @@ +--- +"@gradio/wasm": minor +--- + +feat:Make the HTTP requests for the Wasm worker wait for the initial `run_code()` or `run_file()` to finish diff --git a/js/wasm/src/promise-delegate.ts b/js/wasm/src/promise-delegate.ts new file mode 100644 index 0000000000..ab5c1c7bf1 --- /dev/null +++ b/js/wasm/src/promise-delegate.ts @@ -0,0 +1,26 @@ +type PromiseImplFn = ConstructorParameters>[0]; + +export class PromiseDelegate { + private promiseInternal: Promise; + private resolveInternal!: Parameters>[0]; + private rejectInternal!: Parameters>[1]; + + constructor() { + this.promiseInternal = new Promise((resolve, reject) => { + this.resolveInternal = resolve; + this.rejectInternal = reject; + }); + } + + get promise(): Promise { + return this.promiseInternal; + } + + public resolve(value: T): void { + this.resolveInternal(value); + } + + public reject(reason: unknown): void { + this.rejectInternal(reason); + } +} diff --git a/js/wasm/src/webworker/index.ts b/js/wasm/src/webworker/index.ts index 8a4a8faeb4..09ed0ee16b 100644 --- a/js/wasm/src/webworker/index.ts +++ b/js/wasm/src/webworker/index.ts @@ -26,7 +26,7 @@ let call_asgi_app_from_js: ( receive: () => Promise, send: (event: any) => Promise ) => Promise; -let run_script: (path: string) => void; +let run_script: (path: string) => Promise; let unload_local_modules: (target_dir_path?: string) => void; async function loadPyodideAndPackages( @@ -218,7 +218,7 @@ self.onmessage = async (event: MessageEvent): Promise => { case "run-python-file": { unload_local_modules(); - run_script(msg.data.path); + await run_script(msg.data.path); const replyMessage: ReplyMessageSuccess = { type: "reply:success", diff --git a/js/wasm/src/webworker/py/script_runner.py b/js/wasm/src/webworker/py/script_runner.py index 9bf014c627..c0b479e238 100644 --- a/js/wasm/src/webworker/py/script_runner.py +++ b/js/wasm/src/webworker/py/script_runner.py @@ -1,6 +1,8 @@ +import ast import tokenize import types import sys +from inspect import CO_COROUTINE # BSD 3-Clause License # @@ -63,6 +65,7 @@ class modified_sys_path: # Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) +# Copyright (c) Yuichiro Tachibana (2023) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -80,9 +83,10 @@ def _new_module(name: str) -> types.ModuleType: return types.ModuleType(name) -def _run_script(script_path: str) -> None: +async def _run_script(script_path: str) -> None: # This function is based on the following code from Streamlit: # https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/runtime/scriptrunner/script_runner.py#L519-L554 + # with modifications to support top-level await. with tokenize.open(script_path) as f: filebody = f.read() @@ -98,7 +102,7 @@ def _run_script(script_path: str) -> None: # mode (as opposed to "eval" or "single"). mode="exec", # Don't inherit any flags or "future" statements. - flags=0, + flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, # Allow top-level await. Ref: https://github.com/whitphx/streamlit/commit/277dc580efb315a3e9296c9a0078c602a0904384 dont_inherit=1, # Use the default optimization options. optimize=-1, @@ -117,4 +121,9 @@ def _run_script(script_path: str) -> None: module.__dict__["__file__"] = script_path with modified_sys_path(script_path): - exec(bytecode, module.__dict__) + # Allow top-level await. Ref: https://github.com/whitphx/streamlit/commit/277dc580efb315a3e9296c9a0078c602a0904384 + if bytecode.co_flags & CO_COROUTINE: + # The source code includes top-level awaits, so the compiled code object is a coroutine. + await eval(bytecode, module.__dict__) + else: + exec(bytecode, module.__dict__) diff --git a/js/wasm/src/worker-proxy.ts b/js/wasm/src/worker-proxy.ts index f9ce655d30..8f8a8f68af 100644 --- a/js/wasm/src/worker-proxy.ts +++ b/js/wasm/src/worker-proxy.ts @@ -9,6 +9,7 @@ import type { ReplyMessage } from "./message-types"; import { MessagePortWebSocket } from "./messageportwebsocket"; +import { PromiseDelegate } from "./promise-delegate"; export interface WorkerProxyOptions { gradioWheelUrl: string; @@ -20,6 +21,8 @@ export interface WorkerProxyOptions { export class WorkerProxy { private worker: globalThis.Worker; + private firstRunPromiseDelegate = new PromiseDelegate(); + constructor(options: WorkerProxyOptions) { console.debug("WorkerProxy.constructor(): Create a new worker."); // Loading a worker here relies on Vite's support for WebWorkers (https://vitejs.dev/guide/features.html#web-workers), @@ -49,6 +52,7 @@ export class WorkerProxy { code } }); + this.firstRunPromiseDelegate.resolve(); } public async runPythonFile(path: string): Promise { @@ -58,6 +62,7 @@ export class WorkerProxy { path } }); + this.firstRunPromiseDelegate.resolve(); } // A wrapper for this.worker.postMessage(). Unlike that function, which @@ -84,6 +89,12 @@ export class WorkerProxy { } public async httpRequest(request: HttpRequest): Promise { + // Wait for the first run to be done + // to avoid the "Gradio app has not been launched." error + // in case running the code takes long time. + // Ref: https://github.com/gradio-app/gradio/issues/5957 + await this.firstRunPromiseDelegate.promise; + console.debug("WorkerProxy.httpRequest()", request); const result = await this.postMessageAsync({ type: "http-request",