From f0150c6260d657b150b73f0eecabd10b19d297c8 Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Mon, 10 Jul 2023 18:03:09 +0900 Subject: [PATCH] Lite: unload local modules before executing Python scripts (#4826) * Unload the local modules before re-executing a Python script * Add changeset --- .changeset/mean-birds-cheer.md | 5 + js/wasm/src/webworker/index.ts | 13 +- js/wasm/src/webworker/py/unload_modules.py | 165 +++++++++++++++++++++ 3 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 .changeset/mean-birds-cheer.md create mode 100644 js/wasm/src/webworker/py/unload_modules.py diff --git a/.changeset/mean-birds-cheer.md b/.changeset/mean-birds-cheer.md new file mode 100644 index 0000000000..9c2cb2c09e --- /dev/null +++ b/.changeset/mean-birds-cheer.md @@ -0,0 +1,5 @@ +--- +"@gradio/lite": patch +--- + +Unload the local modules before re-executing a Python script so the edits on the modules are reflected diff --git a/js/wasm/src/webworker/index.ts b/js/wasm/src/webworker/index.ts index e5b773b46d..ebb33bc3f1 100644 --- a/js/wasm/src/webworker/index.ts +++ b/js/wasm/src/webworker/index.ts @@ -12,6 +12,7 @@ import { writeFileWithParents, renameWithParents } from "./file"; import { verifyRequirements } from "./requirements"; import { makeHttpRequest } from "./http"; import scriptRunnerPySource from "./py/script_runner.py?raw"; +import unloadModulesPySource from "./py/unload_modules.py?raw" importScripts("https://cdn.jsdelivr.net/pyodide/v0.23.2/full/pyodide.js"); @@ -25,6 +26,7 @@ let call_asgi_app_from_js: ( send: (event: any) => Promise ) => Promise; let run_script: (path: string) => void; +let unload_local_modules: (target_dir_path?: string) => void; async function loadPyodideAndPackages( options: InMessageInit["data"] @@ -161,10 +163,12 @@ matplotlib.use("agg") `); console.debug("matplotlib backend is set."); - console.debug("Set up a script runner"); + console.debug("Set up Python utility functions."); await pyodide.runPythonAsync(scriptRunnerPySource); run_script = pyodide.globals.get("_run_script"); - console.debug("A script runner is set up."); + await pyodide.runPythonAsync(unloadModulesPySource); + unload_local_modules = pyodide.globals.get("unload_local_modules"); + console.debug("Python utility functions are set up."); } self.onmessage = async (event: MessageEvent): Promise => { @@ -201,7 +205,10 @@ self.onmessage = async (event: MessageEvent): Promise => { break; } case "run-python-code": { + unload_local_modules() + await pyodide.runPythonAsync(msg.data.code); + const replyMessage: ReplyMessageSuccess = { type: "reply:success", data: null // We don't send back the execution result because it's not needed for our purpose, and sometimes the result is of type `pyodide.ffi.PyProxy` which cannot be cloned across threads and causes an error. @@ -210,6 +217,8 @@ self.onmessage = async (event: MessageEvent): Promise => { break; } case "run-python-file": { + unload_local_modules() + run_script(msg.data.path); const replyMessage: ReplyMessageSuccess = { diff --git a/js/wasm/src/webworker/py/unload_modules.py b/js/wasm/src/webworker/py/unload_modules.py new file mode 100644 index 0000000000..569a685e48 --- /dev/null +++ b/js/wasm/src/webworker/py/unload_modules.py @@ -0,0 +1,165 @@ +# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) +# Copyright (c) Yuichiro Tachibana (2023) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import fnmatch +import os +import sys +import types +from typing import Optional, Set + +LOGGER = logging.getLogger(__name__) + +# +# Copied from https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/file_util.py +# + +def file_is_in_folder_glob(filepath, folderpath_glob) -> bool: + """Test whether a file is in some folder with globbing support. + + Parameters + ---------- + filepath : str + A file path. + folderpath_glob: str + A path to a folder that may include globbing. + + """ + # Make the glob always end with "/*" so we match files inside subfolders of + # folderpath_glob. + if not folderpath_glob.endswith("*"): + if folderpath_glob.endswith("/"): + folderpath_glob += "*" + else: + folderpath_glob += "/*" + + file_dir = os.path.dirname(filepath) + "/" + return fnmatch.fnmatch(file_dir, folderpath_glob) + + +def get_directory_size(directory: str) -> int: + """Return the size of a directory in bytes.""" + total_size = 0 + for dirpath, _, filenames in os.walk(directory): + for f in filenames: + fp = os.path.join(dirpath, f) + total_size += os.path.getsize(fp) + return total_size + + +def file_in_pythonpath(filepath) -> bool: + """Test whether a filepath is in the same folder of a path specified in the PYTHONPATH env variable. + + + Parameters + ---------- + filepath : str + An absolute file path. + + Returns + ------- + boolean + True if contained in PYTHONPATH, False otherwise. False if PYTHONPATH is not defined or empty. + + """ + pythonpath = os.environ.get("PYTHONPATH", "") + if len(pythonpath) == 0: + return False + + absolute_paths = [os.path.abspath(path) for path in pythonpath.split(os.pathsep)] + return any( + file_is_in_folder_glob(os.path.normpath(filepath), path) + for path in absolute_paths + ) + +# +# Copied from https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/watcher/local_sources_watcher.py +# + +def get_module_paths(module: types.ModuleType) -> Set[str]: + paths_extractors = [ + # https://docs.python.org/3/reference/datamodel.html + # __file__ is the pathname of the file from which the module was loaded + # if it was loaded from a file. + # The __file__ attribute may be missing for certain types of modules + lambda m: [m.__file__], + # https://docs.python.org/3/reference/import.html#__spec__ + # The __spec__ attribute is set to the module spec that was used + # when importing the module. one exception is __main__, + # where __spec__ is set to None in some cases. + # https://www.python.org/dev/peps/pep-0451/#id16 + # "origin" in an import context means the system + # (or resource within a system) from which a module originates + # ... It is up to the loader to decide on how to interpret + # and use a module's origin, if at all. + lambda m: [m.__spec__.origin], + # https://www.python.org/dev/peps/pep-0420/ + # Handling of "namespace packages" in which the __path__ attribute + # is a _NamespacePath object with a _path attribute containing + # the various paths of the package. + lambda m: [p for p in m.__path__._path], + ] + + all_paths = set() + for extract_paths in paths_extractors: + potential_paths = [] + try: + potential_paths = extract_paths(module) + except AttributeError: + # Some modules might not have __file__ or __spec__ attributes. + pass + except Exception as e: + LOGGER.warning(f"Examining the path of {module.__name__} raised: {e}") + + all_paths.update( + [os.path.abspath(str(p)) for p in potential_paths if _is_valid_path(p)] + ) + return all_paths + + +def _is_valid_path(path: Optional[str]) -> bool: + return isinstance(path, str) and (os.path.isfile(path) or os.path.isdir(path)) + + +# +# Original code +# + +def unload_local_modules(target_dir_path: str = "."): + """ Unload all modules that are in the target directory or in a subdirectory of it. + It is necessary to unload modules before re-executing a script that imports the modules, + so that the new version of the modules is loaded. + The module unloading feature is extracted from Streamlit's LocalSourcesWatcher (https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/watcher/local_sources_watcher.py) + and packaged as a standalone function. + """ + target_dir_path = os.path.abspath(target_dir_path) + loaded_modules = {} # filepath -> module_name + + # Copied from `LocalSourcesWatcher.update_watched_modules()` + module_paths = { + name: get_module_paths(module) + for name, module in dict(sys.modules).items() + } + + # Copied from `LocalSourcesWatcher._register_necessary_watchers()` + for name, paths in module_paths.items(): + for path in paths: + if file_is_in_folder_glob(path, target_dir_path) or file_in_pythonpath(path): + loaded_modules[path] = name + + # Copied from `LocalSourcesWatcher.on_file_changed()` + for module_name in loaded_modules.values(): + if module_name is not None and module_name in sys.modules: + del sys.modules[module_name]