Faster reload mode (#5267)

* This works

* Add code

* Final touches

* Lint

* Fix bug in other dirs

* add changeset

* Reload

* lint + test

* Load from frontend

* add changeset

* Use key

* tweak frontend config generation

* tweak

* WIP ipython

* Fix robust

* fix

* Fix for jupyter notebook

* Add checks

* Lint frontend

* Undo demo changes

* add changeset

* Use is_in_or_equal

* python 3.11 changes and no if __name__

* Forward sys.argv + guide

* lint

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-08-29 17:29:15 -04:00 committed by GitHub
parent 63b7a3c85e
commit 119c834331
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 582 additions and 195 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/client": minor
"gradio": minor
---
feat:Faster reload mode

View File

@ -9,6 +9,7 @@
"svelte.plugin.svelte.diagnostics.enable": false,
"prettier.configPath": ".config/.prettierrc.json",
"prettier.ignorePath": ".config/.prettierignore",
"python.analysis.typeCheckingMode": "basic",
"python.testing.pytestArgs": ["."],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,

View File

@ -1133,7 +1133,8 @@ async function resolve_config(
if (
typeof window !== "undefined" &&
window.gradio_config &&
location.origin !== "http://localhost:9876"
location.origin !== "http://localhost:9876" &&
!window.gradio_config.dev_mode
) {
const path = window.gradio_config.root;
const config = window.gradio_config;

View File

@ -741,7 +741,7 @@ class Blocks(BlockContext):
self.space_id = utils.get_space()
self.favicon_path = None
self.auth = None
self.dev_mode = True
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", False))
self.app_id = random.getrandbits(64)
self.temp_file_sets = []
self.title = title
@ -775,6 +775,12 @@ class Blocks(BlockContext):
}
analytics.initiated_analytics(data)
@property
def _is_running_in_reload_thread(self):
from gradio.reload import reload_thread
return getattr(reload_thread, "running_reload", False)
@classmethod
def from_config(
cls,
@ -1465,6 +1471,7 @@ Received outputs:
config = {
"version": routes.VERSION,
"mode": self.mode,
"app_id": self.app_id,
"dev_mode": self.dev_mode,
"analytics_enabled": self.analytics_enabled,
"components": [],
@ -1796,10 +1803,13 @@ Received outputs:
demo = gr.Interface(reverse, "text", "text")
demo.launch(share=True, auth=("username", "password"))
"""
if self._is_running_in_reload_thread:
# We have already launched the demo
return None, None, None # type: ignore
if not self.exited:
self.__exit__()
self.dev_mode = False
if (
auth
and not callable(auth)
@ -2033,11 +2043,10 @@ Received outputs:
if self.share and self.share_url:
while not networking.url_ok(self.share_url):
time.sleep(0.25)
display(
HTML(
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
artifact = HTML(
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
elif self.is_colab:
# modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
code = """(async (port, path, width, height, cache, element) => {
@ -2072,13 +2081,13 @@ Received outputs:
cache=json.dumps(False),
)
display(Javascript(code))
artifact = Javascript(code)
else:
display(
HTML(
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
artifact = HTML(
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
self.artifact = artifact
display(artifact)
except ImportError:
pass

View File

@ -29,6 +29,12 @@ class InvalidBlockError(ValueError):
pass
class ReloadError(ValueError):
"""Raised when something goes wrong when reloading the gradio app."""
pass
InvalidApiName = InvalidApiNameError # backwards compatibility

View File

@ -1,23 +1,89 @@
try:
from IPython.core.magic import needs_local_scope, register_cell_magic
from IPython.core.magic import (
needs_local_scope,
register_cell_magic,
)
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
except ImportError:
pass
import warnings
import gradio as gr
from gradio.networking import App
from gradio.utils import BaseReloader
class CellIdTracker:
"""Determines the most recently run cell in the notebook.
Needed to keep track of which demo the user is updating.
"""
def __init__(self, ipython):
ipython.events.register("pre_run_cell", self.pre_run_cell)
self.shell = ipython
self.current_cell: str = ""
def pre_run_cell(self, info):
self._current_cell = info.cell_id
class JupyterReloader(BaseReloader):
"""Swap a running blocks class in a notebook with the latest cell contents."""
def __init__(self, ipython) -> None:
super().__init__()
self._cell_tracker = CellIdTracker(ipython)
self._running: dict[str, gr.Blocks] = {}
@property
def current_cell(self):
return self._cell_tracker.current_cell
@property
def running_app(self) -> App:
assert self.running_demo.server
return self.running_demo.server.running_app
@property
def running_demo(self):
return self._running[self.current_cell]
def demo_tracked(self) -> bool:
return self.current_cell in self._running
def track(self, demo: gr.Blocks):
self._running[self.current_cell] = demo
def load_ipython_extension(ipython):
__demo = gr.Blocks()
reloader = JupyterReloader(ipython)
@magic_arguments()
@argument("--demo-name", default="demo", help="Name of gradio blocks instance.")
@argument(
"--share",
default=False,
const=True,
nargs="?",
help="Whether to launch with sharing. Will slow down reloading.",
)
@register_cell_magic
@needs_local_scope
def blocks(line, cell, local_ns=None):
if "gr.Interface" in cell:
warnings.warn(
"Usage of gradio.Interface with %%blocks may result in errors."
)
with __demo.clear():
exec(cell, None, local_ns)
__demo.launch(quiet=True)
def blocks(line, cell, local_ns):
"""Launch a demo defined in a cell in reload mode."""
args = parse_argstring(blocks, line)
exec(cell, None, local_ns)
demo: gr.Blocks = local_ns[args.demo_name]
if not reloader.demo_tracked():
demo.launch(share=args.share)
reloader.track(demo)
elif reloader.queue_changed(demo):
print("Queue got added or removed. Restarting demo.")
reloader.running_demo.close()
demo.launch()
reloader.track(demo)
else:
reloader.swap_blocks(demo)
return reloader.running_demo.artifact

View File

@ -9,14 +9,17 @@ import socket
import threading
import time
import warnings
from functools import partial
from typing import TYPE_CHECKING
import requests
import uvicorn
from uvicorn.config import Config
from gradio.exceptions import ServerFailedToStartError
from gradio.routes import App
from gradio.tunneling import Tunnel
from gradio.utils import SourceFileReloader, watchfn
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
@ -28,13 +31,34 @@ TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", False))
GRADIO_WATCH_DIRS = (
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
)
GRADIO_WATCH_FILE = os.getenv("GRADIO_WATCH_FILE", "app")
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")
class Server(uvicorn.Server):
def __init__(
self, config: Config, reloader: SourceFileReloader | None = None
) -> None:
assert isinstance(config.app, App)
self.running_app = config.app
super().__init__(config)
self.reloader = reloader
if self.reloader:
self.event = threading.Event()
self.watch = partial(watchfn, self.reloader)
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
if self.reloader:
self.watch_thread = threading.Thread(target=self.watch, daemon=True)
self.watch_thread.start()
self.thread.start()
start = time.time()
while not self.started:
@ -46,6 +70,9 @@ class Server(uvicorn.Server):
def close(self):
self.should_exit = True
if self.reloader:
self.reloader.stop()
self.watch_thread.join()
self.thread.join()
@ -160,7 +187,19 @@ def start_server(
ssl_keyfile_password=ssl_keyfile_password,
ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
)
server = Server(config=config)
reloader = None
if GRADIO_WATCH_DIRS:
change_event = threading.Event()
app.change_event = change_event
reloader = SourceFileReloader(
app=app,
watch_dirs=GRADIO_WATCH_DIRS,
watch_file=GRADIO_WATCH_FILE,
demo_name=GRADIO_WATCH_DEMO_NAME,
stop_event=threading.Event(),
change_event=change_event,
)
server = Server(config=config, reloader=reloader)
server.run_in_thread()
break
except (OSError, ServerFailedToStartError):

View File

@ -19,7 +19,12 @@ from gradio.data_classes import (
ProgressUnit,
)
from gradio.helpers import TrackedIterable
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
from gradio.utils import (
AsyncRequest,
run_coro_in_background,
safe_get_lock,
set_task_name,
)
class Event:
@ -59,7 +64,7 @@ class Queue:
self.max_thread_count = concurrency_count
self.update_intervals = update_intervals
self.active_jobs: list[None | list[Event]] = [None] * concurrency_count
self.delete_lock = asyncio.Lock()
self.delete_lock = safe_get_lock()
self.server_path = None
self.duration_history_total = 0
self.duration_history_count = 0

View File

@ -3,18 +3,20 @@
Contains the functions that run when `gradio` is called from the command line. Specifically, allows
$ gradio app.py, to run app.py in reload mode where any changes in the app.py file or Gradio library reloads the demo.
$ gradio app.py my_demo.app, to use variable names other than "demo"
$ gradio app.py my_demo, to use variable names other than "demo"
"""
import inspect
import os
import re
import subprocess
import sys
import threading
from pathlib import Path
from uvicorn import Config
from uvicorn.supervisors import ChangeReload
import gradio
from gradio import networking, utils
from gradio import utils
reload_thread = threading.local()
def _setup_config():
@ -22,15 +24,34 @@ def _setup_config():
if len(args) == 0:
raise ValueError("No file specified.")
if len(args) == 1 or args[1].startswith("--"):
demo_name = "demo.app"
demo_name = "demo"
else:
demo_name = args[1]
if "." not in demo_name:
if "." in demo_name:
demo_name = demo_name.split(".")[0]
print(
"\nWARNING: As of Gradio 3.31, the parameter after the file path must be the name of the FastAPI app, not the Gradio demo. In most cases, this just means you should add '.app' after the name of your demo, e.g. 'demo' -> 'demo.app'."
"\nWARNING: As of Gradio 3.41.0, the parameter after the file path must be the name of the Gradio demo, not the FastAPI app. In most cases, this just means you should remove '.app' after the name of your demo, e.g. 'demo.app' -> 'demo'."
)
original_path = args[0]
app_text = Path(original_path).read_text()
patterns = [
f"with gr\\.Blocks\\(\\) as {demo_name}",
f"{demo_name} = gr\\.Blocks",
f"{demo_name} = gr\\.Interface",
f"{demo_name} = gr\\.ChatInterface",
f"{demo_name} = gr\\.Series",
f"{demo_name} = gr\\.Paralles",
f"{demo_name} = gr\\.TabbedInterface",
]
if not any(re.search(p, app_text) for p in patterns):
print(
f"\nWarning: Cannot statically find a gradio demo called {demo_name}. "
"Reload work may fail."
)
abs_original_path = utils.abspath(original_path)
path = os.path.normpath(original_path)
path = path.replace("/", ".")
@ -39,15 +60,6 @@ def _setup_config():
gradio_folder = Path(inspect.getfile(gradio)).parent
port = networking.get_first_available_port(
networking.INITIAL_PORT_VALUE,
networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
)
print(
f"\nLaunching in *reload mode* on: http://{networking.LOCALHOST_NAME}:{port} (Press CTRL+C to quit)\n"
)
gradio_app = f"{filename}:{demo_name}"
message = "Watching:"
message_change_count = 0
@ -68,23 +80,24 @@ def _setup_config():
# guaranty access to the module of an app
sys.path.insert(0, os.getcwd())
# uvicorn.run blocks the execution (looping) which makes it hard to test
return Config(
gradio_app,
reload=True,
port=port,
log_level="warning",
reload_dirs=watching_dirs,
)
return filename, abs_original_path, [str(s) for s in watching_dirs], demo_name
def main():
# default execution pattern to start the server and watch changes
config = _setup_config()
server = networking.Server(config)
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
filename, path, watch_dirs, demo_name = _setup_config()
args = sys.argv[1:]
extra_args = args[1:] if len(args) == 1 or args[1].startswith("--") else args[2:]
popen = subprocess.Popen(
["python", path] + extra_args,
env=dict(
os.environ,
GRADIO_WATCH_DIRS=",".join(watch_dirs),
GRADIO_WATCH_FILE=filename,
GRADIO_WATCH_DEMO_NAME=demo_name,
),
)
popen.wait()
if __name__ == "__main__":

View File

@ -17,6 +17,7 @@ import os
import posixpath
import secrets
import tempfile
import threading
import time
import traceback
from asyncio import TimeoutError as AsyncTimeOutError
@ -112,12 +113,13 @@ class App(FastAPI):
self.state_holder = {}
self.iterators = defaultdict(dict)
self.iterators_to_reset = defaultdict(set)
self.lock = asyncio.Lock()
self.lock = utils.safe_get_lock()
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
self.change_event: None | threading.Event = None
# Allow user to manually set `docs_url` and `redoc_url`
# when instantiating an App; when they're not set, disable docs and redoc.
kwargs.setdefault("docs_url", None)
@ -216,6 +218,39 @@ class App(FastAPI):
def app_id(request: fastapi.Request) -> dict:
return {"app_id": app.get_blocks().app_id}
async def send_ping_periodically(websocket: WebSocket):
while True:
await websocket.send_text("PING")
await asyncio.sleep(1)
async def listen_for_changes(websocket: WebSocket):
assert app.change_event
while True:
if app.change_event.is_set():
await websocket.send_text("CHANGE")
app.change_event.clear()
await asyncio.sleep(0.1) # Short sleep to not make this a tight loop
@app.websocket("/dev/reload")
async def notify_changes(websocket: WebSocket):
await websocket.accept()
ping = asyncio.create_task(send_ping_periodically(websocket))
notify = asyncio.create_task(listen_for_changes(websocket))
tasks = {ping, notify}
ping.add_done_callback(tasks.remove)
notify.add_done_callback(tasks.remove)
done, pending = await asyncio.wait(
[ping, notify],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
if any(isinstance(task.exception(), Exception) for task in done):
await websocket.close()
@app.post("/login")
@app.post("/login/")
def login(form_data: OAuth2PasswordRequestForm = Depends()):

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import copy
import functools
import importlib
import inspect
import json
import json.decoder
@ -13,9 +14,12 @@ import pkgutil
import pprint
import random
import re
import threading
import time
import traceback
import typing
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from enum import Enum
from io import BytesIO
@ -27,6 +31,7 @@ from typing import (
Any,
Callable,
Generator,
Iterator,
Optional,
TypeVar,
Union,
@ -47,6 +52,7 @@ from gradio.strings import en
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import Block, BlockContext, Blocks
from gradio.components import Component
from gradio.routes import App
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
GRADIO_VERSION = (
@ -57,6 +63,158 @@ P = ParamSpec("P")
T = TypeVar("T")
def safe_get_lock() -> asyncio.Lock:
"""Get asyncio.Lock() without fear of getting an Exception.
Needed because in reload mode we import the Blocks object outside
the main thread.
"""
try:
asyncio.get_event_loop()
return asyncio.Lock()
except RuntimeError:
return None # type: ignore
class BaseReloader(ABC):
@property
@abstractmethod
def running_app(self) -> App:
pass
def queue_changed(self, demo: Blocks):
return (
hasattr(self.running_app.blocks, "_queue") and not hasattr(demo, "_queue")
) or (
not hasattr(self.running_app.blocks, "_queue") and hasattr(demo, "_queue")
)
def swap_blocks(self, demo: Blocks):
assert self.running_app.blocks
# Copy over the blocks to get new components and events but
# not a new queue
if hasattr(self.running_app.blocks, "_queue"):
self.running_app.blocks._queue.blocks_dependencies = demo.dependencies
demo._queue = self.running_app.blocks._queue
self.running_app.blocks = demo
class SourceFileReloader(BaseReloader):
def __init__(
self,
app: App,
watch_dirs: list[str],
watch_file: str,
stop_event: threading.Event,
change_event: threading.Event,
demo_name: str = "demo",
) -> None:
super().__init__()
self.app = app
self.watch_dirs = watch_dirs
self.watch_file = watch_file
self.stop_event = stop_event
self.change_event = change_event
self.demo_name = demo_name
@property
def running_app(self) -> App:
return self.app
def should_watch(self) -> bool:
return not self.stop_event.is_set()
def stop(self) -> None:
self.stop_event.set()
def alert_change(self):
self.change_event.set()
def swap_blocks(self, demo: Blocks):
super().swap_blocks(demo)
self.alert_change()
def watchfn(reloader: SourceFileReloader):
"""Watch python files in a given module.
get_changes is taken from uvicorn's default file watcher.
"""
# The thread running watchfn will be the thread reloading
# the app. So we need to modify this thread_data attr here
# so that subsequent calls to reload don't launch the app
from gradio.reload import reload_thread
reload_thread.running_reload = True
def get_changes() -> Path | None:
for file in iter_py_files():
try:
mtime = file.stat().st_mtime
except OSError: # pragma: nocover
continue
old_time = mtimes.get(file)
if old_time is None:
mtimes[file] = mtime
continue
elif mtime > old_time:
return file
return None
def iter_py_files() -> Iterator[Path]:
for reload_dir in reload_dirs:
for path in list(reload_dir.rglob("*.py")):
yield path.resolve()
module = None
reload_dirs = [Path(dir_) for dir_ in reloader.watch_dirs]
mtimes = {}
while reloader.should_watch():
import sys
changed = get_changes()
if changed:
print(f"Changes detected in: {changed}")
# To simulate a fresh reload, delete all module references from sys.modules
# for the modules in the package the change came from.
dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d))
modules = list(sys.modules)
for k in modules:
v = sys.modules[k]
sourcefile = getattr(v, "__file__", None)
# Do not reload `reload.py` to keep thread data
if (
sourcefile
and dir_ == Path(inspect.getfile(gradio)).parent
and sourcefile.endswith("reload.py")
):
continue
if sourcefile and is_in_or_equal(sourcefile, dir_):
del sys.modules[k]
try:
module = importlib.import_module(reloader.watch_file)
module = importlib.reload(module)
except Exception as e:
print(
f"Reloading {reloader.watch_file} failed with the following exception: "
)
traceback.print_exception(None, value=e, tb=None)
mtimes = {}
continue
demo = getattr(module, reloader.demo_name)
if reloader.queue_changed(demo):
print(
"Reloading failed. The new demo has a queue and the old one doesn't (or vice versa). "
"Please launch your demo again"
)
else:
reloader.swap_blocks(demo)
mtimes = {}
def colab_check() -> bool:
"""
Check if interface is launching from Google Colab

View File

@ -41,18 +41,16 @@ In the terminal, run `gradio run.py`. That's it!
Now, you'll see that after you'll see something like this:
```bash
Launching in *reload mode* on: http://127.0.0.1:7860 (Press CTRL+C to quit)
Watching: '/Users/freddy/sources/gradio/gradio', '/Users/freddy/sources/gradio/demo/'
Watching...
WARNING: The --reload flag should not be used in production on Windows.
Running on local URL: http://127.0.0.1:7860
```
The important part here is the line that says `Watching...` What's happening here is that Gradio will be observing the directory where `run.py` file lives, and if the file changes, it will automatically rerun the file for you. So you can focus on writing your code, and your Gradio demo will refresh automatically 🥳
⚠️ Warning: the `gradio` command does not detect the parameters passed to the `launch()` methods because the `launch()` method is never called in reload mode. For example, setting `auth`, or `show_error` in `launch()` will not be reflected in the app.
There is one important thing to keep in mind when using the reload mode: Gradio specifically looks for a Gradio Blocks/Interface demo called `demo` in your code. If you have named your demo something else, you will need to pass in the name of your demo's FastAPI app as the 2nd parameter in your code. For Gradio demos, the FastAPI app can be accessed using the `.app` attribute. So if your `run.py` file looked like this:
There is one important thing to keep in mind when using the reload mode: Gradio specifically looks for a Gradio Blocks/Interface demo called `demo` in your code. If you have named your demo something else, you will need to pass in the name of your demo as the 2nd parameter in your code. So if your `run.py` file looked like this:
```python
import gradio as gr
@ -70,7 +68,7 @@ if __name__ == "__main__":
my_demo.launch()
```
Then you would launch it in reload mode like this: `gradio run.py my_demo.app`.
Then you would launch it in reload mode like this: `gradio run.py my_demo`.
🔥 If your application accepts command line arguments, you can pass them in as well. Here's an example:
@ -112,26 +110,25 @@ Then, in the cell that you are developing your Gradio demo, simply write the mag
import gradio as gr
gr.Markdown("# Greetings from Gradio!")
inp = gr.Textbox(placeholder="What is your name?")
out = gr.Textbox()
with gr.Blocks() as demo:
gr.Markdown(f"# Greetings {args.name}!")
inp = gr.Textbox()
out = gr.Textbox()
inp.change(fn=lambda x: f"Welcome, {x}!",
inputs=inp,
outputs=out)
inp.change(fn=lambda x: x, inputs=inp, outputs=out)
```
Notice that:
- You do not need to put the boiler plate `with gr.Blocks() as demo:` and `demo.launch()` code — Gradio does that for you automatically!
- You do not need to launch your demo — Gradio does that for you automatically!
- Every time you rerun the cell, Gradio will re-launch your app on the same port and using the same underlying web server. This means you'll see your changes _much, much faster_ than if you were rerunning the cell normally.
- Every time you rerun the cell, Gradio will re-render your app on the same port and using the same underlying web server. This means you'll see your changes _much, much faster_ than if you were rerunning the cell normally.
Here's what it looks like in a jupyter notebook:
![](https://i.ibb.co/nrszFws/Blocks.gif)
![](https://gradio-builds.s3.amazonaws.com/demo-files/jupyter_reload.gif)
🪄 This works in colab notebooks too! [Here's a colab notebook](https://colab.research.google.com/drive/1jUlX1w7JqckRHVE-nbDyMPyZ7fYD8488?authuser=1#scrollTo=zxHYjbCTTz_5) where you can see the Blocks magic in action. Try making some changes and re-running the cell with the Gradio code!
🪄 This works in colab notebooks too! [Here's a colab notebook](https://colab.research.google.com/drive/1zAuWoiTIb3O2oitbtVb2_ekv1K6ggtC1?usp=sharing) where you can see the Blocks magic in action. Try making some changes and re-running the cell with the Gradio code!
The Notebook Magic is now the author's preferred way of building Gradio demos. Regardless of how you write Python code, we hope either of these methods will give you a much better development experience using Gradio.

View File

@ -52,12 +52,10 @@
type: "column",
props: { mode: "static" },
has_modes: false,
instance: {} as ComponentMeta["instance"],
component: {} as ComponentMeta["component"]
instance: null as unknown as ComponentMeta["instance"],
component: null as unknown as ComponentMeta["component"]
};
components.push(rootNode);
const AsyncFunction = Object.getPrototypeOf(async function () {}).constructor;
dependencies.forEach((d) => {
if (d.js) {
@ -103,18 +101,7 @@
return false;
}
const dynamic_ids: Set<number> = new Set();
for (const comp of components) {
const { id, props } = comp;
const is_input = is_dep(id, "inputs", dependencies);
if (
is_input ||
(!is_dep(id, "outputs", dependencies) &&
has_no_default_value(props?.value))
) {
dynamic_ids.add(id);
}
}
let dynamic_ids: Set<number> = new Set();
function has_no_default_value(value: any): boolean {
return (
@ -125,13 +112,7 @@
);
}
let instance_map = components.reduce(
(acc, next) => {
acc[next.id] = next;
return acc;
},
{} as { [id: number]: ComponentMeta }
);
let instance_map: { [id: number]: ComponentMeta };
type LoadedComponent = {
default: ComponentMeta["component"];
@ -172,58 +153,124 @@
}
}
const component_set = new Set<
let component_set = new Set<
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
>();
const _component_map = new Map<
let _component_map = new Map<
`${ComponentMeta["type"]}_${ComponentMeta["props"]["mode"]}`,
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
>();
const _type_for_id = new Map<number, ComponentMeta["props"]["mode"]>();
async function walk_layout(node: LayoutNode): Promise<void> {
async function walk_layout(
node: LayoutNode,
type_map: Map<number, ComponentMeta["props"]["mode"]>,
instance_map: { [id: number]: ComponentMeta },
component_map: Map<
`${ComponentMeta["type"]}_${ComponentMeta["props"]["mode"]}`,
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
>
): Promise<void> {
ready = false;
let instance = instance_map[node.id];
const _component = (await _component_map.get(
`${instance.type}_${_type_for_id.get(node.id) || "static"}`
const _component = (await component_map.get(
`${instance.type}_${type_map.get(node.id) || "static"}`
))!.component;
instance.component = _component.default;
if (node.children) {
instance.children = node.children.map((v) => instance_map[v.id]);
await Promise.all(node.children.map((v) => walk_layout(v)));
await Promise.all(
node.children.map((v) =>
walk_layout(v, type_map, instance_map, component_map)
)
);
}
}
components.forEach((c) => {
if ((c.props as any).interactive === false) {
(c.props as any).mode = "static";
} else if ((c.props as any).interactive === true) {
(c.props as any).mode = "interactive";
} else if (dynamic_ids.has(c.id)) {
(c.props as any).mode = "interactive";
} else {
(c.props as any).mode = "static";
}
_type_for_id.set(c.id, c.props.mode);
const _c = load_component(c.type, c.props.mode);
component_set.add(_c);
_component_map.set(`${c.type}_${c.props.mode}`, _c);
});
export let ready = false;
export let render_complete = false;
Promise.all(Array.from(component_set)).then(() => {
walk_layout(layout)
.then(async () => {
ready = true;
})
.catch((e) => {
console.error(e);
});
});
$: components, layout, prepare_components();
function prepare_components(): void {
loading_status = create_loading_status_store();
dependencies.forEach((v, i) => {
loading_status.register(i, v.inputs, v.outputs);
});
const _dynamic_ids = new Set<number>();
for (const comp of components) {
const { id, props } = comp;
const is_input = is_dep(id, "inputs", dependencies);
if (
is_input ||
(!is_dep(id, "outputs", dependencies) &&
has_no_default_value(props?.value))
) {
_dynamic_ids.add(id);
}
}
dynamic_ids = _dynamic_ids;
const _rootNode: typeof rootNode = {
id: layout.id,
type: "column",
props: { mode: "static" },
has_modes: false,
instance: null as unknown as ComponentMeta["instance"],
component: null as unknown as ComponentMeta["component"]
};
components.push(_rootNode);
const _component_set = new Set<
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
>();
const __component_map = new Map<
`${ComponentMeta["type"]}_${ComponentMeta["props"]["mode"]}`,
Promise<{ name: ComponentMeta["type"]; component: LoadedComponent }>
>();
const __type_for_id = new Map<number, ComponentMeta["props"]["mode"]>();
const _instance_map = components.reduce(
(acc, next) => {
acc[next.id] = next;
return acc;
},
{} as { [id: number]: ComponentMeta }
);
components.forEach((c) => {
if ((c.props as any).interactive === false) {
(c.props as any).mode = "static";
} else if ((c.props as any).interactive === true) {
(c.props as any).mode = "interactive";
} else if (dynamic_ids.has(c.id)) {
(c.props as any).mode = "interactive";
} else {
(c.props as any).mode = "static";
}
__type_for_id.set(c.id, c.props.mode);
const _c = load_component(c.type, c.props.mode);
_component_set.add(_c);
__component_map.set(`${c.type}_${c.props.mode}`, _c);
});
Promise.all(Array.from(_component_set)).then(() => {
walk_layout(layout, __type_for_id, _instance_map, __component_map)
.then(async () => {
ready = true;
component_set = _component_set;
_component_map = __component_map;
instance_map = _instance_map;
rootNode = _rootNode;
})
.catch((e) => {
console.error(e);
});
});
}
async function update_interactive_mode(
instance: ComponentMeta,
@ -494,6 +541,26 @@
const is_external_url = (link: string | null): boolean =>
!!(link && new URL(link, location.href).origin !== location.origin);
$: target_map = dependencies.reduce(
(acc, dep, i) => {
let { targets, trigger } = dep;
targets.forEach((id) => {
if (!acc[id]) {
acc[id] = {};
}
if (acc[id]?.[trigger]) {
acc[id][trigger].push(i);
} else {
acc[id][trigger] = [i];
}
});
return acc;
},
{} as Record<number, Record<string, number[]>>
);
async function handle_mount(): Promise<void> {
await tick();
@ -515,25 +582,6 @@
}
});
dependencies.forEach((dep, i) => {
let { targets, trigger, inputs, outputs } = dep;
const target_instances: [number, ComponentMeta][] = targets.map((t) => [
t,
instance_map[t]
]);
target_instances.forEach(([id]) => {
if (!target_map[id]) {
target_map[id] = {};
}
if (target_map[id]?.[trigger]) {
target_map[id][trigger].push(i);
} else {
target_map[id][trigger] = [i];
}
});
});
target.addEventListener("gradio", (e: Event) => {
if (!isCustomEvent(e)) throw new Error("not a custom event");
@ -544,12 +592,12 @@
trigger_share(title, description);
} else if (event === "error") {
messages = [new_message(data, -1, "error"), ...messages];
} else {
const deps = target_map[id]?.[event];
deps?.forEach((dep_id) => {
trigger_api_call(dep_id, data);
});
}
const deps = target_map[id]?.[event];
deps?.forEach((dep_id) => {
trigger_api_call(dep_id, data);
});
});
render_complete = true;
@ -563,10 +611,6 @@
$: set_status($loading_status);
dependencies.forEach((v, i) => {
loading_status.register(i, v.inputs, v.outputs);
});
function set_status(statuses: LoadingStatusCollection): void {
for (const id in statuses) {
let loading_status = statuses[id];
@ -582,8 +626,6 @@
}
}
const target_map: Record<number, Record<string, number[]>> = {};
function isCustomEvent(event: Event): event is CustomEvent {
return "detail" in event;
}

View File

@ -27,6 +27,8 @@
is_colab: boolean;
show_api: boolean;
stylesheets?: string[];
path: string;
app_id?: string;
}
let id = -1;
@ -81,6 +83,7 @@
export let container: boolean;
export let info: boolean;
export let eager: boolean;
let websocket: WebSocket;
// These utilities are exported to be injectable for the Wasm version.
export let mount_css: typeof default_mount_css = default_mount_css;
@ -103,6 +106,10 @@
let loading_text = $_("common.loading") + "...";
let active_theme_mode: ThemeMode;
$: if (config?.app_id) {
app_id = config.app_id;
}
async function mount_custom_css(
target: HTMLElement,
css_string: string | null
@ -127,18 +134,6 @@
);
}
async function reload_check(root: string): Promise<void> {
const result = await (await fetch(root + "/app_id")).text();
if (app_id === null) {
app_id = result;
} else if (app_id != result) {
location.reload();
}
setTimeout(() => reload_check(root), 250);
}
function handle_darkmode(target: HTMLDivElement): "light" | "dark" {
let url = new URL(window.location.toString());
let url_color_mode: ThemeMode | null = url.searchParams.get(
@ -225,7 +220,22 @@
window.__is_colab__ = config.is_colab;
if (config.dev_mode) {
reload_check(config.root);
setTimeout(() => {
const { host } = new URL(api_url);
let url = new URL(`ws://${host}/dev/reload`);
websocket = new WebSocket(url);
websocket.onmessage = async function (event) {
if (event.data === "CHANGE") {
app = await client(api_url, {
status_callback: handle_status,
normalise_files: false
});
app.config.root = app.config.path;
config = app.config;
window.__gradio_space__ = config.space_id;
}
};
}, 200);
}
});

View File

@ -1,4 +1,6 @@
import dataclasses
from pathlib import Path
from typing import List
import pytest
@ -15,15 +17,23 @@ def build_demo():
return demo
@dataclasses.dataclass
class Config:
filename: str
path: Path
watch_dirs: List[str]
demo_name: str
class TestReload:
@pytest.fixture(autouse=True)
def argv(self):
return ["demo/calculator/run.py"]
@pytest.fixture
def config(self, monkeypatch, argv):
def config(self, monkeypatch, argv) -> Config:
monkeypatch.setattr("sys.argv", ["gradio"] + argv)
return _setup_config()
return Config(*_setup_config())
@pytest.fixture(params=[{}])
def reloader(self, config):
@ -33,29 +43,17 @@ class TestReload:
reloader.close()
def test_config_default_app(self, config):
assert config.app == "demo.calculator.run:demo.app"
assert config.filename == "demo.calculator.run"
@pytest.mark.parametrize("argv", [["demo/calculator/run.py", "test.app"]])
@pytest.mark.parametrize("argv", [["demo/calculator/run.py", "test"]])
def test_config_custom_app(self, config):
assert config.app == "demo.calculator.run:test.app"
assert config.filename == "demo.calculator.run"
assert config.demo_name == "test"
def test_config_watch_gradio(self, config):
gradio_dir = Path(gradio.__file__).parent
assert gradio_dir in config.reload_dirs
gradio_dir = str(Path(gradio.__file__).parent)
assert gradio_dir in config.watch_dirs
def test_config_watch_app(self, config):
demo_dir = Path("demo/calculator/run.py").resolve().parent
assert demo_dir in config.reload_dirs
def test_config_load_default(self, config):
config.load()
assert config.loaded is True
@pytest.mark.parametrize("argv", [["test/test_reload.py", "build_demo"]])
def test_config_load_factory(self, config):
config.load()
assert config.loaded is True
def test_reload_run_default(self, reloader):
reloader.run_in_thread()
assert reloader.started is True
demo_dir = str(Path("demo/calculator/run.py").resolve().parent)
assert demo_dir in config.watch_dirs