mirror of
https://github.com/jupyter/notebook.git
synced 2025-03-19 13:20:36 +08:00
DEV: Add full support for non-notebook checkpoints.
This commit is contained in:
parent
23837e9ad4
commit
7030a8717a
@ -60,7 +60,10 @@ class FileManagerMixin(object):
|
||||
"""
|
||||
Mixin for ContentsAPI classes that interact with the filesystem.
|
||||
|
||||
Shared by both FileContentsManager and FileCheckpointManager.
|
||||
Provides facilities for reading, writing, and copying both notebooks and
|
||||
generic files.
|
||||
|
||||
Shared by FileContentsManager and FileCheckpointManager.
|
||||
|
||||
Note
|
||||
----
|
||||
@ -114,17 +117,6 @@ class FileManagerMixin(object):
|
||||
except OSError:
|
||||
self.log.debug("copystat on %s failed", dest, exc_info=True)
|
||||
|
||||
def _read_notebook(self, os_path, as_version=4):
|
||||
"""Read a notebook from an os path."""
|
||||
with self.open(os_path, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
return nbformat.read(f, as_version=as_version)
|
||||
except Exception as e:
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
u"Unreadable Notebook: %s %r" % (os_path, e),
|
||||
)
|
||||
|
||||
def _get_os_path(self, path):
|
||||
"""Given an API path, return its file system path.
|
||||
|
||||
@ -140,6 +132,70 @@ class FileManagerMixin(object):
|
||||
"""
|
||||
return to_os_path(path, self.root_dir)
|
||||
|
||||
def _read_notebook(self, os_path, as_version=4):
|
||||
"""Read a notebook from an os path."""
|
||||
with self.open(os_path, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
return nbformat.read(f, as_version=as_version)
|
||||
except Exception as e:
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
u"Unreadable Notebook: %s %r" % (os_path, e),
|
||||
)
|
||||
|
||||
def _save_notebook(self, os_path, nb):
|
||||
"""Save a notebook to an os_path."""
|
||||
with self.atomic_writing(os_path, encoding='utf-8') as f:
|
||||
nbformat.write(nb, f, version=nbformat.NO_CONVERT)
|
||||
|
||||
def _read_file(self, os_path, format):
|
||||
"""Read a non-notebook file.
|
||||
|
||||
os_path: The path to be read.
|
||||
format:
|
||||
If 'text', the contents will be decoded as UTF-8.
|
||||
If 'base64', the raw bytes contents will be encoded as base64.
|
||||
If not specified, try to decode as UTF-8, and fall back to base64
|
||||
"""
|
||||
if not os.path.isfile(os_path):
|
||||
raise web.HTTPError(400, "Cannot read non-file %s" % os_path)
|
||||
|
||||
with self.open(os_path, 'rb') as f:
|
||||
bcontent = f.read()
|
||||
|
||||
if format is None or format == 'text':
|
||||
# Try to interpret as unicode if format is unknown or if unicode
|
||||
# was explicitly requested.
|
||||
try:
|
||||
return bcontent.decode('utf8'), 'text'
|
||||
except UnicodeError as e:
|
||||
if format == 'text':
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"%s is not UTF-8 encoded" % os_path,
|
||||
reason='bad format',
|
||||
)
|
||||
return base64.encodestring(bcontent).decode('ascii'), 'base64'
|
||||
|
||||
def _save_file(self, os_path, content, format):
|
||||
"""Save content of a generic file."""
|
||||
if format not in {'text', 'base64'}:
|
||||
raise web.HTTPError(
|
||||
400,
|
||||
"Must specify format of file contents as 'text' or 'base64'",
|
||||
)
|
||||
try:
|
||||
if format == 'text':
|
||||
bcontent = content.encode('utf8')
|
||||
else:
|
||||
b64_bytes = content.encode('ascii')
|
||||
bcontent = base64.decodestring(b64_bytes)
|
||||
except Exception as e:
|
||||
raise web.HTTPError(400, u'Encoding error saving %s: %s' % (os_path, e))
|
||||
|
||||
with self.atomic_writing(os_path, text=False) as f:
|
||||
f.write(bcontent)
|
||||
|
||||
|
||||
class FileCheckpointManager(FileManagerMixin, CheckpointManager):
|
||||
"""
|
||||
@ -167,39 +223,51 @@ class FileCheckpointManager(FileManagerMixin, CheckpointManager):
|
||||
return getcwd()
|
||||
|
||||
# public checkpoint API
|
||||
def create_checkpoint(self, nb, path):
|
||||
def create_file_checkpoint(self, content, format, path):
|
||||
"""Create a checkpoint from the current content of a notebook."""
|
||||
path = path.strip('/')
|
||||
# only the one checkpoint ID:
|
||||
checkpoint_id = u"checkpoint"
|
||||
os_checkpoint_path = self.get_checkpoint_path(checkpoint_id, path)
|
||||
os_checkpoint_path = self.checkpoint_path(checkpoint_id, path)
|
||||
self.log.debug("creating checkpoint for %s", path)
|
||||
with self.perm_to_403():
|
||||
self._save_file(os_checkpoint_path, content, format=format)
|
||||
|
||||
# return the checkpoint info
|
||||
return self.checkpoint_model(checkpoint_id, os_checkpoint_path)
|
||||
|
||||
def create_notebook_checkpoint(self, nb, path):
|
||||
"""Create a checkpoint from the current content of a notebook."""
|
||||
path = path.strip('/')
|
||||
# only the one checkpoint ID:
|
||||
checkpoint_id = u"checkpoint"
|
||||
os_checkpoint_path = self.checkpoint_path(checkpoint_id, path)
|
||||
self.log.debug("creating checkpoint for %s", path)
|
||||
with self.perm_to_403():
|
||||
self._save_notebook(os_checkpoint_path, nb)
|
||||
|
||||
# return the checkpoint info
|
||||
return self.get_checkpoint_model(checkpoint_id, path)
|
||||
return self.checkpoint_model(checkpoint_id, os_checkpoint_path)
|
||||
|
||||
def get_checkpoint_content(self, checkpoint_id, path):
|
||||
def get_checkpoint(self, checkpoint_id, path, type):
|
||||
"""Get the content of a checkpoint.
|
||||
|
||||
Returns an unvalidated model with the same structure as
|
||||
the return value of ContentsManager.get
|
||||
Returns a pair of (content, type).
|
||||
"""
|
||||
path = path.strip('/')
|
||||
self.log.info("restoring %s from checkpoint %s", path, checkpoint_id)
|
||||
os_checkpoint_path = self.get_checkpoint_path(checkpoint_id, path)
|
||||
return self._read_notebook(os_checkpoint_path, as_version=4)
|
||||
|
||||
def _save_notebook(self, os_path, nb):
|
||||
"""Save a notebook file."""
|
||||
with self.atomic_writing(os_path, encoding='utf-8') as f:
|
||||
nbformat.write(nb, f, version=nbformat.NO_CONVERT)
|
||||
os_checkpoint_path = self.checkpoint_path(checkpoint_id, path)
|
||||
if not os.path.isfile(os_checkpoint_path):
|
||||
self.no_such_checkpoint(path, checkpoint_id)
|
||||
if type == 'notebook':
|
||||
return self._read_notebook(os_checkpoint_path, as_version=4), None
|
||||
else:
|
||||
return self._read_file(os_checkpoint_path, format=None)
|
||||
|
||||
def rename_checkpoint(self, checkpoint_id, old_path, new_path):
|
||||
"""Rename a checkpoint from old_path to new_path."""
|
||||
old_cp_path = self.get_checkpoint_path(checkpoint_id, old_path)
|
||||
new_cp_path = self.get_checkpoint_path(checkpoint_id, new_path)
|
||||
old_cp_path = self.checkpoint_path(checkpoint_id, old_path)
|
||||
new_cp_path = self.checkpoint_path(checkpoint_id, new_path)
|
||||
if os.path.isfile(old_cp_path):
|
||||
self.log.debug(
|
||||
"Renaming checkpoint %s -> %s",
|
||||
@ -212,7 +280,7 @@ class FileCheckpointManager(FileManagerMixin, CheckpointManager):
|
||||
def delete_checkpoint(self, checkpoint_id, path):
|
||||
"""delete a file's checkpoint"""
|
||||
path = path.strip('/')
|
||||
cp_path = self.get_checkpoint_path(checkpoint_id, path)
|
||||
cp_path = self.checkpoint_path(checkpoint_id, path)
|
||||
if not os.path.isfile(cp_path):
|
||||
self.no_such_checkpoint(path, checkpoint_id)
|
||||
|
||||
@ -227,14 +295,14 @@ class FileCheckpointManager(FileManagerMixin, CheckpointManager):
|
||||
"""
|
||||
path = path.strip('/')
|
||||
checkpoint_id = "checkpoint"
|
||||
os_path = self.get_checkpoint_path(checkpoint_id, path)
|
||||
if not os.path.exists(os_path):
|
||||
os_path = self.checkpoint_path(checkpoint_id, path)
|
||||
if not os.path.isfile(os_path):
|
||||
return []
|
||||
else:
|
||||
return [self.get_checkpoint_model(checkpoint_id, path)]
|
||||
return [self.checkpoint_model(checkpoint_id, os_path)]
|
||||
|
||||
# Checkpoint-related utilities
|
||||
def get_checkpoint_path(self, checkpoint_id, path):
|
||||
def checkpoint_path(self, checkpoint_id, path):
|
||||
"""find the path to a checkpoint"""
|
||||
path = path.strip('/')
|
||||
parent, name = ('/' + path).rsplit('/', 1)
|
||||
@ -252,11 +320,9 @@ class FileCheckpointManager(FileManagerMixin, CheckpointManager):
|
||||
cp_path = os.path.join(cp_dir, filename)
|
||||
return cp_path
|
||||
|
||||
def get_checkpoint_model(self, checkpoint_id, path):
|
||||
def checkpoint_model(self, checkpoint_id, os_path):
|
||||
"""construct the info dict for a given checkpoint"""
|
||||
path = path.strip('/')
|
||||
cp_path = self.get_checkpoint_path(checkpoint_id, path)
|
||||
stats = os.stat(cp_path)
|
||||
stats = os.stat(os_path)
|
||||
last_modified = tz.utcfromtimestamp(stats.st_mtime)
|
||||
info = dict(
|
||||
id=checkpoint_id,
|
||||
@ -499,29 +565,17 @@ class FileContentsManager(FileManagerMixin, ContentsManager):
|
||||
os_path = self._get_os_path(path)
|
||||
|
||||
if content:
|
||||
if not os.path.isfile(os_path):
|
||||
# could be FIFO
|
||||
raise web.HTTPError(400, "Cannot get content of non-file %s" % os_path)
|
||||
with self.open(os_path, 'rb') as f:
|
||||
bcontent = f.read()
|
||||
content, format = self._read_file(os_path, format)
|
||||
default_mime = {
|
||||
'text': 'text/plain',
|
||||
'base64': 'application/octet-stream'
|
||||
}[format]
|
||||
|
||||
if format != 'base64':
|
||||
try:
|
||||
model['content'] = bcontent.decode('utf8')
|
||||
except UnicodeError as e:
|
||||
if format == 'text':
|
||||
raise web.HTTPError(400, "%s is not UTF-8 encoded" % path, reason='bad format')
|
||||
else:
|
||||
model['format'] = 'text'
|
||||
default_mime = 'text/plain'
|
||||
|
||||
if model['content'] is None:
|
||||
model['content'] = base64.encodestring(bcontent).decode('ascii')
|
||||
model['format'] = 'base64'
|
||||
if model['format'] == 'base64':
|
||||
default_mime = 'application/octet-stream'
|
||||
|
||||
model['mimetype'] = mimetypes.guess_type(os_path)[0] or default_mime
|
||||
model.update(
|
||||
content=content,
|
||||
format=format,
|
||||
mimetype=mimetypes.guess_type(os_path)[0] or default_mime,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@ -584,35 +638,6 @@ class FileContentsManager(FileManagerMixin, ContentsManager):
|
||||
model = self._file_model(path, content=content, format=format)
|
||||
return model
|
||||
|
||||
def _save_notebook(self, os_path, model, path):
|
||||
"""save a notebook file"""
|
||||
nb = nbformat.from_dict(model['content'])
|
||||
self.check_and_sign(nb, path)
|
||||
|
||||
# One checkpoint should always exist for notebooks.
|
||||
if not self.checkpoint_manager.list_checkpoints(path):
|
||||
self.checkpoint_manager.create_checkpoint(nb, path)
|
||||
|
||||
with self.atomic_writing(os_path, encoding='utf-8') as f:
|
||||
nbformat.write(nb, f, version=nbformat.NO_CONVERT)
|
||||
|
||||
def _save_file(self, os_path, model, path=''):
|
||||
"""save a non-notebook file"""
|
||||
fmt = model.get('format', None)
|
||||
if fmt not in {'text', 'base64'}:
|
||||
raise web.HTTPError(400, "Must specify format of file contents as 'text' or 'base64'")
|
||||
try:
|
||||
content = model['content']
|
||||
if fmt == 'text':
|
||||
bcontent = content.encode('utf8')
|
||||
else:
|
||||
b64_bytes = content.encode('ascii')
|
||||
bcontent = base64.decodestring(b64_bytes)
|
||||
except Exception as e:
|
||||
raise web.HTTPError(400, u'Encoding error saving %s: %s' % (os_path, e))
|
||||
with self.atomic_writing(os_path, text=False) as f:
|
||||
f.write(bcontent)
|
||||
|
||||
def _save_directory(self, os_path, model, path=''):
|
||||
"""create a directory"""
|
||||
if is_hidden(os_path, self.root_dir):
|
||||
@ -640,9 +665,18 @@ class FileContentsManager(FileManagerMixin, ContentsManager):
|
||||
self.log.debug("Saving %s", os_path)
|
||||
try:
|
||||
if model['type'] == 'notebook':
|
||||
self._save_notebook(os_path, model, path)
|
||||
nb = nbformat.from_dict(model['content'])
|
||||
self.check_and_sign(nb, path)
|
||||
self._save_notebook(os_path, nb)
|
||||
# One checkpoint should always exist for notebooks.
|
||||
if not self.checkpoint_manager.list_checkpoints(path):
|
||||
self.checkpoint_manager.create_notebook_checkpoint(
|
||||
nb,
|
||||
path,
|
||||
)
|
||||
elif model['type'] == 'file':
|
||||
self._save_file(os_path, model, path)
|
||||
# Missing format will be handled internally by _save_file.
|
||||
self._save_file(os_path, model['content'], model.get('format'))
|
||||
elif model['type'] == 'directory':
|
||||
self._save_directory(os_path, model, path)
|
||||
else:
|
||||
|
@ -34,15 +34,21 @@ class CheckpointManager(LoggingConfigurable):
|
||||
"""
|
||||
Base class for managing checkpoints for a ContentsManager.
|
||||
"""
|
||||
|
||||
def create_checkpoint(self, nb, path):
|
||||
def create_file_checkpoint(self, content, format, path):
|
||||
"""Create a checkpoint of the current state of a file
|
||||
|
||||
Returns a checkpoint_id for the new checkpoint.
|
||||
Returns a checkpoint model for the new checkpoint.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in a subclass")
|
||||
|
||||
def get_checkpoint_content(self, checkpoint_id, path):
|
||||
def create_notebook_checkpoint(self, nb, path):
|
||||
"""Create a checkpoint of the current state of a file
|
||||
|
||||
Returns a checkpoint model for the new checkpoint.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in a subclass")
|
||||
|
||||
def get_checkpoint(self, checkpoint_id, path, type):
|
||||
"""Get the content of a checkpoint.
|
||||
|
||||
Returns an unvalidated model with the same structure as
|
||||
@ -496,9 +502,19 @@ class ContentsManager(LoggingConfigurable):
|
||||
# Part 3: Checkpoints API
|
||||
def create_checkpoint(self, path):
|
||||
"""Create a checkpoint."""
|
||||
|
||||
nb = nbformat.from_dict(self.get(path, content=True)['content'])
|
||||
return self.checkpoint_manager.create_checkpoint(nb, path)
|
||||
model = self.get(path, content=True)
|
||||
type = model['type']
|
||||
if type == 'notebook':
|
||||
return self.checkpoint_manager.create_notebook_checkpoint(
|
||||
model['content'],
|
||||
path,
|
||||
)
|
||||
elif type == 'file':
|
||||
return self.checkpoint_manager.create_file_checkpoint(
|
||||
model['content'],
|
||||
model['format'],
|
||||
path,
|
||||
)
|
||||
|
||||
def list_checkpoints(self, path):
|
||||
return self.checkpoint_manager.list_checkpoints(path)
|
||||
@ -507,17 +523,18 @@ class ContentsManager(LoggingConfigurable):
|
||||
"""
|
||||
Restore a checkpoint.
|
||||
"""
|
||||
nb = self.checkpoint_manager.get_checkpoint_content(
|
||||
type = self.get(path, content=False)['type']
|
||||
content, format = self.checkpoint_manager.get_checkpoint(
|
||||
checkpoint_id,
|
||||
path,
|
||||
type,
|
||||
)
|
||||
|
||||
model = {
|
||||
'content': nb,
|
||||
'type': 'notebook',
|
||||
'type': type,
|
||||
'content': content,
|
||||
'format': format,
|
||||
}
|
||||
|
||||
self.validate_notebook_model(model)
|
||||
return self.save(model, path)
|
||||
|
||||
def delete_checkpoint(self, checkpoint_id, path):
|
||||
|
@ -542,6 +542,49 @@ class APITest(NotebookTestBase):
|
||||
cps = self.api.get_checkpoints('foo/a.ipynb').json()
|
||||
self.assertEqual(cps, [])
|
||||
|
||||
def test_file_checkpoints(self):
|
||||
"""
|
||||
Test checkpointing of non-notebook files.
|
||||
"""
|
||||
filename = 'foo/a.txt'
|
||||
resp = self.api.read(filename)
|
||||
orig_content = json.loads(resp.text)['content']
|
||||
|
||||
# Create a checkpoint.
|
||||
r = self.api.new_checkpoint(filename)
|
||||
self.assertEqual(r.status_code, 201)
|
||||
cp1 = r.json()
|
||||
self.assertEqual(set(cp1), {'id', 'last_modified'})
|
||||
self.assertEqual(r.headers['Location'].split('/')[-1], cp1['id'])
|
||||
|
||||
# Modify the file and save.
|
||||
new_content = orig_content + '\nsecond line'
|
||||
model = {
|
||||
'content': new_content,
|
||||
'type': 'file',
|
||||
'format': 'text',
|
||||
}
|
||||
resp = self.api.save(filename, body=json.dumps(model))
|
||||
|
||||
# List checkpoints
|
||||
cps = self.api.get_checkpoints(filename).json()
|
||||
self.assertEqual(cps, [cp1])
|
||||
|
||||
content = self.api.read(filename).json()['content']
|
||||
self.assertEqual(content, new_content)
|
||||
|
||||
# Restore cp1
|
||||
r = self.api.restore_checkpoint(filename, cp1['id'])
|
||||
self.assertEqual(r.status_code, 204)
|
||||
restored_content = self.api.read(filename).json()['content']
|
||||
self.assertEqual(restored_content, orig_content)
|
||||
|
||||
# Delete cp1
|
||||
r = self.api.delete_checkpoint(filename, cp1['id'])
|
||||
self.assertEqual(r.status_code, 204)
|
||||
cps = self.api.get_checkpoints(filename).json()
|
||||
self.assertEqual(cps, [])
|
||||
|
||||
@contextmanager
|
||||
def patch_cp_root(self, dirname):
|
||||
"""
|
||||
@ -561,8 +604,13 @@ class APITest(NotebookTestBase):
|
||||
using a different root dir from FileContentsManager. This also keeps
|
||||
the implementation honest for use with ContentsManagers that don't map
|
||||
models to the filesystem
|
||||
"""
|
||||
|
||||
Override this method to a no-op when testing other managers.
|
||||
"""
|
||||
with TemporaryDirectory() as td:
|
||||
with self.patch_cp_root(td):
|
||||
self.test_checkpoints()
|
||||
|
||||
with TemporaryDirectory() as td:
|
||||
with self.patch_cp_root(td):
|
||||
self.test_file_checkpoints()
|
||||
|
@ -85,10 +85,10 @@ class TestFileContentsManager(TestCase):
|
||||
os.mkdir(os.path.join(td, subd))
|
||||
fm = FileContentsManager(root_dir=root)
|
||||
cpm = fm.checkpoint_manager
|
||||
cp_dir = cpm.get_checkpoint_path(
|
||||
cp_dir = cpm.checkpoint_path(
|
||||
'cp', 'test.ipynb'
|
||||
)
|
||||
cp_subdir = cpm.get_checkpoint_path(
|
||||
cp_subdir = cpm.checkpoint_path(
|
||||
'cp', '/%s/test.ipynb' % subd
|
||||
)
|
||||
self.assertNotEqual(cp_dir, cp_subdir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user