Do not reload code inside gr.NO_RELOAD context (#7684)

* Add code

* Copy files

* I think its working

* Tidy up

* add changeset

* do not change demos

* test

* Don't copy files

* Add code

* lint

* Add reload mode e2e test

* Reload mode test

* add changeset

* add changeset

* Use NO_RELOAD

* add no reload to docs

* add changeset

* Fix docs

* handle else statements. No need to edit string

* Fix typos

* Use compile

* Do not use unparse

* notebook

* Documentation comments

* Fix top-package import without having to delete all modules

* Revert demo calculator

* Typo guides

* Fix website

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: aliabd <ali.si3luwa@gmail.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-03-21 12:59:53 -07:00 committed by GitHub
parent 43ae23f092
commit 755157f99c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 853 additions and 294 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/app": minor
"gradio": minor
"website": minor
---
feat:Do not reload code inside gr.NO_RELOAD context

View File

@ -16,7 +16,8 @@ const test_files = readdirSync(TEST_FILES_PATH)
(f) =>
f.endsWith("spec.ts") &&
!f.endsWith(".skip.spec.ts") &&
!f.endsWith(".component.spec.ts")
!f.endsWith(".component.spec.ts") &&
!f.endsWith(".reload.spec.ts")
)
.map((f) => basename(f, ".spec.ts"));

View File

@ -24,7 +24,7 @@ const base = defineConfig({
});
const normal = defineConfig(base, {
globalSetup: "./playwright-setup.js"
globalSetup: process.env.CUSTOM_TEST ? undefined : "./playwright-setup.js"
});
normal.projects = undefined; // Explicitly unset this field due to https://github.com/microsoft/playwright/issues/28795

View File

@ -74,6 +74,10 @@ jobs:
run: |
. venv/bin/activate
pnpm run test:ct
- name: run reload mode test
run: |
. venv/bin/activate
pnpm test:browser:reload
- name: Run Lite E2E tests
run: |
. venv/bin/activate

View File

@ -98,6 +98,6 @@ from gradio.templates import (
TextArea,
)
from gradio.themes import Base as Theme
from gradio.utils import get_package_version, set_static_paths
from gradio.utils import NO_RELOAD, get_package_version, set_static_paths
__version__ = get_package_version()

View File

@ -97,7 +97,7 @@ def _setup_config(
print(message + "\n")
# guaranty access to the module of an app
# guarantee access to the module of an app
sys.path.insert(0, os.getcwd())
return module_name, abs_original_path, [str(s) for s in watching_dirs], demo_name
@ -109,17 +109,20 @@ def main(
encoding: str = "utf-8",
):
# default execution pattern to start the server and watch changes
module_name, path, watch_dirs, demo_name = _setup_config(
module_name, path, watch_sources, demo_name = _setup_config(
demo_path, demo_name, watch_dirs, encoding
)
# extra_args = args[1:] if len(args) == 1 or args[1].startswith("--") else args[2:]
# Pass the following data as environment variables
# so that we can set up reload mode correctly in the networking.py module
popen = subprocess.Popen(
[sys.executable, "-u", path],
env=dict(
os.environ,
GRADIO_WATCH_DIRS=",".join(watch_dirs),
GRADIO_WATCH_DIRS=",".join(watch_sources),
GRADIO_WATCH_MODULE_NAME=module_name,
GRADIO_WATCH_DEMO_NAME=demo_name,
GRADIO_WATCH_DEMO_PATH=str(path),
),
)
popen.wait()

View File

@ -1,258 +1,260 @@
"""
Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import os
import socket
import threading
import time
import warnings
from functools import partial
from typing import TYPE_CHECKING
import httpx
import uvicorn
from uvicorn.config import Config
from gradio.exceptions import ServerFailedToStartError
from gradio.routes import App
from gradio.tunneling import Tunnel
from gradio.utils import SourceFileReloader, watchfn
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
GRADIO_SHARE_SERVER_ADDRESS = os.getenv("GRADIO_SHARE_SERVER_ADDRESS")
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
GRADIO_WATCH_DIRS = (
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
)
GRADIO_WATCH_MODULE_NAME = os.getenv("GRADIO_WATCH_MODULE_NAME", "app")
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")
class Server(uvicorn.Server):
def __init__(
self, config: Config, reloader: SourceFileReloader | None = None
) -> None:
self.running_app = config.app
super().__init__(config)
self.reloader = reloader
if self.reloader:
self.event = threading.Event()
self.watch = partial(watchfn, self.reloader)
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
if self.reloader:
self.watch_thread = threading.Thread(target=self.watch, daemon=True)
self.watch_thread.start()
self.thread.start()
start = time.time()
while not self.started:
time.sleep(1e-3)
if time.time() - start > 5:
raise ServerFailedToStartError(
"Server failed to start. Please check that the port is available."
)
def close(self):
self.should_exit = True
if self.reloader:
self.reloader.stop()
self.watch_thread.join()
self.thread.join()
def get_first_available_port(initial: int, final: int) -> int:
"""
Gets the first open port in a specified range of port numbers
Parameters:
initial: the initial value in the range of port numbers
final: final (exclusive) value in the range of port numbers, should be greater than `initial`
Returns:
port: the first open port in the range
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError(
f"All ports from {initial} to {final - 1} are in use. Please close a port."
)
def configure_app(app: App, blocks: Blocks) -> App:
auth = blocks.auth
if auth is not None:
if not callable(auth):
app.auth = {account[0]: account[1] for account in auth}
else:
app.auth = auth
else:
app.auth = None
app.blocks = blocks
app.cwd = os.getcwd()
app.favicon_path = blocks.favicon_path
app.tokens = {}
return app
def start_server(
blocks: Blocks,
server_name: str | None = None,
server_port: int | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
app_kwargs: dict | None = None,
) -> tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
blocks: The Blocks object to run on the server
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor.
Returns:
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
app: the FastAPI app object
server: the server object that is a subclass of uvicorn.Server (used to close the server)
"""
if ssl_keyfile is not None and ssl_certfile is None:
raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.")
server_name = server_name or LOCALHOST_NAME
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
# Strip IPv6 brackets from the address if they exist.
# This is needed as http://[::1]:port/ is a valid browser address,
# but not a valid IPv6 address, so asyncio will throw an exception.
if server_name.startswith("[") and server_name.endswith("]"):
host = server_name[1:-1]
else:
host = server_name
app = App.create_app(blocks, app_kwargs=app_kwargs)
server_ports = (
[server_port]
if server_port is not None
else range(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
)
for port in server_ports:
try:
# The fastest way to check if a port is available is to try to bind to it with socket.
# If the port is not available, socket will throw an OSError.
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Really, we should be checking if (server_name, server_port) is available, but
# socket.bind() doesn't seem to throw an OSError with ipv6 addresses, based on my testing.
# Instead, we just check if the port is available on localhost.
s.bind((LOCALHOST_NAME, port))
s.close()
# To avoid race conditions, so we also check if the port by trying to start the uvicorn server.
# If the port is not available, this will throw a ServerFailedToStartError.
config = uvicorn.Config(
app=app,
port=port,
host=host,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
)
reloader = None
if GRADIO_WATCH_DIRS:
change_event = threading.Event()
app.change_event = change_event
reloader = SourceFileReloader(
app=app,
watch_dirs=GRADIO_WATCH_DIRS,
watch_module_name=GRADIO_WATCH_MODULE_NAME,
demo_name=GRADIO_WATCH_DEMO_NAME,
stop_event=threading.Event(),
change_event=change_event,
)
server = Server(config=config, reloader=reloader)
server.run_in_thread()
break
except (OSError, ServerFailedToStartError):
pass
else:
raise OSError(
f"Cannot find empty port in range: {min(server_ports)}-{max(server_ports)}. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`."
)
if ssl_keyfile is not None:
path_to_local_server = f"https://{url_host_name}:{port}/"
else:
path_to_local_server = f"http://{url_host_name}:{port}/"
return server_name, port, path_to_local_server, app, server
def setup_tunnel(
local_host: str, local_port: int, share_token: str, share_server_address: str | None
) -> str:
share_server_address = (
GRADIO_SHARE_SERVER_ADDRESS
if share_server_address is None
else share_server_address
)
if share_server_address is None:
try:
response = httpx.get(GRADIO_API_SERVER, timeout=30)
payload = response.json()[0]
remote_host, remote_port = payload["host"], int(payload["port"])
except Exception as e:
raise RuntimeError(
"Could not get share link from Gradio API Server."
) from e
else:
remote_host, remote_port = share_server_address.split(":")
remote_port = int(remote_port)
try:
tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token)
address = tunnel.start_tunnel()
return address
except Exception as e:
raise RuntimeError(str(e)) from e
def url_ok(url: str) -> bool:
try:
for _ in range(5):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
r = httpx.head(url, timeout=3, verify=False)
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
return True
time.sleep(0.500)
except (ConnectionError, httpx.ConnectError, httpx.TimeoutException):
return False
return False
"""
Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import os
import socket
import threading
import time
import warnings
from functools import partial
from typing import TYPE_CHECKING
import httpx
import uvicorn
from uvicorn.config import Config
from gradio.exceptions import ServerFailedToStartError
from gradio.routes import App
from gradio.tunneling import Tunnel
from gradio.utils import SourceFileReloader, watchfn
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
GRADIO_SHARE_SERVER_ADDRESS = os.getenv("GRADIO_SHARE_SERVER_ADDRESS")
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
GRADIO_WATCH_DIRS = (
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
)
GRADIO_WATCH_MODULE_NAME = os.getenv("GRADIO_WATCH_MODULE_NAME", "app")
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")
GRADIO_WATCH_DEMO_PATH = os.getenv("GRADIO_WATCH_DEMO_PATH", "")
class Server(uvicorn.Server):
def __init__(
self, config: Config, reloader: SourceFileReloader | None = None
) -> None:
self.running_app = config.app
super().__init__(config)
self.reloader = reloader
if self.reloader:
self.event = threading.Event()
self.watch = partial(watchfn, self.reloader)
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
if self.reloader:
self.watch_thread = threading.Thread(target=self.watch, daemon=True)
self.watch_thread.start()
self.thread.start()
start = time.time()
while not self.started:
time.sleep(1e-3)
if time.time() - start > 5:
raise ServerFailedToStartError(
"Server failed to start. Please check that the port is available."
)
def close(self):
self.should_exit = True
if self.reloader:
self.reloader.stop()
self.watch_thread.join()
self.thread.join()
def get_first_available_port(initial: int, final: int) -> int:
"""
Gets the first open port in a specified range of port numbers
Parameters:
initial: the initial value in the range of port numbers
final: final (exclusive) value in the range of port numbers, should be greater than `initial`
Returns:
port: the first open port in the range
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError(
f"All ports from {initial} to {final - 1} are in use. Please close a port."
)
def configure_app(app: App, blocks: Blocks) -> App:
auth = blocks.auth
if auth is not None:
if not callable(auth):
app.auth = {account[0]: account[1] for account in auth}
else:
app.auth = auth
else:
app.auth = None
app.blocks = blocks
app.cwd = os.getcwd()
app.favicon_path = blocks.favicon_path
app.tokens = {}
return app
def start_server(
blocks: Blocks,
server_name: str | None = None,
server_port: int | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
app_kwargs: dict | None = None,
) -> tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
blocks: The Blocks object to run on the server
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor.
Returns:
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
app: the FastAPI app object
server: the server object that is a subclass of uvicorn.Server (used to close the server)
"""
if ssl_keyfile is not None and ssl_certfile is None:
raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.")
server_name = server_name or LOCALHOST_NAME
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
# Strip IPv6 brackets from the address if they exist.
# This is needed as http://[::1]:port/ is a valid browser address,
# but not a valid IPv6 address, so asyncio will throw an exception.
if server_name.startswith("[") and server_name.endswith("]"):
host = server_name[1:-1]
else:
host = server_name
app = App.create_app(blocks, app_kwargs=app_kwargs)
server_ports = (
[server_port]
if server_port is not None
else range(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
)
for port in server_ports:
try:
# The fastest way to check if a port is available is to try to bind to it with socket.
# If the port is not available, socket will throw an OSError.
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Really, we should be checking if (server_name, server_port) is available, but
# socket.bind() doesn't seem to throw an OSError with ipv6 addresses, based on my testing.
# Instead, we just check if the port is available on localhost.
s.bind((LOCALHOST_NAME, port))
s.close()
# To avoid race conditions, so we also check if the port by trying to start the uvicorn server.
# If the port is not available, this will throw a ServerFailedToStartError.
config = uvicorn.Config(
app=app,
port=port,
host=host,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
)
reloader = None
if GRADIO_WATCH_DIRS:
change_event = threading.Event()
app.change_event = change_event
reloader = SourceFileReloader(
app=app,
watch_dirs=GRADIO_WATCH_DIRS,
watch_module_name=GRADIO_WATCH_MODULE_NAME,
demo_name=GRADIO_WATCH_DEMO_NAME,
stop_event=threading.Event(),
change_event=change_event,
demo_file=GRADIO_WATCH_DEMO_PATH,
)
server = Server(config=config, reloader=reloader)
server.run_in_thread()
break
except (OSError, ServerFailedToStartError):
pass
else:
raise OSError(
f"Cannot find empty port in range: {min(server_ports)}-{max(server_ports)}. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`."
)
if ssl_keyfile is not None:
path_to_local_server = f"https://{url_host_name}:{port}/"
else:
path_to_local_server = f"http://{url_host_name}:{port}/"
return server_name, port, path_to_local_server, app, server
def setup_tunnel(
local_host: str, local_port: int, share_token: str, share_server_address: str | None
) -> str:
share_server_address = (
GRADIO_SHARE_SERVER_ADDRESS
if share_server_address is None
else share_server_address
)
if share_server_address is None:
try:
response = httpx.get(GRADIO_API_SERVER, timeout=30)
payload = response.json()[0]
remote_host, remote_port = payload["host"], int(payload["port"])
except Exception as e:
raise RuntimeError(
"Could not get share link from Gradio API Server."
) from e
else:
remote_host, remote_port = share_server_address.split(":")
remote_port = int(remote_port)
try:
tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token)
address = tunnel.start_tunnel()
return address
except Exception as e:
raise RuntimeError(str(e)) from e
def url_ok(url: str) -> bool:
try:
for _ in range(5):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
r = httpx.head(url, timeout=3, verify=False)
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
return True
time.sleep(0.500)
except (ConnectionError, httpx.ConnectError):
return False
return False

View File

@ -19,6 +19,12 @@ class StateHolder:
self.blocks = blocks
self.capacity = blocks.state_session_capacity
def reset(self, blocks: Blocks):
"""Reset the state holder with new blocks. Used during reload mode."""
self.session_data = OrderedDict()
# Call set blocks again to set new ids
self.set_blocks(blocks)
def __getitem__(self, session_id: str) -> SessionState:
if session_id not in self.session_data:
self.session_data[session_id] = SessionState(self.blocks)

View File

@ -2,17 +2,20 @@
from __future__ import annotations
import ast
import asyncio
import copy
import dataclasses
import functools
import importlib
import importlib.util
import inspect
import json
import json.decoder
import os
import pkgutil
import re
import sys
import tempfile
import threading
import time
@ -26,7 +29,7 @@ from contextlib import contextmanager
from io import BytesIO
from numbers import Number
from pathlib import Path
from types import AsyncGeneratorType, GeneratorType
from types import AsyncGeneratorType, GeneratorType, ModuleType
from typing import (
TYPE_CHECKING,
Any,
@ -104,6 +107,7 @@ class BaseReloader(ABC):
# not a new queue
self.running_app.blocks._queue.block_fns = demo.fns
demo._queue = self.running_app.blocks._queue
self.running_app.state_holder.reset(demo)
self.running_app.blocks = demo
demo._queue.reload()
@ -114,6 +118,7 @@ class SourceFileReloader(BaseReloader):
app: App,
watch_dirs: list[str],
watch_module_name: str,
demo_file: str,
stop_event: threading.Event,
change_event: threading.Event,
demo_name: str = "demo",
@ -125,6 +130,7 @@ class SourceFileReloader(BaseReloader):
self.stop_event = stop_event
self.change_event = change_event
self.demo_name = demo_name
self.demo_file = Path(demo_file)
@property
def running_app(self) -> App:
@ -144,6 +150,51 @@ class SourceFileReloader(BaseReloader):
self.alert_change()
NO_RELOAD = True
def _remove_no_reload_codeblocks(file_path: str):
"""Parse the file, remove the gr.no_reload code blocks, and write the file back to disk.
Parameters:
file_path (str): The path to the file to remove the no_reload code blocks from.
"""
with open(file_path) as file:
code = file.read()
tree = ast.parse(code)
def _is_gr_no_reload(expr: ast.AST) -> bool:
"""Find with gr.no_reload context managers."""
return (
isinstance(expr, ast.If)
and isinstance(expr.test, ast.Attribute)
and isinstance(expr.test.value, ast.Name)
and expr.test.value.id == "gr"
and expr.test.attr == "NO_RELOAD"
)
# Find the positions of the code blocks to load
for node in ast.walk(tree):
if _is_gr_no_reload(node):
assert isinstance(node, ast.If) # noqa: S101
node.body = [ast.Pass(lineno=node.lineno, col_offset=node.col_offset)]
# convert tree to string
code_removed = compile(tree, filename=file_path, mode="exec")
return code_removed
def _find_module(source_file: Path) -> ModuleType:
for s, v in sys.modules.items():
if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str(
source_file
):
return v
raise ValueError(f"Cannot find module for source file: {source_file}")
def watchfn(reloader: SourceFileReloader):
"""Watch python files in a given module.
@ -179,7 +230,6 @@ def watchfn(reloader: SourceFileReloader):
for path in list(reload_dir.rglob("*.css")):
yield path.resolve()
module = None
reload_dirs = [Path(dir_) for dir_ in reloader.watch_dirs]
import sys
@ -187,29 +237,37 @@ def watchfn(reloader: SourceFileReloader):
sys.path.insert(0, str(dir_))
mtimes = {}
# Need to import the module in this thread so that the
# module is available in the namespace of this thread
module = importlib.import_module(reloader.watch_module_name)
while reloader.should_watch():
changed = get_changes()
if changed:
print(f"Changes detected in: {changed}")
# To simulate a fresh reload, delete all module references from sys.modules
# for the modules in the package the change came from.
dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d))
modules = list(sys.modules)
for k in modules:
v = sys.modules[k]
sourcefile = getattr(v, "__file__", None)
# Do not reload `reload.py` to keep thread data
if (
sourcefile
and dir_ == Path(inspect.getfile(gradio)).parent
and sourcefile.endswith("reload.py")
):
continue
if sourcefile and is_in_or_equal(sourcefile, dir_):
del sys.modules[k]
try:
module = importlib.import_module(reloader.watch_module_name)
module = importlib.reload(module)
# How source file reloading works
# 1. Remove the gr.no_reload code blocks from the temp file
# 2. Execute the changed source code in the original module's namespac
# 3. Delete the package the module is in from sys.modules.
# This is so that the updated module is available in the entire package
# 4. Do 1-2 for the main demo file even if it did not change.
# This is because the main demo file may import the changed file and we need the
# changes to be reflected in the main demo file.
changed_in_copy = _remove_no_reload_codeblocks(str(changed))
if changed != reloader.demo_file:
changed_module = _find_module(changed)
exec(changed_in_copy, changed_module.__dict__)
top_level_parent = sys.modules[
changed_module.__name__.split(".")[0]
]
if top_level_parent != changed_module:
importlib.reload(top_level_parent)
changed_demo_file = _remove_no_reload_codeblocks(
str(reloader.demo_file)
)
exec(changed_demo_file, module.__dict__)
except Exception:
print(
f"Reloading {reloader.watch_module_name} failed with the following exception: "
@ -217,7 +275,6 @@ def watchfn(reloader: SourceFileReloader):
traceback.print_exc()
mtimes = {}
continue
demo = getattr(module, reloader.demo_name)
if reloader.queue_changed(demo):
print(

View File

@ -48,7 +48,7 @@ Running on local URL: http://127.0.0.1:7860
The important part here is the line that says `Watching...` What's happening here is that Gradio will be observing the directory where `run.py` file lives, and if the file changes, it will automatically rerun the file for you. So you can focus on writing your code, and your Gradio demo will refresh automatically 🥳
⚠️ Warning: the `gradio` command does not detect the parameters passed to the `launch()` methods because the `launch()` method is never called in reload mode. For example, setting `auth`, or `show_error` in `launch()` will not be reflected in the app.
Tip: the `gradio` command does not detect the parameters passed to the `launch()` methods because the `launch()` method is never called in reload mode. For example, setting `auth`, or `show_error` in `launch()` will not be reflected in the app.
There is one important thing to keep in mind when using the reload mode: Gradio specifically looks for a Gradio Blocks/Interface demo called `demo` in your code. If you have named your demo something else, you will need to pass in the name of your demo as the 2nd parameter in your code. So if your `run.py` file looked like this:
@ -101,6 +101,31 @@ Which you could run like this: `gradio run.py --name Gretel`
As a small aside, this auto-reloading happens if you change your `run.py` source code or the Gradio source code. Meaning that this can be useful if you decide to [contribute to Gradio itself](https://github.com/gradio-app/gradio/blob/main/CONTRIBUTING.md) ✅
## Controlling the Reload 🎛️
By default, reload mode will re-run your entire script for every change you make.
But there are some cases where this is not desirable.
For example, loading a machine learning model should probably only happen once to save time. There are also some Python libraries that use C or Rust extensions that throw errors when they are reloaded, like `numpy` and `tiktoken`.
In these situations, you can place code that you do not want to be re-run inside an `if gr.NO_RELOAD:` codeblock. Here's an example of how you can use it to only load a transformers model once during the development process.
Tip: The value of `gr.NO_RELOAD` is `True`. So you don't have to change your script when you are done developing and want to run it in production. Simply run the file with `python` instead of `gradio`.
```python
import gradio as gr
if gr.NO_RELOAD:
from transformers import pipeline
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
demo = gr.Interface(lambda s: pipe(s), gr.Textbox(), gr.Label())
if __name__ == "__main__":
demo.launch()
```
## Jupyter Notebook Magic 🔮
What about if you use Jupyter Notebooks (or Colab Notebooks, etc.) to develop code? We got something for you too!

View File

@ -221,6 +221,11 @@
class="thin-link px-4 block"
href="./themes/">Themes</a
>
<a
class:current-nav-link={current_nav_link == "no-reload"}
class="thin-link px-4 block"
href="./no-reload/">NO_RELOAD</a
>
<a
class:current-nav-link={current_nav_link == "python-client"}

View File

@ -0,0 +1,73 @@
import Prism from "prismjs";
import "prismjs/components/prism-python";
import { make_slug_processor } from "$lib/utils";
let language = "python";
const COLOR_SETS = [
["from-green-100", "to-green-50"],
["from-yellow-100", "to-yellow-50"],
["from-red-100", "to-red-50"],
["from-blue-100", "to-blue-50"],
["from-pink-100", "to-pink-50"],
["from-purple-100", "to-purple-50"]
];
export async function load({ parent }) {
const {
docs,
components,
helpers,
modals,
py_client,
routes,
on_main,
wheel
} = await parent();
let headers = [
["Description", "description"],
["Example Uage", "example-usage"]
];
let method_headers: string[][] = [];
const get_slug = make_slug_processor();
let obj = {
name: "NO_RELOAD",
description:
"Any code in a `if gr.NO_RELOAD` code-block will not be re-evaluated when the source file is reloaded. This is helpful for importing modules that do not like to be reloaded (tiktoken, numpy) as well as database connections and long running set up code.",
example: `import gradio as gr
if gr.NO_RELOAD:
from transformers import pipeline
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
gr.Interface.from_pipeline(pipe).launch()
`,
override_signature: "if gr.NO_RELOAD:"
};
if (obj.name) {
obj.slug = get_slug(obj.name);
}
if (obj.example) {
obj.highlighted_example = Prism.highlight(
obj.example,
Prism.languages[language],
"python"
);
}
return {
obj,
components,
helpers,
modals,
routes,
py_client,
COLOR_SETS,
headers,
method_headers,
on_main,
wheel
};
}

View File

@ -0,0 +1,229 @@
<script lang="ts">
import Demos from "$lib/components/Demos.svelte";
import DocsNav from "$lib/components/DocsNav.svelte";
import FunctionDoc from "$lib/components/FunctionDoc.svelte";
import MetaTags from "$lib/components/MetaTags.svelte";
import anchor from "$lib/assets/img/anchor.svg";
import { onDestroy } from "svelte";
import { page } from "$app/stores";
export let data: any;
let obj = data.obj;
let components = data.components;
let helpers = data.helpers;
let modals = data.modals;
let routes = data.routes;
let headers = data.headers;
let method_headers = data.method_headers;
let py_client = data.py_client;
let current_selection = 0;
let y: number;
let header_targets: { [key: string]: HTMLElement } = {};
let target_elem: HTMLElement;
onDestroy(() => {
header_targets = {};
});
$: for (const target in header_targets) {
target_elem = document.querySelector(`#${target}`) as HTMLElement;
if (
y > target_elem?.offsetTop - 50 &&
y < target_elem?.offsetTop + target_elem?.offsetHeight
) {
header_targets[target]?.classList.add("current-nav-link");
} else {
header_targets[target]?.classList.remove("current-nav-link");
}
}
let on_main: boolean;
let wheel: string = data.wheel;
$: on_main = data.on_main;
$: components = data.components;
$: helpers = data.helpers;
$: modals = data.modals;
$: routes = data.routes;
$: py_client = data.py_client;
</script>
<MetaTags
title={"Gradio No Reload Docs"}
url={$page.url.pathname}
canonical={$page.url.pathname}
description={obj.description}
/>
<svelte:window bind:scrollY={y} />
<main class="container mx-auto px-4 flex gap-4">
<div class="flex w-full">
<DocsNav
current_nav_link={"no-reload"}
{components}
{helpers}
{modals}
{routes}
{py_client}
/>
<div class="flex flex-col w-full min-w-full lg:w-8/12 lg:min-w-0">
<div>
<p
class="lg:ml-10 bg-gradient-to-r from-orange-100 to-orange-50 border border-orange-200 px-4 py-1 mr-2 rounded-full text-orange-800 mb-1 w-fit float-left"
>
New to Gradio? Start here: <a class="link" href="/quickstart"
>Getting Started</a
>
</p>
<p
class="bg-gradient-to-r from-green-100 to-green-50 border border-green-200 px-4 py-1 rounded-full text-green-800 mb-1 w-fit float-left sm:float-right"
>
See the <a class="link" href="/changelog">Release History</a>
</p>
</div>
{#if on_main}
<div
class="bg-gray-100 border border-gray-200 text-gray-800 px-3 py-1 mt-4 rounded-lg lg:ml-10"
>
<p class="my-2">
To install Gradio from main, run the following command:
</p>
<div class="codeblock">
<pre class="language-bash" style="padding-right: 50px;"><code
class="language-bash">pip install {wheel}</code
></pre>
</div>
<p class="float-right text-sm">
*Note: Setting <code style="font-size: 0.85rem">share=True</code> in
<code style="font-size: 0.85rem">launch()</code> will not work.
</p>
</div>
{/if}
<div class="lg:ml-10 flex justify-between mt-4">
<a
href="./themes"
class="text-left px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
<span class="text-orange-500">&#8592;</span> Themes
</div>
</a>
<a
href="./python-client"
class="text-right px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
Python Client <span class="text-orange-500">&#8594;</span>
</div>
</a>
</div>
<div class="flex flex-row">
<div class="lg:ml-10">
<div class="obj" id={obj.slug}>
<div class="flex flex-row items-center justify-between">
<h3 id="{obj.slug}-header" class="group text-3xl font-light py-4">
{obj.name}
<a
href="#{obj.slug}-header"
class="invisible group-hover-visible"
><img class="anchor-img" src={anchor} /></a
>
</h3>
</div>
<div class="codeblock">
<pre><code class="code language-python"
>{obj.override_signature}</code
></pre>
</div>
<h4
class="mt-8 text-xl text-orange-500 font-light group"
id="description"
>
Description
<a href="#description" class="invisible group-hover-visible"
><img class="anchor-img-small" src={anchor} /></a
>
</h4>
<p class="mb-2 text-lg text-gray-600">{@html obj.description}</p>
{#if obj.example}
<h4
class="mt-4 text-xl text-orange-500 font-light group"
id="example-usage"
>
Example Usage
<a href="#example-usage" class="invisible group-hover-visible"
><img class="anchor-img-small" src={anchor} /></a
>
</h4>
<div class="codeblock">
<pre><code class="code language-python"
>{@html obj.highlighted_example}</code
></pre>
</div>
{/if}
</div>
</div>
</div>
<div class="lg:ml-10 flex justify-between my-4">
<a
href="./themes"
class="text-left px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
<span class="text-orange-500">&#8592;</span> themes
</div>
</a>
<a
href="./python-client"
class="text-right px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
Python Client <span class="text-orange-500">&#8594;</span>
</div>
</a>
</div>
</div>
<div
class="float-right top-8 hidden sticky h-screen overflow-y-auto lg:block lg:w-2/12"
>
<div class="mx-8">
<a class="thin-link py-2 block text-lg" href="#no-reload">NO RELOAD</a>
{#if headers.length > 0}
<ul class="text-slate-700 text-lg leading-6">
{#each headers as header}
<li>
<a
bind:this={header_targets[header[1]]}
href="#{header[1]}"
class="thin-link block py-2 font-light second-nav-link"
>{header[0]}</a
>
</li>
{/each}
{#if method_headers.length > 0}
{#each method_headers as method_header}
<li class="ml-4">
<a
href="#{method_header[1]}"
class="thin-link block py-2 font-light second-nav-link"
>{method_header[0]}</a
>
</li>
{/each}
{/if}
</ul>
{/if}
</div>
</div>
</div>
</main>

View File

@ -76,11 +76,11 @@
<div class="lg:ml-10 flex justify-between mt-4">
<a
href="./themes"
href="./no-reload"
class="text-left px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
<span class="text-orange-500">&#8592;</span> Themes
<span class="text-orange-500">&#8592;</span> NO_RELOAD
</div>
</a>
<a

View File

@ -118,11 +118,11 @@
</div>
</a>
<a
href="./python-client"
href="./no-reload"
class="text-right px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
Python Client <span class="text-orange-500">&#8594;</span>
NO RELOAD <span class="text-orange-500">&#8594;</span>
</div>
</a>
</div>
@ -469,11 +469,11 @@
</div>
</a>
<a
href="./python-client"
href="./no-reload"
class="text-right px-4 py-1 bg-gray-50 rounded-full hover:underline"
>
<div class="text-lg">
Python Client <span class="text-orange-500">&#8594;</span>
NO RELOAD <span class="text-orange-500">&#8594;</span>
</div>
</a>
</div>

View File

@ -16,8 +16,9 @@
"build:lite": "pnpm pybuild && pnpm cssbuild && pnpm --filter @gradio/client build && pnpm --filter @gradio/wasm build && vite build --mode production:lite",
"preview": "vite preview",
"test:snapshot": "pnpm exec playwright test snapshots/ --config=../../.config/playwright.config.js",
"test:browser": "pnpm exec playwright test test/ --config=../../.config/playwright.config.js",
"test:browser": "pnpm exec playwright test test/ --grep-invert 'reload.spec.ts' --config=../../.config/playwright.config.js",
"test:browser:dev": "pnpm exec playwright test test/ --ui --config=../../.config/playwright.config.js",
"test:browser:reload": "pnpm exec playwright test test/ --grep 'reload.spec.ts' --config=../../.config/playwright.config.js",
"test:browser:lite": "GRADIO_E2E_TEST_LITE=1 pnpm test:browser",
"test:browser:lite:dev": "GRADIO_E2E_TEST_LITE=1 pnpm test:browser:dev",
"build:css": "pollen -c pollen.config.cjs -o src/pollen-dev.css"

View File

@ -0,0 +1,84 @@
import { test, expect } from "@playwright/test";
import { spawnSync } from "node:child_process";
import { launch_app_background, kill_process } from "./utils";
import { join } from "path";
let _process;
test.beforeAll(() => {
const demo = `
import gradio as gr
def greet(name):
return "Hello " + name + "!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
if __name__ == "__main__":
demo.launch()
`;
// write contents of demo to a local 'run.py' file
spawnSync(`echo '${demo}' > ${join(process.cwd(), "run.py")}`, {
shell: true,
stdio: "pipe",
env: {
...process.env,
PYTHONUNBUFFERED: "true"
}
});
});
test.afterAll(() => {
if (_process) kill_process(_process);
spawnSync(`rm ${join(process.cwd(), "run.py")}`, {
shell: true,
stdio: "pipe",
env: {
...process.env,
PYTHONUNBUFFERED: "true"
}
});
});
test("gradio dev mode correctly reloads the page", async ({ page }) => {
test.setTimeout(20 * 1000);
try {
const port = 7880;
const { _process: server_process } = await launch_app_background(
`GRADIO_SERVER_PORT=${port} gradio ${join(process.cwd(), "run.py")}`,
process.cwd()
);
_process = server_process;
console.log("Connected to port", port);
const demo = `
import gradio as gr
def greet(name):
return "Hello " + name + "!"
demo = gr.Interface(fn=greet, inputs=gr.Textbox(label="x"), outputs=gr.Textbox(label="foo"))
if __name__ == "__main__":
demo.launch()
`;
// write contents of demo to a local 'run.py' file
await page.goto(`http://localhost:${port}`);
spawnSync(`echo '${demo}' > ${join(process.cwd(), "run.py")}`, {
shell: true,
stdio: "pipe",
env: {
...process.env,
PYTHONUNBUFFERED: "true"
}
});
//await page.reload();
await page.getByLabel("x").fill("Maria");
await page.getByRole("button", { name: "Submit" }).click();
await expect(page.getByLabel("foo")).toHaveValue("Hello Maria!");
} finally {
if (_process) kill_process(_process);
}
});

61
js/app/test/utils.ts Normal file
View File

@ -0,0 +1,61 @@
import { spawn } from "node:child_process";
import type { ChildProcess } from "child_process";
export function kill_process(process: ChildProcess) {
process.kill("SIGKILL");
}
type LaunchAppBackgroundReturn = {
port: number;
_process: ChildProcess;
};
export const launch_app_background = async (
command: string,
cwd?: string
): Promise<LaunchAppBackgroundReturn> => {
const _process = spawn(command, {
shell: true,
stdio: "pipe",
cwd: cwd || process.cwd(),
env: {
...process.env,
PYTHONUNBUFFERED: "true"
}
});
_process.stdout.setEncoding("utf8");
_process.stderr.setEncoding("utf8");
_process.on("exit", () => kill_process(_process));
_process.on("close", () => kill_process(_process));
_process.on("disconnect", () => kill_process(_process));
let port;
function std_out(data: any) {
const _data: string = data.toString();
console.log(_data);
const portRegExp = /:(\d+)/;
const match = portRegExp.exec(_data);
if (match && match[1] && _data.includes("Running on local URL:")) {
port = parseInt(match[1], 10);
}
}
function std_err(data: any) {
const _data: string = data.toString();
console.log(_data);
}
_process.stdout.on("data", std_out);
_process.stderr.on("data", std_err);
while (!port) {
await new Promise((r) => setTimeout(r, 1000));
}
return { port: port, _process: _process };
};

View File

@ -17,6 +17,7 @@
"test:run": "pnpm --filter @gradio/client build && vitest run --config .config/vitest.config.ts --reporter=verbose",
"test:node": "TEST_MODE=node pnpm vitest run --config .config/vitest.config.ts",
"test:browser": "pnpm --filter @gradio/app test:browser",
"test:browser:reload": "CUSTOM_TEST=1 pnpm --filter @gradio/app test:browser:reload",
"test:browser:full": "run-s build test:browser",
"test:browser:verbose": "pnpm test:browser",
"test:browser:dev": "pnpm --filter @gradio/app test:browser:dev",