Merge pull request #6412 from takluyver/sessions-rest-api-fix

Sessions rest api fix
This commit is contained in:
Matthias Bussonnier 2014-09-06 15:04:13 -07:00
commit 97e99d0661
3 changed files with 64 additions and 8 deletions

View File

@ -53,7 +53,7 @@ class SessionManager(LoggingConfigurable):
"""Start a database connection"""
if self._connection is None:
self._connection = sqlite3.connect(':memory:')
self._connection.row_factory = self.row_factory
self._connection.row_factory = sqlite3.Row
return self._connection
def __del__(self):
@ -141,14 +141,20 @@ class SessionManager(LoggingConfigurable):
query = "SELECT * FROM session WHERE %s" % (' AND '.join(conditions))
self.cursor.execute(query, list(kwargs.values()))
model = self.cursor.fetchone()
if model is None:
try:
row = self.cursor.fetchone()
except KeyError:
# The kernel is missing, so the session just got deleted.
row = None
if row is None:
q = []
for key, value in kwargs.items():
q.append("%s=%r" % (key, value))
raise web.HTTPError(404, u'Session not found: %s' % (', '.join(q)))
return model
return self.row_to_model(row)
def update_session(self, session_id, **kwargs):
"""Updates the values in the session database.
@ -179,9 +185,16 @@ class SessionManager(LoggingConfigurable):
query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets))
self.cursor.execute(query, list(kwargs.values()) + [session_id])
def row_factory(self, cursor, row):
def row_to_model(self, row):
"""Takes sqlite database session row and turns it into a dictionary"""
row = sqlite3.Row(cursor, row)
if row['kernel_id'] not in self.kernel_manager:
# The kernel was killed or died without deleting the session.
# We can't use delete_session here because that tries to find
# and shut down the kernel.
self.cursor.execute("DELETE FROM session WHERE session_id=?",
(row['session_id'],))
raise KeyError
model = {
'id': row['session_id'],
'notebook': {
@ -196,7 +209,15 @@ class SessionManager(LoggingConfigurable):
"""Returns a list of dictionaries containing all the information from
the session database"""
c = self.cursor.execute("SELECT * FROM session")
return list(c.fetchall())
result = []
# We need to use fetchall() here, because row_to_model can delete rows,
# which messes up the cursor if we're iterating over rows.
for row in c.fetchall():
try:
result.append(self.row_to_model(row))
except KeyError:
pass
return result
def delete_session(self, session_id):
"""Deletes the row in the session database with given session_id"""

View File

@ -47,6 +47,17 @@ class TestSessionManager(TestCase):
kernel_name='foo')['id']
self.assertRaises(TypeError, sm.get_session, bad_id=session_id) # Bad keyword
def test_get_session_dead_kernel(self):
sm = SessionManager(kernel_manager=DummyMKM())
session = sm.create_session(name='test1.ipynb', path='/path/to/1/', kernel_name='python')
# kill the kernel
sm.kernel_manager.shutdown_kernel(session['kernel']['id'])
with self.assertRaises(KeyError):
sm.get_session(session_id=session['id'])
# no sessions left
listed = sm.list_sessions()
self.assertEqual(listed, [])
def test_list_sessions(self):
sm = SessionManager(kernel_manager=DummyMKM())
sessions = [
@ -63,6 +74,30 @@ class TestSessionManager(TestCase):
'path': u'/path/to/3/'}, 'kernel':{'id':u'C', 'name':'python'}}]
self.assertEqual(sessions, expected)
def test_list_sessions_dead_kernel(self):
sm = SessionManager(kernel_manager=DummyMKM())
sessions = [
sm.create_session(name='test1.ipynb', path='/path/to/1/', kernel_name='python'),
sm.create_session(name='test2.ipynb', path='/path/to/2/', kernel_name='python'),
]
# kill one of the kernels
sm.kernel_manager.shutdown_kernel(sessions[0]['kernel']['id'])
listed = sm.list_sessions()
expected = [
{
'id': sessions[1]['id'],
'notebook': {
'name': u'test2.ipynb',
'path': u'/path/to/2/',
},
'kernel': {
'id': u'B',
'name':'python',
}
}
]
self.assertEqual(listed, expected)
def test_update_session(self):
sm = SessionManager(kernel_manager=DummyMKM())
session_id = sm.create_session(name='test.ipynb', path='/path/to/',

View File

@ -95,7 +95,7 @@ def assert_http_error(status, msg=None):
except requests.HTTPError as e:
real_status = e.response.status_code
assert real_status == status, \
"Expected status %d, got %d" % (real_status, status)
"Expected status %d, got %d" % (status, real_status)
if msg:
assert msg in str(e), e
else: