Fix SourceFileReloader to watch the module with a qualified name to avoid importing a module with the same name from a different path (#6497)

* Fix SourceFileReloader to watch the module with a qualified name to avoid importing a module with the same name from a different path

* Fix the unit tests

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2023-11-21 03:54:26 +09:00 committed by GitHub
parent 070f71c933
commit 1baed201b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 19 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix SourceFileReloader to watch the module with a qualified name to avoid importing a module with the same name from a different path

View File

@ -30,8 +30,8 @@ def _setup_config(
demo_name: str = "demo",
additional_watch_dirs: list[str] | None = None,
):
original_path = demo_path
app_text = Path(original_path).read_text()
original_path = Path(demo_path)
app_text = original_path.read_text()
patterns = [
f"with gr\\.Blocks\\(\\) as {demo_name}",
@ -48,7 +48,12 @@ def _setup_config(
)
abs_original_path = utils.abspath(original_path)
filename = Path(original_path).stem
if original_path.is_absolute():
relpath = original_path.relative_to(Path.cwd())
else:
relpath = original_path
module_name = str(relpath.parent / relpath.stem).replace(os.path.sep, ".")
gradio_folder = Path(inspect.getfile(gradio)).parent
@ -68,12 +73,12 @@ def _setup_config(
message += ","
message += f" '{abs_parent}'"
abs_parent = Path(".").resolve()
if str(abs_parent).strip():
watching_dirs.append(abs_parent)
abs_current = Path.cwd().absolute()
if str(abs_current).strip():
watching_dirs.append(abs_current)
if message_change_count == 1:
message += ","
message += f" '{abs_parent}'"
message += f" '{abs_current}'"
for wd in additional_watch_dirs or []:
if Path(wd) not in watching_dirs:
@ -87,14 +92,14 @@ def _setup_config(
# guaranty access to the module of an app
sys.path.insert(0, os.getcwd())
return filename, abs_original_path, [str(s) for s in watching_dirs], demo_name
return module_name, abs_original_path, [str(s) for s in watching_dirs], demo_name
def main(
demo_path: Path, demo_name: str = "demo", watch_dirs: Optional[List[str]] = None
):
# default execution pattern to start the server and watch changes
filename, path, watch_dirs, demo_name = _setup_config(
module_name, path, watch_dirs, demo_name = _setup_config(
demo_path, demo_name, watch_dirs
)
# extra_args = args[1:] if len(args) == 1 or args[1].startswith("--") else args[2:]
@ -103,7 +108,7 @@ def main(
env=dict(
os.environ,
GRADIO_WATCH_DIRS=",".join(watch_dirs),
GRADIO_WATCH_FILE=filename,
GRADIO_WATCH_MODULE_NAME=module_name,
GRADIO_WATCH_DEMO_NAME=demo_name,
),
)

View File

@ -35,7 +35,7 @@ should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", False))
GRADIO_WATCH_DIRS = (
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
)
GRADIO_WATCH_FILE = os.getenv("GRADIO_WATCH_FILE", "app")
GRADIO_WATCH_MODULE_NAME = os.getenv("GRADIO_WATCH_MODULE_NAME", "app")
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")
@ -192,7 +192,7 @@ def start_server(
reloader = SourceFileReloader(
app=app,
watch_dirs=GRADIO_WATCH_DIRS,
watch_file=GRADIO_WATCH_FILE,
watch_module_name=GRADIO_WATCH_MODULE_NAME,
demo_name=GRADIO_WATCH_DEMO_NAME,
stop_event=threading.Event(),
change_event=change_event,

View File

@ -107,7 +107,7 @@ class SourceFileReloader(BaseReloader):
self,
app: App,
watch_dirs: list[str],
watch_file: str,
watch_module_name: str,
stop_event: threading.Event,
change_event: threading.Event,
demo_name: str = "demo",
@ -115,7 +115,7 @@ class SourceFileReloader(BaseReloader):
super().__init__()
self.app = app
self.watch_dirs = watch_dirs
self.watch_file = watch_file
self.watch_module_name = watch_module_name
self.stop_event = stop_event
self.change_event = change_event
self.demo_name = demo_name
@ -200,11 +200,11 @@ def watchfn(reloader: SourceFileReloader):
if sourcefile and is_in_or_equal(sourcefile, dir_):
del sys.modules[k]
try:
module = importlib.import_module(reloader.watch_file)
module = importlib.import_module(reloader.watch_module_name)
module = importlib.reload(module)
except Exception as e:
print(
f"Reloading {reloader.watch_file} failed with the following exception: "
f"Reloading {reloader.watch_module_name} failed with the following exception: "
)
traceback.print_exception(None, value=e, tb=None)
mtimes = {}

View File

@ -19,7 +19,7 @@ def build_demo():
@dataclasses.dataclass
class Config:
filename: str
module_name: str
path: Path
watch_dirs: List[str]
demo_name: str
@ -44,11 +44,11 @@ class TestReload:
reloader.close()
def test_config_default_app(self, config):
assert config.filename == "run"
assert config.module_name == "demo.calculator.run"
@pytest.mark.parametrize("argv", [["demo/calculator/run.py", "--demo-name test"]])
def test_config_custom_app(self, config):
assert config.filename == "run"
assert config.module_name == "demo.calculator.run"
assert config.demo_name == "test"
def test_config_watch_gradio(self, config):