mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Faster reload mode (#5267)
* This works * Add code * Final touches * Lint * Fix bug in other dirs * add changeset * Reload * lint + test * Load from frontend * add changeset * Use key * tweak frontend config generation * tweak * WIP ipython * Fix robust * fix * Fix for jupyter notebook * Add checks * Lint frontend * Undo demo changes * add changeset * Use is_in_or_equal * python 3.11 changes and no if __name__ * Forward sys.argv + guide * lint --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: pngwn <hello@pngwn.io> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
63b7a3c85e
commit
119c834331
7
.changeset/curvy-signs-pump.md
Normal file
7
.changeset/curvy-signs-pump.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"@gradio/client": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Faster reload mode
|
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
@ -9,6 +9,7 @@
|
||||
"svelte.plugin.svelte.diagnostics.enable": false,
|
||||
"prettier.configPath": ".config/.prettierrc.json",
|
||||
"prettier.ignorePath": ".config/.prettierignore",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.testing.pytestArgs": ["."],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true,
|
||||
|
@ -1133,7 +1133,8 @@ async function resolve_config(
|
||||
if (
|
||||
typeof window !== "undefined" &&
|
||||
window.gradio_config &&
|
||||
location.origin !== "http://localhost:9876"
|
||||
location.origin !== "http://localhost:9876" &&
|
||||
!window.gradio_config.dev_mode
|
||||
) {
|
||||
const path = window.gradio_config.root;
|
||||
const config = window.gradio_config;
|
||||
|
@ -741,7 +741,7 @@ class Blocks(BlockContext):
|
||||
self.space_id = utils.get_space()
|
||||
self.favicon_path = None
|
||||
self.auth = None
|
||||
self.dev_mode = True
|
||||
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", False))
|
||||
self.app_id = random.getrandbits(64)
|
||||
self.temp_file_sets = []
|
||||
self.title = title
|
||||
@ -775,6 +775,12 @@ class Blocks(BlockContext):
|
||||
}
|
||||
analytics.initiated_analytics(data)
|
||||
|
||||
@property
|
||||
def _is_running_in_reload_thread(self):
|
||||
from gradio.reload import reload_thread
|
||||
|
||||
return getattr(reload_thread, "running_reload", False)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
@ -1465,6 +1471,7 @@ Received outputs:
|
||||
config = {
|
||||
"version": routes.VERSION,
|
||||
"mode": self.mode,
|
||||
"app_id": self.app_id,
|
||||
"dev_mode": self.dev_mode,
|
||||
"analytics_enabled": self.analytics_enabled,
|
||||
"components": [],
|
||||
@ -1796,10 +1803,13 @@ Received outputs:
|
||||
demo = gr.Interface(reverse, "text", "text")
|
||||
demo.launch(share=True, auth=("username", "password"))
|
||||
"""
|
||||
if self._is_running_in_reload_thread:
|
||||
# We have already launched the demo
|
||||
return None, None, None # type: ignore
|
||||
|
||||
if not self.exited:
|
||||
self.__exit__()
|
||||
|
||||
self.dev_mode = False
|
||||
if (
|
||||
auth
|
||||
and not callable(auth)
|
||||
@ -2033,11 +2043,10 @@ Received outputs:
|
||||
if self.share and self.share_url:
|
||||
while not networking.url_ok(self.share_url):
|
||||
time.sleep(0.25)
|
||||
display(
|
||||
HTML(
|
||||
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
|
||||
)
|
||||
artifact = HTML(
|
||||
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
|
||||
)
|
||||
|
||||
elif self.is_colab:
|
||||
# modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
|
||||
code = """(async (port, path, width, height, cache, element) => {
|
||||
@ -2072,13 +2081,13 @@ Received outputs:
|
||||
cache=json.dumps(False),
|
||||
)
|
||||
|
||||
display(Javascript(code))
|
||||
artifact = Javascript(code)
|
||||
else:
|
||||
display(
|
||||
HTML(
|
||||
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
|
||||
)
|
||||
artifact = HTML(
|
||||
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
|
||||
)
|
||||
self.artifact = artifact
|
||||
display(artifact)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -29,6 +29,12 @@ class InvalidBlockError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ReloadError(ValueError):
|
||||
"""Raised when something goes wrong when reloading the gradio app."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
InvalidApiName = InvalidApiNameError # backwards compatibility
|
||||
|
||||
|
||||
|
@ -1,23 +1,89 @@
|
||||
try:
|
||||
from IPython.core.magic import needs_local_scope, register_cell_magic
|
||||
from IPython.core.magic import (
|
||||
needs_local_scope,
|
||||
register_cell_magic,
|
||||
)
|
||||
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import warnings
|
||||
|
||||
import gradio as gr
|
||||
from gradio.networking import App
|
||||
from gradio.utils import BaseReloader
|
||||
|
||||
|
||||
class CellIdTracker:
|
||||
"""Determines the most recently run cell in the notebook.
|
||||
|
||||
Needed to keep track of which demo the user is updating.
|
||||
"""
|
||||
|
||||
def __init__(self, ipython):
|
||||
ipython.events.register("pre_run_cell", self.pre_run_cell)
|
||||
self.shell = ipython
|
||||
self.current_cell: str = ""
|
||||
|
||||
def pre_run_cell(self, info):
|
||||
self._current_cell = info.cell_id
|
||||
|
||||
|
||||
class JupyterReloader(BaseReloader):
|
||||
"""Swap a running blocks class in a notebook with the latest cell contents."""
|
||||
|
||||
def __init__(self, ipython) -> None:
|
||||
super().__init__()
|
||||
self._cell_tracker = CellIdTracker(ipython)
|
||||
self._running: dict[str, gr.Blocks] = {}
|
||||
|
||||
@property
|
||||
def current_cell(self):
|
||||
return self._cell_tracker.current_cell
|
||||
|
||||
@property
|
||||
def running_app(self) -> App:
|
||||
assert self.running_demo.server
|
||||
return self.running_demo.server.running_app
|
||||
|
||||
@property
|
||||
def running_demo(self):
|
||||
return self._running[self.current_cell]
|
||||
|
||||
def demo_tracked(self) -> bool:
|
||||
return self.current_cell in self._running
|
||||
|
||||
def track(self, demo: gr.Blocks):
|
||||
self._running[self.current_cell] = demo
|
||||
|
||||
|
||||
def load_ipython_extension(ipython):
|
||||
__demo = gr.Blocks()
|
||||
reloader = JupyterReloader(ipython)
|
||||
|
||||
@magic_arguments()
|
||||
@argument("--demo-name", default="demo", help="Name of gradio blocks instance.")
|
||||
@argument(
|
||||
"--share",
|
||||
default=False,
|
||||
const=True,
|
||||
nargs="?",
|
||||
help="Whether to launch with sharing. Will slow down reloading.",
|
||||
)
|
||||
@register_cell_magic
|
||||
@needs_local_scope
|
||||
def blocks(line, cell, local_ns=None):
|
||||
if "gr.Interface" in cell:
|
||||
warnings.warn(
|
||||
"Usage of gradio.Interface with %%blocks may result in errors."
|
||||
)
|
||||
with __demo.clear():
|
||||
exec(cell, None, local_ns)
|
||||
__demo.launch(quiet=True)
|
||||
def blocks(line, cell, local_ns):
|
||||
"""Launch a demo defined in a cell in reload mode."""
|
||||
|
||||
args = parse_argstring(blocks, line)
|
||||
|
||||
exec(cell, None, local_ns)
|
||||
demo: gr.Blocks = local_ns[args.demo_name]
|
||||
if not reloader.demo_tracked():
|
||||
demo.launch(share=args.share)
|
||||
reloader.track(demo)
|
||||
elif reloader.queue_changed(demo):
|
||||
print("Queue got added or removed. Restarting demo.")
|
||||
reloader.running_demo.close()
|
||||
demo.launch()
|
||||
reloader.track(demo)
|
||||
else:
|
||||
reloader.swap_blocks(demo)
|
||||
return reloader.running_demo.artifact
|
||||
|
@ -9,14 +9,17 @@ import socket
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from uvicorn.config import Config
|
||||
|
||||
from gradio.exceptions import ServerFailedToStartError
|
||||
from gradio.routes import App
|
||||
from gradio.tunneling import Tunnel
|
||||
from gradio.utils import SourceFileReloader, watchfn
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio.blocks import Blocks
|
||||
@ -28,13 +31,34 @@ TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
|
||||
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
|
||||
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
|
||||
|
||||
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", False))
|
||||
GRADIO_WATCH_DIRS = (
|
||||
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
|
||||
)
|
||||
GRADIO_WATCH_FILE = os.getenv("GRADIO_WATCH_FILE", "app")
|
||||
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")
|
||||
|
||||
|
||||
class Server(uvicorn.Server):
|
||||
def __init__(
|
||||
self, config: Config, reloader: SourceFileReloader | None = None
|
||||
) -> None:
|
||||
assert isinstance(config.app, App)
|
||||
self.running_app = config.app
|
||||
super().__init__(config)
|
||||
self.reloader = reloader
|
||||
if self.reloader:
|
||||
self.event = threading.Event()
|
||||
self.watch = partial(watchfn, self.reloader)
|
||||
|
||||
def install_signal_handlers(self):
|
||||
pass
|
||||
|
||||
def run_in_thread(self):
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
if self.reloader:
|
||||
self.watch_thread = threading.Thread(target=self.watch, daemon=True)
|
||||
self.watch_thread.start()
|
||||
self.thread.start()
|
||||
start = time.time()
|
||||
while not self.started:
|
||||
@ -46,6 +70,9 @@ class Server(uvicorn.Server):
|
||||
|
||||
def close(self):
|
||||
self.should_exit = True
|
||||
if self.reloader:
|
||||
self.reloader.stop()
|
||||
self.watch_thread.join()
|
||||
self.thread.join()
|
||||
|
||||
|
||||
@ -160,7 +187,19 @@ def start_server(
|
||||
ssl_keyfile_password=ssl_keyfile_password,
|
||||
ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
|
||||
)
|
||||
server = Server(config=config)
|
||||
reloader = None
|
||||
if GRADIO_WATCH_DIRS:
|
||||
change_event = threading.Event()
|
||||
app.change_event = change_event
|
||||
reloader = SourceFileReloader(
|
||||
app=app,
|
||||
watch_dirs=GRADIO_WATCH_DIRS,
|
||||
watch_file=GRADIO_WATCH_FILE,
|
||||
demo_name=GRADIO_WATCH_DEMO_NAME,
|
||||
stop_event=threading.Event(),
|
||||
change_event=change_event,
|
||||
)
|
||||
server = Server(config=config, reloader=reloader)
|
||||
server.run_in_thread()
|
||||
break
|
||||
except (OSError, ServerFailedToStartError):
|
||||
|
@ -19,7 +19,12 @@ from gradio.data_classes import (
|
||||
ProgressUnit,
|
||||
)
|
||||
from gradio.helpers import TrackedIterable
|
||||
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
|
||||
from gradio.utils import (
|
||||
AsyncRequest,
|
||||
run_coro_in_background,
|
||||
safe_get_lock,
|
||||
set_task_name,
|
||||
)
|
||||
|
||||
|
||||
class Event:
|
||||
@ -59,7 +64,7 @@ class Queue:
|
||||
self.max_thread_count = concurrency_count
|
||||
self.update_intervals = update_intervals
|
||||
self.active_jobs: list[None | list[Event]] = [None] * concurrency_count
|
||||
self.delete_lock = asyncio.Lock()
|
||||
self.delete_lock = safe_get_lock()
|
||||
self.server_path = None
|
||||
self.duration_history_total = 0
|
||||
self.duration_history_count = 0
|
||||
|
@ -3,18 +3,20 @@
|
||||
Contains the functions that run when `gradio` is called from the command line. Specifically, allows
|
||||
|
||||
$ gradio app.py, to run app.py in reload mode where any changes in the app.py file or Gradio library reloads the demo.
|
||||
$ gradio app.py my_demo.app, to use variable names other than "demo"
|
||||
$ gradio app.py my_demo, to use variable names other than "demo"
|
||||
"""
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn.supervisors import ChangeReload
|
||||
|
||||
import gradio
|
||||
from gradio import networking, utils
|
||||
from gradio import utils
|
||||
|
||||
reload_thread = threading.local()
|
||||
|
||||
|
||||
def _setup_config():
|
||||
@ -22,15 +24,34 @@ def _setup_config():
|
||||
if len(args) == 0:
|
||||
raise ValueError("No file specified.")
|
||||
if len(args) == 1 or args[1].startswith("--"):
|
||||
demo_name = "demo.app"
|
||||
demo_name = "demo"
|
||||
else:
|
||||
demo_name = args[1]
|
||||
if "." not in demo_name:
|
||||
if "." in demo_name:
|
||||
demo_name = demo_name.split(".")[0]
|
||||
print(
|
||||
"\nWARNING: As of Gradio 3.31, the parameter after the file path must be the name of the FastAPI app, not the Gradio demo. In most cases, this just means you should add '.app' after the name of your demo, e.g. 'demo' -> 'demo.app'."
|
||||
"\nWARNING: As of Gradio 3.41.0, the parameter after the file path must be the name of the Gradio demo, not the FastAPI app. In most cases, this just means you should remove '.app' after the name of your demo, e.g. 'demo.app' -> 'demo'."
|
||||
)
|
||||
|
||||
original_path = args[0]
|
||||
app_text = Path(original_path).read_text()
|
||||
|
||||
patterns = [
|
||||
f"with gr\\.Blocks\\(\\) as {demo_name}",
|
||||
f"{demo_name} = gr\\.Blocks",
|
||||
f"{demo_name} = gr\\.Interface",
|
||||
f"{demo_name} = gr\\.ChatInterface",
|
||||
f"{demo_name} = gr\\.Series",
|
||||
f"{demo_name} = gr\\.Paralles",
|
||||
f"{demo_name} = gr\\.TabbedInterface",
|
||||
]
|
||||
|
||||
if not any(re.search(p, app_text) for p in patterns):
|
||||
print(
|
||||
f"\nWarning: Cannot statically find a gradio demo called {demo_name}. "
|
||||
"Reload work may fail."
|
||||
)
|
||||
|
||||
abs_original_path = utils.abspath(original_path)
|
||||
path = os.path.normpath(original_path)
|
||||
path = path.replace("/", ".")
|
||||
@ -39,15 +60,6 @@ def _setup_config():
|
||||
|
||||
gradio_folder = Path(inspect.getfile(gradio)).parent
|
||||
|
||||
port = networking.get_first_available_port(
|
||||
networking.INITIAL_PORT_VALUE,
|
||||
networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
|
||||
)
|
||||
print(
|
||||
f"\nLaunching in *reload mode* on: http://{networking.LOCALHOST_NAME}:{port} (Press CTRL+C to quit)\n"
|
||||
)
|
||||
|
||||
gradio_app = f"{filename}:{demo_name}"
|
||||
message = "Watching:"
|
||||
message_change_count = 0
|
||||
|
||||
@ -68,23 +80,24 @@ def _setup_config():
|
||||
|
||||
# guaranty access to the module of an app
|
||||
sys.path.insert(0, os.getcwd())
|
||||
|
||||
# uvicorn.run blocks the execution (looping) which makes it hard to test
|
||||
return Config(
|
||||
gradio_app,
|
||||
reload=True,
|
||||
port=port,
|
||||
log_level="warning",
|
||||
reload_dirs=watching_dirs,
|
||||
)
|
||||
return filename, abs_original_path, [str(s) for s in watching_dirs], demo_name
|
||||
|
||||
|
||||
def main():
|
||||
# default execution pattern to start the server and watch changes
|
||||
config = _setup_config()
|
||||
server = networking.Server(config)
|
||||
sock = config.bind_socket()
|
||||
ChangeReload(config, target=server.run, sockets=[sock]).run()
|
||||
filename, path, watch_dirs, demo_name = _setup_config()
|
||||
args = sys.argv[1:]
|
||||
extra_args = args[1:] if len(args) == 1 or args[1].startswith("--") else args[2:]
|
||||
popen = subprocess.Popen(
|
||||
["python", path] + extra_args,
|
||||
env=dict(
|
||||
os.environ,
|
||||
GRADIO_WATCH_DIRS=",".join(watch_dirs),
|
||||
GRADIO_WATCH_FILE=filename,
|
||||
GRADIO_WATCH_DEMO_NAME=demo_name,
|
||||
),
|
||||
)
|
||||
popen.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,6 +17,7 @@ import os
|
||||
import posixpath
|
||||
import secrets
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from asyncio import TimeoutError as AsyncTimeOutError
|
||||
@ -112,12 +113,13 @@ class App(FastAPI):
|
||||
self.state_holder = {}
|
||||
self.iterators = defaultdict(dict)
|
||||
self.iterators_to_reset = defaultdict(set)
|
||||
self.lock = asyncio.Lock()
|
||||
self.lock = utils.safe_get_lock()
|
||||
self.queue_token = secrets.token_urlsafe(32)
|
||||
self.startup_events_triggered = False
|
||||
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
|
||||
Path(tempfile.gettempdir()) / "gradio"
|
||||
)
|
||||
self.change_event: None | threading.Event = None
|
||||
# Allow user to manually set `docs_url` and `redoc_url`
|
||||
# when instantiating an App; when they're not set, disable docs and redoc.
|
||||
kwargs.setdefault("docs_url", None)
|
||||
@ -216,6 +218,39 @@ class App(FastAPI):
|
||||
def app_id(request: fastapi.Request) -> dict:
|
||||
return {"app_id": app.get_blocks().app_id}
|
||||
|
||||
async def send_ping_periodically(websocket: WebSocket):
|
||||
while True:
|
||||
await websocket.send_text("PING")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def listen_for_changes(websocket: WebSocket):
|
||||
assert app.change_event
|
||||
while True:
|
||||
if app.change_event.is_set():
|
||||
await websocket.send_text("CHANGE")
|
||||
app.change_event.clear()
|
||||
await asyncio.sleep(0.1) # Short sleep to not make this a tight loop
|
||||
|
||||
@app.websocket("/dev/reload")
|
||||
async def notify_changes(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
ping = asyncio.create_task(send_ping_periodically(websocket))
|
||||
notify = asyncio.create_task(listen_for_changes(websocket))
|
||||
tasks = {ping, notify}
|
||||
ping.add_done_callback(tasks.remove)
|
||||
notify.add_done_callback(tasks.remove)
|
||||
done, pending = await asyncio.wait(
|
||||
[ping, notify],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if any(isinstance(task.exception(), Exception) for task in done):
|
||||
await websocket.close()
|
||||
|
||||
@app.post("/login")
|
||||
@app.post("/login/")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
|
158
gradio/utils.py
158
gradio/utils.py
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import json.decoder
|
||||
@ -13,9 +14,12 @@ import pkgutil
|
||||
import pprint
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
@ -27,6 +31,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterator,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -47,6 +52,7 @@ from gradio.strings import en
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio.blocks import Block, BlockContext, Blocks
|
||||
from gradio.components import Component
|
||||
from gradio.routes import App
|
||||
|
||||
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
|
||||
GRADIO_VERSION = (
|
||||
@ -57,6 +63,158 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def safe_get_lock() -> asyncio.Lock:
|
||||
"""Get asyncio.Lock() without fear of getting an Exception.
|
||||
|
||||
Needed because in reload mode we import the Blocks object outside
|
||||
the main thread.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
return asyncio.Lock()
|
||||
except RuntimeError:
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
class BaseReloader(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def running_app(self) -> App:
|
||||
pass
|
||||
|
||||
def queue_changed(self, demo: Blocks):
|
||||
return (
|
||||
hasattr(self.running_app.blocks, "_queue") and not hasattr(demo, "_queue")
|
||||
) or (
|
||||
not hasattr(self.running_app.blocks, "_queue") and hasattr(demo, "_queue")
|
||||
)
|
||||
|
||||
def swap_blocks(self, demo: Blocks):
|
||||
assert self.running_app.blocks
|
||||
# Copy over the blocks to get new components and events but
|
||||
# not a new queue
|
||||
if hasattr(self.running_app.blocks, "_queue"):
|
||||
self.running_app.blocks._queue.blocks_dependencies = demo.dependencies
|
||||
demo._queue = self.running_app.blocks._queue
|
||||
self.running_app.blocks = demo
|
||||
|
||||
|
||||
class SourceFileReloader(BaseReloader):
|
||||
def __init__(
|
||||
self,
|
||||
app: App,
|
||||
watch_dirs: list[str],
|
||||
watch_file: str,
|
||||
stop_event: threading.Event,
|
||||
change_event: threading.Event,
|
||||
demo_name: str = "demo",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.watch_dirs = watch_dirs
|
||||
self.watch_file = watch_file
|
||||
self.stop_event = stop_event
|
||||
self.change_event = change_event
|
||||
self.demo_name = demo_name
|
||||
|
||||
@property
|
||||
def running_app(self) -> App:
|
||||
return self.app
|
||||
|
||||
def should_watch(self) -> bool:
|
||||
return not self.stop_event.is_set()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.stop_event.set()
|
||||
|
||||
def alert_change(self):
|
||||
self.change_event.set()
|
||||
|
||||
def swap_blocks(self, demo: Blocks):
|
||||
super().swap_blocks(demo)
|
||||
self.alert_change()
|
||||
|
||||
|
||||
def watchfn(reloader: SourceFileReloader):
|
||||
"""Watch python files in a given module.
|
||||
|
||||
get_changes is taken from uvicorn's default file watcher.
|
||||
"""
|
||||
|
||||
# The thread running watchfn will be the thread reloading
|
||||
# the app. So we need to modify this thread_data attr here
|
||||
# so that subsequent calls to reload don't launch the app
|
||||
from gradio.reload import reload_thread
|
||||
|
||||
reload_thread.running_reload = True
|
||||
|
||||
def get_changes() -> Path | None:
|
||||
for file in iter_py_files():
|
||||
try:
|
||||
mtime = file.stat().st_mtime
|
||||
except OSError: # pragma: nocover
|
||||
continue
|
||||
|
||||
old_time = mtimes.get(file)
|
||||
if old_time is None:
|
||||
mtimes[file] = mtime
|
||||
continue
|
||||
elif mtime > old_time:
|
||||
return file
|
||||
return None
|
||||
|
||||
def iter_py_files() -> Iterator[Path]:
|
||||
for reload_dir in reload_dirs:
|
||||
for path in list(reload_dir.rglob("*.py")):
|
||||
yield path.resolve()
|
||||
|
||||
module = None
|
||||
reload_dirs = [Path(dir_) for dir_ in reloader.watch_dirs]
|
||||
mtimes = {}
|
||||
while reloader.should_watch():
|
||||
import sys
|
||||
|
||||
changed = get_changes()
|
||||
if changed:
|
||||
print(f"Changes detected in: {changed}")
|
||||
# To simulate a fresh reload, delete all module references from sys.modules
|
||||
# for the modules in the package the change came from.
|
||||
dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d))
|
||||
modules = list(sys.modules)
|
||||
for k in modules:
|
||||
v = sys.modules[k]
|
||||
sourcefile = getattr(v, "__file__", None)
|
||||
# Do not reload `reload.py` to keep thread data
|
||||
if (
|
||||
sourcefile
|
||||
and dir_ == Path(inspect.getfile(gradio)).parent
|
||||
and sourcefile.endswith("reload.py")
|
||||
):
|
||||
continue
|
||||
if sourcefile and is_in_or_equal(sourcefile, dir_):
|
||||
del sys.modules[k]
|
||||
try:
|
||||
module = importlib.import_module(reloader.watch_file)
|
||||
module = importlib.reload(module)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Reloading {reloader.watch_file} failed with the following exception: "
|
||||
)
|
||||
traceback.print_exception(None, value=e, tb=None)
|
||||
mtimes = {}
|
||||
continue
|
||||
|
||||
demo = getattr(module, reloader.demo_name)
|
||||
if reloader.queue_changed(demo):
|
||||
print(
|
||||
"Reloading failed. The new demo has a queue and the old one doesn't (or vice versa). "
|
||||
"Please launch your demo again"
|
||||
)
|
||||
else:
|
||||
reloader.swap_blocks(demo)
|
||||
mtimes = {}
|
||||
|
||||
|
||||
def colab_check() -> bool:
|
||||
"""
|
||||
Check if interface is launching from Google Colab
|
||||
|
@ -41,18 +41,16 @@ In the terminal, run `gradio run.py`. That's it!
|
||||
Now, you'll see that after you'll see something like this:
|
||||
|
||||
```bash
|
||||
Launching in *reload mode* on: http://127.0.0.1:7860 (Press CTRL+C to quit)
|
||||
Watching: '/Users/freddy/sources/gradio/gradio', '/Users/freddy/sources/gradio/demo/'
|
||||
|
||||
Watching...
|
||||
|
||||
WARNING: The --reload flag should not be used in production on Windows.
|
||||
Running on local URL: http://127.0.0.1:7860
|
||||
```
|
||||
|
||||
The important part here is the line that says `Watching...` What's happening here is that Gradio will be observing the directory where `run.py` file lives, and if the file changes, it will automatically rerun the file for you. So you can focus on writing your code, and your Gradio demo will refresh automatically 🥳
|
||||
|
||||
⚠️ Warning: the `gradio` command does not detect the parameters passed to the `launch()` methods because the `launch()` method is never called in reload mode. For example, setting `auth`, or `show_error` in `launch()` will not be reflected in the app.
|
||||
|
||||
There is one important thing to keep in mind when using the reload mode: Gradio specifically looks for a Gradio Blocks/Interface demo called `demo` in your code. If you have named your demo something else, you will need to pass in the name of your demo's FastAPI app as the 2nd parameter in your code. For Gradio demos, the FastAPI app can be accessed using the `.app` attribute. So if your `run.py` file looked like this:
|
||||
There is one important thing to keep in mind when using the reload mode: Gradio specifically looks for a Gradio Blocks/Interface demo called `demo` in your code. If you have named your demo something else, you will need to pass in the name of your demo as the 2nd parameter in your code. So if your `run.py` file looked like this:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
@ -70,7 +68,7 @@ if __name__ == "__main__":
|
||||
my_demo.launch()
|
||||
```
|
||||
|
||||
Then you would launch it in reload mode like this: `gradio run.py my_demo.app`.
|
||||
Then you would launch it in reload mode like this: `gradio run.py my_demo`.
|
||||
|
||||
🔥 If your application accepts command line arguments, you can pass them in as well. Here's an example:
|
||||
|
||||
@ -112,26 +110,25 @@ Then, in the cell that you are developing your Gradio demo, simply write the mag
|
||||
|
||||
import gradio as gr
|
||||
|
||||
gr.Markdown("# Greetings from Gradio!")
|
||||
inp = gr.Textbox(placeholder="What is your name?")
|
||||
out = gr.Textbox()
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(f"# Greetings {args.name}!")
|
||||
inp = gr.Textbox()
|
||||
out = gr.Textbox()
|
||||
|
||||
inp.change(fn=lambda x: f"Welcome, {x}!",
|
||||
inputs=inp,
|
||||
outputs=out)
|
||||
inp.change(fn=lambda x: x, inputs=inp, outputs=out)
|
||||
```
|
||||
|
||||
Notice that:
|
||||
|
||||
- You do not need to put the boiler plate `with gr.Blocks() as demo:` and `demo.launch()` code — Gradio does that for you automatically!
|
||||
- You do not need to launch your demo — Gradio does that for you automatically!
|
||||
|
||||
- Every time you rerun the cell, Gradio will re-launch your app on the same port and using the same underlying web server. This means you'll see your changes _much, much faster_ than if you were rerunning the cell normally.
|
||||
- Every time you rerun the cell, Gradio will re-render your app on the same port and using the same underlying web server. This means you'll see your changes _much, much faster_ than if you were rerunning the cell normally.
|
||||
|
||||
Here's what it looks like in a jupyter notebook:
|
||||
|
||||
![](https://i.ibb.co/nrszFws/Blocks.gif)
|
||||
![](https://gradio-builds.s3.amazonaws.com/demo-files/jupyter_reload.gif)
|
||||
|
||||
🪄 This works in colab notebooks too! [Here's a colab notebook](https://colab.research.google.com/drive/1jUlX1w7JqckRHVE-nbDyMPyZ7fYD8488?authuser=1#scrollTo=zxHYjbCTTz_5) where you can see the Blocks magic in action. Try making some changes and re-running the cell with the Gradio code!
|
||||
🪄 This works in colab notebooks too! [Here's a colab notebook](https://colab.research.google.com/drive/1zAuWoiTIb3O2oitbtVb2_ekv1K6ggtC1?usp=sharing) where you can see the Blocks magic in action. Try making some changes and re-running the cell with the Gradio code!
|
||||
|
||||
The Notebook Magic is now the author's preferred way of building Gradio demos. Regardless of how you write Python code, we hope either of these methods will give you a much better development experience using Gradio.
|
||||
|
||||
|
@ -52,12 +52,10 @@
|
||||
type: "column",
|
||||
props: { mode: "static" },
|
||||
has_modes: false,
|
||||
instance: {} as ComponentMeta["instance"],
|
||||
component: {} as ComponentMeta["component"]
|
||||
instance: null as unknown as ComponentMeta["instance"],
|
||||
component: null as unknown as ComponentMeta["component"]
|
||||
};
|
||||
|
||||
components.push(rootNode);
|
||||
|
||||
const AsyncFunction = Object.getPrototypeOf(async function () {}).constructor;
|
||||
dependencies.forEach((d) => {
|
||||
if (d.js) {
|
||||
@ -103,18 +101,7 @@
|
||||
return false;
|
||||
}
|
||||
|
||||
const dynamic_ids: Set<number> = new Set();
|
||||
for (const comp of components) {
|
||||
const { id, props } = comp;
|
||||
const is_input = is_dep(id, "inputs", dependencies);
|
||||
if (
|
||||
is_input ||
|
||||
(!is_dep(id, "outputs", dependencies) &&
|
||||
has_no_default_value(props?.value))
|
||||
) {
|
||||
dynamic_ids.add(id);
|
||||
}
|
||||
}
|
||||
let dynamic_ids: Set<number> = new Set();
|
||||
|
||||
function has_no_default_value(value: any): boolean {
|
||||
return (
|
||||
@ -125,13 +112,7 @@
|
||||
);
|
||||
}
|
||||
|
||||
let instance_map = components.reduce(
|
||||
(acc, next) => {
|
||||
acc[next.id] = next;
|
||||
return acc;
|
||||
},
|
||||
{} as { [id: number]: ComponentMeta }
|
||||
);
|
||||
let instance_map: { [id: number]: ComponentMeta };
|
||||
|
||||
type LoadedComponent = {
|
||||
default: ComponentMeta["component"];
|
||||
@ -172,58 +153,124 @@
|
||||
}
|
||||
}
|
||||
|
||||
const component_set = new Set<
|
||||
let component_set = new Set<
|
||||
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
|
||||
>();
|
||||
|
||||
const _component_map = new Map<
|
||||
let _component_map = new Map<
|
||||
`${ComponentMeta["type"]}_${ComponentMeta["props"]["mode"]}`,
|
||||
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
|
||||
>();
|
||||
const _type_for_id = new Map<number, ComponentMeta["props"]["mode"]>();
|
||||
|
||||
async function walk_layout(node: LayoutNode): Promise<void> {
|
||||
async function walk_layout(
|
||||
node: LayoutNode,
|
||||
type_map: Map<number, ComponentMeta["props"]["mode"]>,
|
||||
instance_map: { [id: number]: ComponentMeta },
|
||||
component_map: Map<
|
||||
`${ComponentMeta["type"]}_${ComponentMeta["props"]["mode"]}`,
|
||||
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
|
||||
>
|
||||
): Promise<void> {
|
||||
ready = false;
|
||||
let instance = instance_map[node.id];
|
||||
|
||||
const _component = (await _component_map.get(
|
||||
`${instance.type}_${_type_for_id.get(node.id) || "static"}`
|
||||
const _component = (await component_map.get(
|
||||
`${instance.type}_${type_map.get(node.id) || "static"}`
|
||||
))!.component;
|
||||
instance.component = _component.default;
|
||||
|
||||
if (node.children) {
|
||||
instance.children = node.children.map((v) => instance_map[v.id]);
|
||||
await Promise.all(node.children.map((v) => walk_layout(v)));
|
||||
await Promise.all(
|
||||
node.children.map((v) =>
|
||||
walk_layout(v, type_map, instance_map, component_map)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
components.forEach((c) => {
|
||||
if ((c.props as any).interactive === false) {
|
||||
(c.props as any).mode = "static";
|
||||
} else if ((c.props as any).interactive === true) {
|
||||
(c.props as any).mode = "interactive";
|
||||
} else if (dynamic_ids.has(c.id)) {
|
||||
(c.props as any).mode = "interactive";
|
||||
} else {
|
||||
(c.props as any).mode = "static";
|
||||
}
|
||||
_type_for_id.set(c.id, c.props.mode);
|
||||
|
||||
const _c = load_component(c.type, c.props.mode);
|
||||
component_set.add(_c);
|
||||
_component_map.set(`${c.type}_${c.props.mode}`, _c);
|
||||
});
|
||||
|
||||
export let ready = false;
|
||||
export let render_complete = false;
|
||||
Promise.all(Array.from(component_set)).then(() => {
|
||||
walk_layout(layout)
|
||||
.then(async () => {
|
||||
ready = true;
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(e);
|
||||
});
|
||||
});
|
||||
|
||||
$: components, layout, prepare_components();
|
||||
|
||||
function prepare_components(): void {
|
||||
loading_status = create_loading_status_store();
|
||||
|
||||
dependencies.forEach((v, i) => {
|
||||
loading_status.register(i, v.inputs, v.outputs);
|
||||
});
|
||||
|
||||
const _dynamic_ids = new Set<number>();
|
||||
for (const comp of components) {
|
||||
const { id, props } = comp;
|
||||
const is_input = is_dep(id, "inputs", dependencies);
|
||||
if (
|
||||
is_input ||
|
||||
(!is_dep(id, "outputs", dependencies) &&
|
||||
has_no_default_value(props?.value))
|
||||
) {
|
||||
_dynamic_ids.add(id);
|
||||
}
|
||||
}
|
||||
|
||||
dynamic_ids = _dynamic_ids;
|
||||
|
||||
const _rootNode: typeof rootNode = {
|
||||
id: layout.id,
|
||||
type: "column",
|
||||
props: { mode: "static" },
|
||||
has_modes: false,
|
||||
instance: null as unknown as ComponentMeta["instance"],
|
||||
component: null as unknown as ComponentMeta["component"]
|
||||
};
|
||||
components.push(_rootNode);
|
||||
const _component_set = new Set<
|
||||
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
|
||||
>();
|
||||
const __component_map = new Map<
|
||||
`${ComponentMeta["type"]}_${ComponentMeta["props"]["mode"]}`,
|
||||
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
|
||||
>();
|
||||
const __type_for_id = new Map<number, ComponentMeta["props"]["mode"]>();
|
||||
const _instance_map = components.reduce(
|
||||
(acc, next) => {
|
||||
acc[next.id] = next;
|
||||
return acc;
|
||||
},
|
||||
{} as { [id: number]: ComponentMeta }
|
||||
);
|
||||
components.forEach((c) => {
|
||||
if ((c.props as any).interactive === false) {
|
||||
(c.props as any).mode = "static";
|
||||
} else if ((c.props as any).interactive === true) {
|
||||
(c.props as any).mode = "interactive";
|
||||
} else if (dynamic_ids.has(c.id)) {
|
||||
(c.props as any).mode = "interactive";
|
||||
} else {
|
||||
(c.props as any).mode = "static";
|
||||
}
|
||||
__type_for_id.set(c.id, c.props.mode);
|
||||
|
||||
const _c = load_component(c.type, c.props.mode);
|
||||
_component_set.add(_c);
|
||||
__component_map.set(`${c.type}_${c.props.mode}`, _c);
|
||||
});
|
||||
|
||||
Promise.all(Array.from(_component_set)).then(() => {
|
||||
walk_layout(layout, __type_for_id, _instance_map, __component_map)
|
||||
.then(async () => {
|
||||
ready = true;
|
||||
component_set = _component_set;
|
||||
_component_map = __component_map;
|
||||
instance_map = _instance_map;
|
||||
rootNode = _rootNode;
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(e);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function update_interactive_mode(
|
||||
instance: ComponentMeta,
|
||||
@ -494,6 +541,26 @@
|
||||
const is_external_url = (link: string | null): boolean =>
|
||||
!!(link && new URL(link, location.href).origin !== location.origin);
|
||||
|
||||
$: target_map = dependencies.reduce(
|
||||
(acc, dep, i) => {
|
||||
let { targets, trigger } = dep;
|
||||
|
||||
targets.forEach((id) => {
|
||||
if (!acc[id]) {
|
||||
acc[id] = {};
|
||||
}
|
||||
if (acc[id]?.[trigger]) {
|
||||
acc[id][trigger].push(i);
|
||||
} else {
|
||||
acc[id][trigger] = [i];
|
||||
}
|
||||
});
|
||||
|
||||
return acc;
|
||||
},
|
||||
{} as Record<number, Record<string, number[]>>
|
||||
);
|
||||
|
||||
async function handle_mount(): Promise<void> {
|
||||
await tick();
|
||||
|
||||
@ -515,25 +582,6 @@
|
||||
}
|
||||
});
|
||||
|
||||
dependencies.forEach((dep, i) => {
|
||||
let { targets, trigger, inputs, outputs } = dep;
|
||||
const target_instances: [number, ComponentMeta][] = targets.map((t) => [
|
||||
t,
|
||||
instance_map[t]
|
||||
]);
|
||||
|
||||
target_instances.forEach(([id]) => {
|
||||
if (!target_map[id]) {
|
||||
target_map[id] = {};
|
||||
}
|
||||
if (target_map[id]?.[trigger]) {
|
||||
target_map[id][trigger].push(i);
|
||||
} else {
|
||||
target_map[id][trigger] = [i];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
target.addEventListener("gradio", (e: Event) => {
|
||||
if (!isCustomEvent(e)) throw new Error("not a custom event");
|
||||
|
||||
@ -544,12 +592,12 @@
|
||||
trigger_share(title, description);
|
||||
} else if (event === "error") {
|
||||
messages = [new_message(data, -1, "error"), ...messages];
|
||||
} else {
|
||||
const deps = target_map[id]?.[event];
|
||||
deps?.forEach((dep_id) => {
|
||||
trigger_api_call(dep_id, data);
|
||||
});
|
||||
}
|
||||
|
||||
const deps = target_map[id]?.[event];
|
||||
deps?.forEach((dep_id) => {
|
||||
trigger_api_call(dep_id, data);
|
||||
});
|
||||
});
|
||||
|
||||
render_complete = true;
|
||||
@ -563,10 +611,6 @@
|
||||
|
||||
$: set_status($loading_status);
|
||||
|
||||
dependencies.forEach((v, i) => {
|
||||
loading_status.register(i, v.inputs, v.outputs);
|
||||
});
|
||||
|
||||
function set_status(statuses: LoadingStatusCollection): void {
|
||||
for (const id in statuses) {
|
||||
let loading_status = statuses[id];
|
||||
@ -582,8 +626,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
const target_map: Record<number, Record<string, number[]>> = {};
|
||||
|
||||
function isCustomEvent(event: Event): event is CustomEvent {
|
||||
return "detail" in event;
|
||||
}
|
||||
|
@ -27,6 +27,8 @@
|
||||
is_colab: boolean;
|
||||
show_api: boolean;
|
||||
stylesheets?: string[];
|
||||
path: string;
|
||||
app_id?: string;
|
||||
}
|
||||
|
||||
let id = -1;
|
||||
@ -81,6 +83,7 @@
|
||||
export let container: boolean;
|
||||
export let info: boolean;
|
||||
export let eager: boolean;
|
||||
let websocket: WebSocket;
|
||||
|
||||
// These utilities are exported to be injectable for the Wasm version.
|
||||
export let mount_css: typeof default_mount_css = default_mount_css;
|
||||
@ -103,6 +106,10 @@
|
||||
let loading_text = $_("common.loading") + "...";
|
||||
let active_theme_mode: ThemeMode;
|
||||
|
||||
$: if (config?.app_id) {
|
||||
app_id = config.app_id;
|
||||
}
|
||||
|
||||
async function mount_custom_css(
|
||||
target: HTMLElement,
|
||||
css_string: string | null
|
||||
@ -127,18 +134,6 @@
|
||||
);
|
||||
}
|
||||
|
||||
async function reload_check(root: string): Promise<void> {
|
||||
const result = await (await fetch(root + "/app_id")).text();
|
||||
|
||||
if (app_id === null) {
|
||||
app_id = result;
|
||||
} else if (app_id != result) {
|
||||
location.reload();
|
||||
}
|
||||
|
||||
setTimeout(() => reload_check(root), 250);
|
||||
}
|
||||
|
||||
function handle_darkmode(target: HTMLDivElement): "light" | "dark" {
|
||||
let url = new URL(window.location.toString());
|
||||
let url_color_mode: ThemeMode | null = url.searchParams.get(
|
||||
@ -225,7 +220,22 @@
|
||||
window.__is_colab__ = config.is_colab;
|
||||
|
||||
if (config.dev_mode) {
|
||||
reload_check(config.root);
|
||||
setTimeout(() => {
|
||||
const { host } = new URL(api_url);
|
||||
let url = new URL(`ws://${host}/dev/reload`);
|
||||
websocket = new WebSocket(url);
|
||||
websocket.onmessage = async function (event) {
|
||||
if (event.data === "CHANGE") {
|
||||
app = await client(api_url, {
|
||||
status_callback: handle_status,
|
||||
normalise_files: false
|
||||
});
|
||||
app.config.root = app.config.path;
|
||||
config = app.config;
|
||||
window.__gradio_space__ = config.space_id;
|
||||
}
|
||||
};
|
||||
}, 200);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,15 +17,23 @@ def build_demo():
|
||||
return demo
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
filename: str
|
||||
path: Path
|
||||
watch_dirs: List[str]
|
||||
demo_name: str
|
||||
|
||||
|
||||
class TestReload:
|
||||
@pytest.fixture(autouse=True)
|
||||
def argv(self):
|
||||
return ["demo/calculator/run.py"]
|
||||
|
||||
@pytest.fixture
|
||||
def config(self, monkeypatch, argv):
|
||||
def config(self, monkeypatch, argv) -> Config:
|
||||
monkeypatch.setattr("sys.argv", ["gradio"] + argv)
|
||||
return _setup_config()
|
||||
return Config(*_setup_config())
|
||||
|
||||
@pytest.fixture(params=[{}])
|
||||
def reloader(self, config):
|
||||
@ -33,29 +43,17 @@ class TestReload:
|
||||
reloader.close()
|
||||
|
||||
def test_config_default_app(self, config):
|
||||
assert config.app == "demo.calculator.run:demo.app"
|
||||
assert config.filename == "demo.calculator.run"
|
||||
|
||||
@pytest.mark.parametrize("argv", [["demo/calculator/run.py", "test.app"]])
|
||||
@pytest.mark.parametrize("argv", [["demo/calculator/run.py", "test"]])
|
||||
def test_config_custom_app(self, config):
|
||||
assert config.app == "demo.calculator.run:test.app"
|
||||
assert config.filename == "demo.calculator.run"
|
||||
assert config.demo_name == "test"
|
||||
|
||||
def test_config_watch_gradio(self, config):
|
||||
gradio_dir = Path(gradio.__file__).parent
|
||||
assert gradio_dir in config.reload_dirs
|
||||
gradio_dir = str(Path(gradio.__file__).parent)
|
||||
assert gradio_dir in config.watch_dirs
|
||||
|
||||
def test_config_watch_app(self, config):
|
||||
demo_dir = Path("demo/calculator/run.py").resolve().parent
|
||||
assert demo_dir in config.reload_dirs
|
||||
|
||||
def test_config_load_default(self, config):
|
||||
config.load()
|
||||
assert config.loaded is True
|
||||
|
||||
@pytest.mark.parametrize("argv", [["test/test_reload.py", "build_demo"]])
|
||||
def test_config_load_factory(self, config):
|
||||
config.load()
|
||||
assert config.loaded is True
|
||||
|
||||
def test_reload_run_default(self, reloader):
|
||||
reloader.run_in_thread()
|
||||
assert reloader.started is True
|
||||
demo_dir = str(Path("demo/calculator/run.py").resolve().parent)
|
||||
assert demo_dir in config.watch_dirs
|
||||
|
Loading…
Reference in New Issue
Block a user