diff --git a/.changeset/flat-sites-judge.md b/.changeset/flat-sites-judge.md new file mode 100644 index 0000000000..e04acd53c5 --- /dev/null +++ b/.changeset/flat-sites-judge.md @@ -0,0 +1,7 @@ +--- +"@gradio/app": minor +"gradio": minor +"website": minor +--- + +feat:Do not reload code inside gr.NO_RELOAD context diff --git a/.config/playwright-setup.js b/.config/playwright-setup.js index 3fc61854f8..f8ebf18378 100644 --- a/.config/playwright-setup.js +++ b/.config/playwright-setup.js @@ -16,7 +16,8 @@ const test_files = readdirSync(TEST_FILES_PATH) (f) => f.endsWith("spec.ts") && !f.endsWith(".skip.spec.ts") && - !f.endsWith(".component.spec.ts") + !f.endsWith(".component.spec.ts") && + !f.endsWith(".reload.spec.ts") ) .map((f) => basename(f, ".spec.ts")); diff --git a/.config/playwright.config.js b/.config/playwright.config.js index caaac5375b..2bfe61bdb7 100644 --- a/.config/playwright.config.js +++ b/.config/playwright.config.js @@ -24,7 +24,7 @@ const base = defineConfig({ }); const normal = defineConfig(base, { - globalSetup: "./playwright-setup.js" + globalSetup: process.env.CUSTOM_TEST ? undefined : "./playwright-setup.js" }); normal.projects = undefined; // Explicitly unset this field due to https://github.com/microsoft/playwright/issues/28795 diff --git a/.github/workflows/test-functional.yml b/.github/workflows/test-functional.yml index d7e6b415c3..3d20f26ef6 100644 --- a/.github/workflows/test-functional.yml +++ b/.github/workflows/test-functional.yml @@ -74,6 +74,10 @@ jobs: run: | . venv/bin/activate pnpm run test:ct + - name: run reload mode test + run: | + . venv/bin/activate + pnpm test:browser:reload - name: Run Lite E2E tests run: | . venv/bin/activate diff --git a/gradio/__init__.py b/gradio/__init__.py index 4397ae4b97..d2d8494d11 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -98,6 +98,6 @@ from gradio.templates import ( TextArea, ) from gradio.themes import Base as Theme -from gradio.utils import get_package_version, set_static_paths +from gradio.utils import NO_RELOAD, get_package_version, set_static_paths __version__ = get_package_version() diff --git a/gradio/cli/commands/reload.py b/gradio/cli/commands/reload.py index cc75bbbba0..ba8fd961a2 100644 --- a/gradio/cli/commands/reload.py +++ b/gradio/cli/commands/reload.py @@ -97,7 +97,7 @@ def _setup_config( print(message + "\n") - # guaranty access to the module of an app + # guarantee access to the module of an app sys.path.insert(0, os.getcwd()) return module_name, abs_original_path, [str(s) for s in watching_dirs], demo_name @@ -109,17 +109,20 @@ def main( encoding: str = "utf-8", ): # default execution pattern to start the server and watch changes - module_name, path, watch_dirs, demo_name = _setup_config( + module_name, path, watch_sources, demo_name = _setup_config( demo_path, demo_name, watch_dirs, encoding ) - # extra_args = args[1:] if len(args) == 1 or args[1].startswith("--") else args[2:] + + # Pass the following data as environment variables + # so that we can set up reload mode correctly in the networking.py module popen = subprocess.Popen( [sys.executable, "-u", path], env=dict( os.environ, - GRADIO_WATCH_DIRS=",".join(watch_dirs), + GRADIO_WATCH_DIRS=",".join(watch_sources), GRADIO_WATCH_MODULE_NAME=module_name, GRADIO_WATCH_DEMO_NAME=demo_name, + GRADIO_WATCH_DEMO_PATH=str(path), ), ) popen.wait() diff --git a/gradio/networking.py b/gradio/networking.py index 8b586bfbe3..56cfd30790 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -1,258 +1,260 @@ -""" -Defines helper methods useful for setting up ports, launching servers, and -creating tunnels. -""" -from __future__ import annotations - -import os -import socket -import threading -import time -import warnings -from functools import partial -from typing import TYPE_CHECKING - -import httpx -import uvicorn -from uvicorn.config import Config - -from gradio.exceptions import ServerFailedToStartError -from gradio.routes import App -from gradio.tunneling import Tunnel -from gradio.utils import SourceFileReloader, watchfn - -if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). - from gradio.blocks import Blocks - -# By default, the local server will try to open on localhost, port 7860. -# If that is not available, then it will try 7861, 7862, ... 7959. -INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860")) -TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100")) -LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") -GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request" -GRADIO_SHARE_SERVER_ADDRESS = os.getenv("GRADIO_SHARE_SERVER_ADDRESS") - -should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", "")) -GRADIO_WATCH_DIRS = ( - os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else [] -) -GRADIO_WATCH_MODULE_NAME = os.getenv("GRADIO_WATCH_MODULE_NAME", "app") -GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo") - - -class Server(uvicorn.Server): - def __init__( - self, config: Config, reloader: SourceFileReloader | None = None - ) -> None: - self.running_app = config.app - super().__init__(config) - self.reloader = reloader - if self.reloader: - self.event = threading.Event() - self.watch = partial(watchfn, self.reloader) - - def install_signal_handlers(self): - pass - - def run_in_thread(self): - self.thread = threading.Thread(target=self.run, daemon=True) - if self.reloader: - self.watch_thread = threading.Thread(target=self.watch, daemon=True) - self.watch_thread.start() - self.thread.start() - start = time.time() - while not self.started: - time.sleep(1e-3) - if time.time() - start > 5: - raise ServerFailedToStartError( - "Server failed to start. Please check that the port is available." - ) - - def close(self): - self.should_exit = True - if self.reloader: - self.reloader.stop() - self.watch_thread.join() - self.thread.join() - - -def get_first_available_port(initial: int, final: int) -> int: - """ - Gets the first open port in a specified range of port numbers - Parameters: - initial: the initial value in the range of port numbers - final: final (exclusive) value in the range of port numbers, should be greater than `initial` - Returns: - port: the first open port in the range - """ - for port in range(initial, final): - try: - s = socket.socket() # create a socket object - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind((LOCALHOST_NAME, port)) # Bind to the port - s.close() - return port - except OSError: - pass - raise OSError( - f"All ports from {initial} to {final - 1} are in use. Please close a port." - ) - - -def configure_app(app: App, blocks: Blocks) -> App: - auth = blocks.auth - if auth is not None: - if not callable(auth): - app.auth = {account[0]: account[1] for account in auth} - else: - app.auth = auth - else: - app.auth = None - app.blocks = blocks - app.cwd = os.getcwd() - app.favicon_path = blocks.favicon_path - app.tokens = {} - return app - - -def start_server( - blocks: Blocks, - server_name: str | None = None, - server_port: int | None = None, - ssl_keyfile: str | None = None, - ssl_certfile: str | None = None, - ssl_keyfile_password: str | None = None, - app_kwargs: dict | None = None, -) -> tuple[str, int, str, App, Server]: - """Launches a local server running the provided Interface - Parameters: - blocks: The Blocks object to run on the server - server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. - server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. - auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login. - ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https. - ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided. - ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https. - app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor. - - Returns: - port: the port number the server is running on - path_to_local_server: the complete address that the local server can be accessed at - app: the FastAPI app object - server: the server object that is a subclass of uvicorn.Server (used to close the server) - """ - if ssl_keyfile is not None and ssl_certfile is None: - raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.") - - server_name = server_name or LOCALHOST_NAME - url_host_name = "localhost" if server_name == "0.0.0.0" else server_name - - # Strip IPv6 brackets from the address if they exist. - # This is needed as http://[::1]:port/ is a valid browser address, - # but not a valid IPv6 address, so asyncio will throw an exception. - if server_name.startswith("[") and server_name.endswith("]"): - host = server_name[1:-1] - else: - host = server_name - - app = App.create_app(blocks, app_kwargs=app_kwargs) - - server_ports = ( - [server_port] - if server_port is not None - else range(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) - ) - - for port in server_ports: - try: - # The fastest way to check if a port is available is to try to bind to it with socket. - # If the port is not available, socket will throw an OSError. - s = socket.socket() - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # Really, we should be checking if (server_name, server_port) is available, but - # socket.bind() doesn't seem to throw an OSError with ipv6 addresses, based on my testing. - # Instead, we just check if the port is available on localhost. - s.bind((LOCALHOST_NAME, port)) - s.close() - - # To avoid race conditions, so we also check if the port by trying to start the uvicorn server. - # If the port is not available, this will throw a ServerFailedToStartError. - config = uvicorn.Config( - app=app, - port=port, - host=host, - log_level="warning", - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ssl_keyfile_password=ssl_keyfile_password, - ) - reloader = None - if GRADIO_WATCH_DIRS: - change_event = threading.Event() - app.change_event = change_event - reloader = SourceFileReloader( - app=app, - watch_dirs=GRADIO_WATCH_DIRS, - watch_module_name=GRADIO_WATCH_MODULE_NAME, - demo_name=GRADIO_WATCH_DEMO_NAME, - stop_event=threading.Event(), - change_event=change_event, - ) - server = Server(config=config, reloader=reloader) - server.run_in_thread() - break - except (OSError, ServerFailedToStartError): - pass - else: - raise OSError( - f"Cannot find empty port in range: {min(server_ports)}-{max(server_ports)}. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`." - ) - - if ssl_keyfile is not None: - path_to_local_server = f"https://{url_host_name}:{port}/" - else: - path_to_local_server = f"http://{url_host_name}:{port}/" - - return server_name, port, path_to_local_server, app, server - - -def setup_tunnel( - local_host: str, local_port: int, share_token: str, share_server_address: str | None -) -> str: - share_server_address = ( - GRADIO_SHARE_SERVER_ADDRESS - if share_server_address is None - else share_server_address - ) - if share_server_address is None: - try: - response = httpx.get(GRADIO_API_SERVER, timeout=30) - payload = response.json()[0] - remote_host, remote_port = payload["host"], int(payload["port"]) - except Exception as e: - raise RuntimeError( - "Could not get share link from Gradio API Server." - ) from e - else: - remote_host, remote_port = share_server_address.split(":") - remote_port = int(remote_port) - try: - tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token) - address = tunnel.start_tunnel() - return address - except Exception as e: - raise RuntimeError(str(e)) from e - - -def url_ok(url: str) -> bool: - try: - for _ in range(5): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - r = httpx.head(url, timeout=3, verify=False) - if r.status_code in (200, 401, 302): # 401 or 302 if auth is set - return True - time.sleep(0.500) - except (ConnectionError, httpx.ConnectError, httpx.TimeoutException): - return False - return False +""" +Defines helper methods useful for setting up ports, launching servers, and +creating tunnels. +""" +from __future__ import annotations + +import os +import socket +import threading +import time +import warnings +from functools import partial +from typing import TYPE_CHECKING + +import httpx +import uvicorn +from uvicorn.config import Config + +from gradio.exceptions import ServerFailedToStartError +from gradio.routes import App +from gradio.tunneling import Tunnel +from gradio.utils import SourceFileReloader, watchfn + +if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). + from gradio.blocks import Blocks + +# By default, the local server will try to open on localhost, port 7860. +# If that is not available, then it will try 7861, 7862, ... 7959. +INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860")) +TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100")) +LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") +GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request" +GRADIO_SHARE_SERVER_ADDRESS = os.getenv("GRADIO_SHARE_SERVER_ADDRESS") + +should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", "")) +GRADIO_WATCH_DIRS = ( + os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else [] +) +GRADIO_WATCH_MODULE_NAME = os.getenv("GRADIO_WATCH_MODULE_NAME", "app") +GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo") +GRADIO_WATCH_DEMO_PATH = os.getenv("GRADIO_WATCH_DEMO_PATH", "") + + +class Server(uvicorn.Server): + def __init__( + self, config: Config, reloader: SourceFileReloader | None = None + ) -> None: + self.running_app = config.app + super().__init__(config) + self.reloader = reloader + if self.reloader: + self.event = threading.Event() + self.watch = partial(watchfn, self.reloader) + + def install_signal_handlers(self): + pass + + def run_in_thread(self): + self.thread = threading.Thread(target=self.run, daemon=True) + if self.reloader: + self.watch_thread = threading.Thread(target=self.watch, daemon=True) + self.watch_thread.start() + self.thread.start() + start = time.time() + while not self.started: + time.sleep(1e-3) + if time.time() - start > 5: + raise ServerFailedToStartError( + "Server failed to start. Please check that the port is available." + ) + + def close(self): + self.should_exit = True + if self.reloader: + self.reloader.stop() + self.watch_thread.join() + self.thread.join() + + +def get_first_available_port(initial: int, final: int) -> int: + """ + Gets the first open port in a specified range of port numbers + Parameters: + initial: the initial value in the range of port numbers + final: final (exclusive) value in the range of port numbers, should be greater than `initial` + Returns: + port: the first open port in the range + """ + for port in range(initial, final): + try: + s = socket.socket() # create a socket object + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((LOCALHOST_NAME, port)) # Bind to the port + s.close() + return port + except OSError: + pass + raise OSError( + f"All ports from {initial} to {final - 1} are in use. Please close a port." + ) + + +def configure_app(app: App, blocks: Blocks) -> App: + auth = blocks.auth + if auth is not None: + if not callable(auth): + app.auth = {account[0]: account[1] for account in auth} + else: + app.auth = auth + else: + app.auth = None + app.blocks = blocks + app.cwd = os.getcwd() + app.favicon_path = blocks.favicon_path + app.tokens = {} + return app + + +def start_server( + blocks: Blocks, + server_name: str | None = None, + server_port: int | None = None, + ssl_keyfile: str | None = None, + ssl_certfile: str | None = None, + ssl_keyfile_password: str | None = None, + app_kwargs: dict | None = None, +) -> tuple[str, int, str, App, Server]: + """Launches a local server running the provided Interface + Parameters: + blocks: The Blocks object to run on the server + server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. + server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. + auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login. + ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https. + ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided. + ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https. + app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor. + + Returns: + port: the port number the server is running on + path_to_local_server: the complete address that the local server can be accessed at + app: the FastAPI app object + server: the server object that is a subclass of uvicorn.Server (used to close the server) + """ + if ssl_keyfile is not None and ssl_certfile is None: + raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.") + + server_name = server_name or LOCALHOST_NAME + url_host_name = "localhost" if server_name == "0.0.0.0" else server_name + + # Strip IPv6 brackets from the address if they exist. + # This is needed as http://[::1]:port/ is a valid browser address, + # but not a valid IPv6 address, so asyncio will throw an exception. + if server_name.startswith("[") and server_name.endswith("]"): + host = server_name[1:-1] + else: + host = server_name + + app = App.create_app(blocks, app_kwargs=app_kwargs) + + server_ports = ( + [server_port] + if server_port is not None + else range(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) + ) + + for port in server_ports: + try: + # The fastest way to check if a port is available is to try to bind to it with socket. + # If the port is not available, socket will throw an OSError. + s = socket.socket() + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # Really, we should be checking if (server_name, server_port) is available, but + # socket.bind() doesn't seem to throw an OSError with ipv6 addresses, based on my testing. + # Instead, we just check if the port is available on localhost. + s.bind((LOCALHOST_NAME, port)) + s.close() + + # To avoid race conditions, so we also check if the port by trying to start the uvicorn server. + # If the port is not available, this will throw a ServerFailedToStartError. + config = uvicorn.Config( + app=app, + port=port, + host=host, + log_level="warning", + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_keyfile_password=ssl_keyfile_password, + ) + reloader = None + if GRADIO_WATCH_DIRS: + change_event = threading.Event() + app.change_event = change_event + reloader = SourceFileReloader( + app=app, + watch_dirs=GRADIO_WATCH_DIRS, + watch_module_name=GRADIO_WATCH_MODULE_NAME, + demo_name=GRADIO_WATCH_DEMO_NAME, + stop_event=threading.Event(), + change_event=change_event, + demo_file=GRADIO_WATCH_DEMO_PATH, + ) + server = Server(config=config, reloader=reloader) + server.run_in_thread() + break + except (OSError, ServerFailedToStartError): + pass + else: + raise OSError( + f"Cannot find empty port in range: {min(server_ports)}-{max(server_ports)}. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`." + ) + + if ssl_keyfile is not None: + path_to_local_server = f"https://{url_host_name}:{port}/" + else: + path_to_local_server = f"http://{url_host_name}:{port}/" + + return server_name, port, path_to_local_server, app, server + + +def setup_tunnel( + local_host: str, local_port: int, share_token: str, share_server_address: str | None +) -> str: + share_server_address = ( + GRADIO_SHARE_SERVER_ADDRESS + if share_server_address is None + else share_server_address + ) + if share_server_address is None: + try: + response = httpx.get(GRADIO_API_SERVER, timeout=30) + payload = response.json()[0] + remote_host, remote_port = payload["host"], int(payload["port"]) + except Exception as e: + raise RuntimeError( + "Could not get share link from Gradio API Server." + ) from e + else: + remote_host, remote_port = share_server_address.split(":") + remote_port = int(remote_port) + try: + tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token) + address = tunnel.start_tunnel() + return address + except Exception as e: + raise RuntimeError(str(e)) from e + + +def url_ok(url: str) -> bool: + try: + for _ in range(5): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + r = httpx.head(url, timeout=3, verify=False) + if r.status_code in (200, 401, 302): # 401 or 302 if auth is set + return True + time.sleep(0.500) + except (ConnectionError, httpx.ConnectError): + return False + return False diff --git a/gradio/state_holder.py b/gradio/state_holder.py index a0c4a95dfc..27ff774594 100644 --- a/gradio/state_holder.py +++ b/gradio/state_holder.py @@ -19,6 +19,12 @@ class StateHolder: self.blocks = blocks self.capacity = blocks.state_session_capacity + def reset(self, blocks: Blocks): + """Reset the state holder with new blocks. Used during reload mode.""" + self.session_data = OrderedDict() + # Call set blocks again to set new ids + self.set_blocks(blocks) + def __getitem__(self, session_id: str) -> SessionState: if session_id not in self.session_data: self.session_data[session_id] = SessionState(self.blocks) diff --git a/gradio/utils.py b/gradio/utils.py index f775f187e0..3cba9431e9 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -2,17 +2,20 @@ from __future__ import annotations +import ast import asyncio import copy import dataclasses import functools import importlib +import importlib.util import inspect import json import json.decoder import os import pkgutil import re +import sys import tempfile import threading import time @@ -26,7 +29,7 @@ from contextlib import contextmanager from io import BytesIO from numbers import Number from pathlib import Path -from types import AsyncGeneratorType, GeneratorType +from types import AsyncGeneratorType, GeneratorType, ModuleType from typing import ( TYPE_CHECKING, Any, @@ -104,6 +107,7 @@ class BaseReloader(ABC): # not a new queue self.running_app.blocks._queue.block_fns = demo.fns demo._queue = self.running_app.blocks._queue + self.running_app.state_holder.reset(demo) self.running_app.blocks = demo demo._queue.reload() @@ -114,6 +118,7 @@ class SourceFileReloader(BaseReloader): app: App, watch_dirs: list[str], watch_module_name: str, + demo_file: str, stop_event: threading.Event, change_event: threading.Event, demo_name: str = "demo", @@ -125,6 +130,7 @@ class SourceFileReloader(BaseReloader): self.stop_event = stop_event self.change_event = change_event self.demo_name = demo_name + self.demo_file = Path(demo_file) @property def running_app(self) -> App: @@ -144,6 +150,51 @@ class SourceFileReloader(BaseReloader): self.alert_change() +NO_RELOAD = True + + +def _remove_no_reload_codeblocks(file_path: str): + """Parse the file, remove the gr.no_reload code blocks, and write the file back to disk. + + Parameters: + file_path (str): The path to the file to remove the no_reload code blocks from. + """ + + with open(file_path) as file: + code = file.read() + + tree = ast.parse(code) + + def _is_gr_no_reload(expr: ast.AST) -> bool: + """Find with gr.no_reload context managers.""" + return ( + isinstance(expr, ast.If) + and isinstance(expr.test, ast.Attribute) + and isinstance(expr.test.value, ast.Name) + and expr.test.value.id == "gr" + and expr.test.attr == "NO_RELOAD" + ) + + # Find the positions of the code blocks to load + for node in ast.walk(tree): + if _is_gr_no_reload(node): + assert isinstance(node, ast.If) # noqa: S101 + node.body = [ast.Pass(lineno=node.lineno, col_offset=node.col_offset)] + + # convert tree to string + code_removed = compile(tree, filename=file_path, mode="exec") + return code_removed + + +def _find_module(source_file: Path) -> ModuleType: + for s, v in sys.modules.items(): + if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str( + source_file + ): + return v + raise ValueError(f"Cannot find module for source file: {source_file}") + + def watchfn(reloader: SourceFileReloader): """Watch python files in a given module. @@ -179,7 +230,6 @@ def watchfn(reloader: SourceFileReloader): for path in list(reload_dir.rglob("*.css")): yield path.resolve() - module = None reload_dirs = [Path(dir_) for dir_ in reloader.watch_dirs] import sys @@ -187,29 +237,37 @@ def watchfn(reloader: SourceFileReloader): sys.path.insert(0, str(dir_)) mtimes = {} + # Need to import the module in this thread so that the + # module is available in the namespace of this thread + module = importlib.import_module(reloader.watch_module_name) while reloader.should_watch(): changed = get_changes() if changed: print(f"Changes detected in: {changed}") - # To simulate a fresh reload, delete all module references from sys.modules - # for the modules in the package the change came from. - dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d)) - modules = list(sys.modules) - for k in modules: - v = sys.modules[k] - sourcefile = getattr(v, "__file__", None) - # Do not reload `reload.py` to keep thread data - if ( - sourcefile - and dir_ == Path(inspect.getfile(gradio)).parent - and sourcefile.endswith("reload.py") - ): - continue - if sourcefile and is_in_or_equal(sourcefile, dir_): - del sys.modules[k] try: - module = importlib.import_module(reloader.watch_module_name) - module = importlib.reload(module) + # How source file reloading works + # 1. Remove the gr.no_reload code blocks from the temp file + # 2. Execute the changed source code in the original module's namespac + # 3. Delete the package the module is in from sys.modules. + # This is so that the updated module is available in the entire package + # 4. Do 1-2 for the main demo file even if it did not change. + # This is because the main demo file may import the changed file and we need the + # changes to be reflected in the main demo file. + + changed_in_copy = _remove_no_reload_codeblocks(str(changed)) + if changed != reloader.demo_file: + changed_module = _find_module(changed) + exec(changed_in_copy, changed_module.__dict__) + top_level_parent = sys.modules[ + changed_module.__name__.split(".")[0] + ] + if top_level_parent != changed_module: + importlib.reload(top_level_parent) + + changed_demo_file = _remove_no_reload_codeblocks( + str(reloader.demo_file) + ) + exec(changed_demo_file, module.__dict__) except Exception: print( f"Reloading {reloader.watch_module_name} failed with the following exception: " @@ -217,7 +275,6 @@ def watchfn(reloader: SourceFileReloader): traceback.print_exc() mtimes = {} continue - demo = getattr(module, reloader.demo_name) if reloader.queue_changed(demo): print( diff --git a/guides/09_other-tutorials/developing-faster-with-reload-mode.md b/guides/09_other-tutorials/developing-faster-with-reload-mode.md index 7cf5db3bb6..eaf545014d 100644 --- a/guides/09_other-tutorials/developing-faster-with-reload-mode.md +++ b/guides/09_other-tutorials/developing-faster-with-reload-mode.md @@ -48,7 +48,7 @@ Running on local URL: http://127.0.0.1:7860 The important part here is the line that says `Watching...` What's happening here is that Gradio will be observing the directory where `run.py` file lives, and if the file changes, it will automatically rerun the file for you. So you can focus on writing your code, and your Gradio demo will refresh automatically 🥳 -⚠️ Warning: the `gradio` command does not detect the parameters passed to the `launch()` methods because the `launch()` method is never called in reload mode. For example, setting `auth`, or `show_error` in `launch()` will not be reflected in the app. +Tip: the `gradio` command does not detect the parameters passed to the `launch()` methods because the `launch()` method is never called in reload mode. For example, setting `auth`, or `show_error` in `launch()` will not be reflected in the app. There is one important thing to keep in mind when using the reload mode: Gradio specifically looks for a Gradio Blocks/Interface demo called `demo` in your code. If you have named your demo something else, you will need to pass in the name of your demo as the 2nd parameter in your code. So if your `run.py` file looked like this: @@ -101,6 +101,31 @@ Which you could run like this: `gradio run.py --name Gretel` As a small aside, this auto-reloading happens if you change your `run.py` source code or the Gradio source code. Meaning that this can be useful if you decide to [contribute to Gradio itself](https://github.com/gradio-app/gradio/blob/main/CONTRIBUTING.md) ✅ + +## Controlling the Reload 🎛️ + +By default, reload mode will re-run your entire script for every change you make. +But there are some cases where this is not desirable. +For example, loading a machine learning model should probably only happen once to save time. There are also some Python libraries that use C or Rust extensions that throw errors when they are reloaded, like `numpy` and `tiktoken`. + +In these situations, you can place code that you do not want to be re-run inside an `if gr.NO_RELOAD:` codeblock. Here's an example of how you can use it to only load a transformers model once during the development process. + +Tip: The value of `gr.NO_RELOAD` is `True`. So you don't have to change your script when you are done developing and want to run it in production. Simply run the file with `python` instead of `gradio`. + +```python +import gradio as gr + +if gr.NO_RELOAD: + from transformers import pipeline + pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest") + +demo = gr.Interface(lambda s: pipe(s), gr.Textbox(), gr.Label()) + +if __name__ == "__main__": + demo.launch() +``` + + ## Jupyter Notebook Magic 🔮 What about if you use Jupyter Notebooks (or Colab Notebooks, etc.) to develop code? We got something for you too! diff --git a/js/_website/src/lib/components/DocsNav.svelte b/js/_website/src/lib/components/DocsNav.svelte index 7aa5172866..2e5a12bb65 100644 --- a/js/_website/src/lib/components/DocsNav.svelte +++ b/js/_website/src/lib/components/DocsNav.svelte @@ -221,6 +221,11 @@ class="thin-link px-4 block" href="./themes/">Themes + NO_RELOAD + import Demos from "$lib/components/Demos.svelte"; + import DocsNav from "$lib/components/DocsNav.svelte"; + import FunctionDoc from "$lib/components/FunctionDoc.svelte"; + import MetaTags from "$lib/components/MetaTags.svelte"; + import anchor from "$lib/assets/img/anchor.svg"; + import { onDestroy } from "svelte"; + import { page } from "$app/stores"; + + export let data: any; + + let obj = data.obj; + let components = data.components; + let helpers = data.helpers; + let modals = data.modals; + let routes = data.routes; + let headers = data.headers; + let method_headers = data.method_headers; + let py_client = data.py_client; + + let current_selection = 0; + + let y: number; + let header_targets: { [key: string]: HTMLElement } = {}; + let target_elem: HTMLElement; + + onDestroy(() => { + header_targets = {}; + }); + + $: for (const target in header_targets) { + target_elem = document.querySelector(`#${target}`) as HTMLElement; + if ( + y > target_elem?.offsetTop - 50 && + y < target_elem?.offsetTop + target_elem?.offsetHeight + ) { + header_targets[target]?.classList.add("current-nav-link"); + } else { + header_targets[target]?.classList.remove("current-nav-link"); + } + } + + let on_main: boolean; + let wheel: string = data.wheel; + + $: on_main = data.on_main; + $: components = data.components; + $: helpers = data.helpers; + $: modals = data.modals; + $: routes = data.routes; + $: py_client = data.py_client; + + + + + + +
+
+ + +
+
+

+ New to Gradio? Start here: Getting Started +

+

+ See the Release History +

+
+ + {#if on_main} +
+

+ To install Gradio from main, run the following command: +

+
+
pip install {wheel}
+
+

+ *Note: Setting share=True in + launch() will not work. +

+
+ {/if} + + + +
+
+
+
+

+ {obj.name} + +

+
+
+
{obj.override_signature}
+
+ +

+ Description + +

+

{@html obj.description}

+ + {#if obj.example} +

+ Example Usage + +

+
+
{@html obj.highlighted_example}
+
+ {/if} +
+
+
+ +
+ +
+
diff --git a/js/_website/src/routes/[[version]]/docs/python-client/+page.svelte b/js/_website/src/routes/[[version]]/docs/python-client/+page.svelte index 4d1fdd72c8..6ab71f226d 100644 --- a/js/_website/src/routes/[[version]]/docs/python-client/+page.svelte +++ b/js/_website/src/routes/[[version]]/docs/python-client/+page.svelte @@ -76,11 +76,11 @@
- Themes + NO_RELOAD
- Python Client + NO RELOAD
@@ -469,11 +469,11 @@
- Python Client + NO RELOAD
diff --git a/js/app/package.json b/js/app/package.json index cddf993637..255c58f4c7 100644 --- a/js/app/package.json +++ b/js/app/package.json @@ -16,8 +16,9 @@ "build:lite": "pnpm pybuild && pnpm cssbuild && pnpm --filter @gradio/client build && pnpm --filter @gradio/wasm build && vite build --mode production:lite", "preview": "vite preview", "test:snapshot": "pnpm exec playwright test snapshots/ --config=../../.config/playwright.config.js", - "test:browser": "pnpm exec playwright test test/ --config=../../.config/playwright.config.js", + "test:browser": "pnpm exec playwright test test/ --grep-invert 'reload.spec.ts' --config=../../.config/playwright.config.js", "test:browser:dev": "pnpm exec playwright test test/ --ui --config=../../.config/playwright.config.js", + "test:browser:reload": "pnpm exec playwright test test/ --grep 'reload.spec.ts' --config=../../.config/playwright.config.js", "test:browser:lite": "GRADIO_E2E_TEST_LITE=1 pnpm test:browser", "test:browser:lite:dev": "GRADIO_E2E_TEST_LITE=1 pnpm test:browser:dev", "build:css": "pollen -c pollen.config.cjs -o src/pollen-dev.css" diff --git a/js/app/test/hello_world.reload.spec.ts b/js/app/test/hello_world.reload.spec.ts new file mode 100644 index 0000000000..92955f8bba --- /dev/null +++ b/js/app/test/hello_world.reload.spec.ts @@ -0,0 +1,84 @@ +import { test, expect } from "@playwright/test"; +import { spawnSync } from "node:child_process"; +import { launch_app_background, kill_process } from "./utils"; +import { join } from "path"; + +let _process; + +test.beforeAll(() => { + const demo = ` +import gradio as gr + +def greet(name): + return "Hello " + name + "!" + +demo = gr.Interface(fn=greet, inputs="text", outputs="text") + +if __name__ == "__main__": + demo.launch() +`; + // write contents of demo to a local 'run.py' file + spawnSync(`echo '${demo}' > ${join(process.cwd(), "run.py")}`, { + shell: true, + stdio: "pipe", + env: { + ...process.env, + PYTHONUNBUFFERED: "true" + } + }); +}); + +test.afterAll(() => { + if (_process) kill_process(_process); + spawnSync(`rm ${join(process.cwd(), "run.py")}`, { + shell: true, + stdio: "pipe", + env: { + ...process.env, + PYTHONUNBUFFERED: "true" + } + }); +}); + +test("gradio dev mode correctly reloads the page", async ({ page }) => { + test.setTimeout(20 * 1000); + + try { + const port = 7880; + const { _process: server_process } = await launch_app_background( + `GRADIO_SERVER_PORT=${port} gradio ${join(process.cwd(), "run.py")}`, + process.cwd() + ); + _process = server_process; + console.log("Connected to port", port); + const demo = ` +import gradio as gr + +def greet(name): + return "Hello " + name + "!" + +demo = gr.Interface(fn=greet, inputs=gr.Textbox(label="x"), outputs=gr.Textbox(label="foo")) + +if __name__ == "__main__": + demo.launch() + `; + // write contents of demo to a local 'run.py' file + await page.goto(`http://localhost:${port}`); + spawnSync(`echo '${demo}' > ${join(process.cwd(), "run.py")}`, { + shell: true, + stdio: "pipe", + env: { + ...process.env, + PYTHONUNBUFFERED: "true" + } + }); + //await page.reload(); + + await page.getByLabel("x").fill("Maria"); + await page.getByRole("button", { name: "Submit" }).click(); + + await expect(page.getByLabel("foo")).toHaveValue("Hello Maria!"); + } finally { + if (_process) kill_process(_process); + } +}); diff --git a/js/app/test/utils.ts b/js/app/test/utils.ts new file mode 100644 index 0000000000..162e7a2aa2 --- /dev/null +++ b/js/app/test/utils.ts @@ -0,0 +1,61 @@ +import { spawn } from "node:child_process"; + +import type { ChildProcess } from "child_process"; + +export function kill_process(process: ChildProcess) { + process.kill("SIGKILL"); +} + +type LaunchAppBackgroundReturn = { + port: number; + _process: ChildProcess; +}; + +export const launch_app_background = async ( + command: string, + cwd?: string +): Promise => { + const _process = spawn(command, { + shell: true, + stdio: "pipe", + cwd: cwd || process.cwd(), + env: { + ...process.env, + PYTHONUNBUFFERED: "true" + } + }); + + _process.stdout.setEncoding("utf8"); + _process.stderr.setEncoding("utf8"); + + _process.on("exit", () => kill_process(_process)); + _process.on("close", () => kill_process(_process)); + _process.on("disconnect", () => kill_process(_process)); + + let port; + + function std_out(data: any) { + const _data: string = data.toString(); + console.log(_data); + + const portRegExp = /:(\d+)/; + const match = portRegExp.exec(_data); + + if (match && match[1] && _data.includes("Running on local URL:")) { + port = parseInt(match[1], 10); + } + } + + function std_err(data: any) { + const _data: string = data.toString(); + console.log(_data); + } + + _process.stdout.on("data", std_out); + _process.stderr.on("data", std_err); + + while (!port) { + await new Promise((r) => setTimeout(r, 1000)); + } + return { port: port, _process: _process }; +}; diff --git a/package.json b/package.json index 30c18c97a9..24b055ebbb 100644 --- a/package.json +++ b/package.json @@ -17,6 +17,7 @@ "test:run": "pnpm --filter @gradio/client build && vitest run --config .config/vitest.config.ts --reporter=verbose", "test:node": "TEST_MODE=node pnpm vitest run --config .config/vitest.config.ts", "test:browser": "pnpm --filter @gradio/app test:browser", + "test:browser:reload": "CUSTOM_TEST=1 pnpm --filter @gradio/app test:browser:reload", "test:browser:full": "run-s build test:browser", "test:browser:verbose": "pnpm test:browser", "test:browser:dev": "pnpm --filter @gradio/app test:browser:dev",