Lite: unload local modules before executing Python scripts (#4826)

* Unload the local modules before re-executing a Python script

* Add changeset
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2023-07-10 18:03:09 +09:00 committed by GitHub
parent 9ff0fe924e
commit f0150c6260
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 181 additions and 2 deletions

View File

@ -0,0 +1,5 @@
---
"@gradio/lite": patch
---
Unload the local modules before re-executing a Python script so the edits on the modules are reflected

View File

@ -12,6 +12,7 @@ import { writeFileWithParents, renameWithParents } from "./file";
import { verifyRequirements } from "./requirements";
import { makeHttpRequest } from "./http";
import scriptRunnerPySource from "./py/script_runner.py?raw";
import unloadModulesPySource from "./py/unload_modules.py?raw"
importScripts("https://cdn.jsdelivr.net/pyodide/v0.23.2/full/pyodide.js");
@ -25,6 +26,7 @@ let call_asgi_app_from_js: (
send: (event: any) => Promise<void>
) => Promise<void>;
let run_script: (path: string) => void;
let unload_local_modules: (target_dir_path?: string) => void;
async function loadPyodideAndPackages(
options: InMessageInit["data"]
@ -161,10 +163,12 @@ matplotlib.use("agg")
`);
console.debug("matplotlib backend is set.");
console.debug("Set up a script runner");
console.debug("Set up Python utility functions.");
await pyodide.runPythonAsync(scriptRunnerPySource);
run_script = pyodide.globals.get("_run_script");
console.debug("A script runner is set up.");
await pyodide.runPythonAsync(unloadModulesPySource);
unload_local_modules = pyodide.globals.get("unload_local_modules");
console.debug("Python utility functions are set up.");
}
self.onmessage = async (event: MessageEvent<InMessage>): Promise<void> => {
@ -201,7 +205,10 @@ self.onmessage = async (event: MessageEvent<InMessage>): Promise<void> => {
break;
}
case "run-python-code": {
unload_local_modules()
await pyodide.runPythonAsync(msg.data.code);
const replyMessage: ReplyMessageSuccess = {
type: "reply:success",
data: null // We don't send back the execution result because it's not needed for our purpose, and sometimes the result is of type `pyodide.ffi.PyProxy` which cannot be cloned across threads and causes an error.
@ -210,6 +217,8 @@ self.onmessage = async (event: MessageEvent<InMessage>): Promise<void> => {
break;
}
case "run-python-file": {
unload_local_modules()
run_script(msg.data.path);
const replyMessage: ReplyMessageSuccess = {

View File

@ -0,0 +1,165 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
# Copyright (c) Yuichiro Tachibana (2023)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import fnmatch
import os
import sys
import types
from typing import Optional, Set
LOGGER = logging.getLogger(__name__)
#
# Copied from https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/file_util.py
#
def file_is_in_folder_glob(filepath, folderpath_glob) -> bool:
"""Test whether a file is in some folder with globbing support.
Parameters
----------
filepath : str
A file path.
folderpath_glob: str
A path to a folder that may include globbing.
"""
# Make the glob always end with "/*" so we match files inside subfolders of
# folderpath_glob.
if not folderpath_glob.endswith("*"):
if folderpath_glob.endswith("/"):
folderpath_glob += "*"
else:
folderpath_glob += "/*"
file_dir = os.path.dirname(filepath) + "/"
return fnmatch.fnmatch(file_dir, folderpath_glob)
def get_directory_size(directory: str) -> int:
"""Return the size of a directory in bytes."""
total_size = 0
for dirpath, _, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
return total_size
def file_in_pythonpath(filepath) -> bool:
"""Test whether a filepath is in the same folder of a path specified in the PYTHONPATH env variable.
Parameters
----------
filepath : str
An absolute file path.
Returns
-------
boolean
True if contained in PYTHONPATH, False otherwise. False if PYTHONPATH is not defined or empty.
"""
pythonpath = os.environ.get("PYTHONPATH", "")
if len(pythonpath) == 0:
return False
absolute_paths = [os.path.abspath(path) for path in pythonpath.split(os.pathsep)]
return any(
file_is_in_folder_glob(os.path.normpath(filepath), path)
for path in absolute_paths
)
#
# Copied from https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/watcher/local_sources_watcher.py
#
def get_module_paths(module: types.ModuleType) -> Set[str]:
paths_extractors = [
# https://docs.python.org/3/reference/datamodel.html
# __file__ is the pathname of the file from which the module was loaded
# if it was loaded from a file.
# The __file__ attribute may be missing for certain types of modules
lambda m: [m.__file__],
# https://docs.python.org/3/reference/import.html#__spec__
# The __spec__ attribute is set to the module spec that was used
# when importing the module. one exception is __main__,
# where __spec__ is set to None in some cases.
# https://www.python.org/dev/peps/pep-0451/#id16
# "origin" in an import context means the system
# (or resource within a system) from which a module originates
# ... It is up to the loader to decide on how to interpret
# and use a module's origin, if at all.
lambda m: [m.__spec__.origin],
# https://www.python.org/dev/peps/pep-0420/
# Handling of "namespace packages" in which the __path__ attribute
# is a _NamespacePath object with a _path attribute containing
# the various paths of the package.
lambda m: [p for p in m.__path__._path],
]
all_paths = set()
for extract_paths in paths_extractors:
potential_paths = []
try:
potential_paths = extract_paths(module)
except AttributeError:
# Some modules might not have __file__ or __spec__ attributes.
pass
except Exception as e:
LOGGER.warning(f"Examining the path of {module.__name__} raised: {e}")
all_paths.update(
[os.path.abspath(str(p)) for p in potential_paths if _is_valid_path(p)]
)
return all_paths
def _is_valid_path(path: Optional[str]) -> bool:
return isinstance(path, str) and (os.path.isfile(path) or os.path.isdir(path))
#
# Original code
#
def unload_local_modules(target_dir_path: str = "."):
""" Unload all modules that are in the target directory or in a subdirectory of it.
It is necessary to unload modules before re-executing a script that imports the modules,
so that the new version of the modules is loaded.
The module unloading feature is extracted from Streamlit's LocalSourcesWatcher (https://github.com/streamlit/streamlit/blob/1.24.0/lib/streamlit/watcher/local_sources_watcher.py)
and packaged as a standalone function.
"""
target_dir_path = os.path.abspath(target_dir_path)
loaded_modules = {} # filepath -> module_name
# Copied from `LocalSourcesWatcher.update_watched_modules()`
module_paths = {
name: get_module_paths(module)
for name, module in dict(sys.modules).items()
}
# Copied from `LocalSourcesWatcher._register_necessary_watchers()`
for name, paths in module_paths.items():
for path in paths:
if file_is_in_folder_glob(path, target_dir_path) or file_in_pythonpath(path):
loaded_modules[path] = name
# Copied from `LocalSourcesWatcher.on_file_changed()`
for module_name in loaded_modules.values():
if module_name is not None and module_name in sys.modules:
del sys.modules[module_name]