mirror of
https://github.com/jupyter/notebook.git
synced 2024-12-15 04:00:34 +08:00
move common websocket methods to WebSocketMixin
- origin check - ws ping used by both kernels and terminals
This commit is contained in:
parent
afdbf3942c
commit
c2c39a7c9d
@ -95,21 +95,31 @@ if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
|
||||
# draft 76 doesn't support ping
|
||||
WS_PING_INTERVAL = 0
|
||||
|
||||
class ZMQStreamHandler(WebSocketHandler):
|
||||
|
||||
if tornado.version_info < (4,1):
|
||||
"""Backport send_error from tornado 4.1 to 4.0"""
|
||||
def send_error(self, *args, **kwargs):
|
||||
if self.stream is None:
|
||||
super(WebSocketHandler, self).send_error(*args, **kwargs)
|
||||
else:
|
||||
# If we get an uncaught exception during the handshake,
|
||||
# we have no choice but to abruptly close the connection.
|
||||
# TODO: for uncaught exceptions after the handshake,
|
||||
# we can close the connection more gracefully.
|
||||
self.stream.close()
|
||||
|
||||
class WebSocketMixin(object):
|
||||
"""Mixin for common websocket options"""
|
||||
ping_callback = None
|
||||
last_ping = 0
|
||||
last_pong = 0
|
||||
|
||||
@property
|
||||
def ping_interval(self):
|
||||
"""The interval for websocket keep-alive pings.
|
||||
|
||||
Set ws_ping_interval = 0 to disable pings.
|
||||
"""
|
||||
return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
|
||||
|
||||
@property
|
||||
def ping_timeout(self):
|
||||
"""If no ping is received in this many milliseconds,
|
||||
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
|
||||
Default is max of 3 pings or 30 seconds.
|
||||
"""
|
||||
return self.settings.get('ws_ping_timeout',
|
||||
max(3 * self.ping_interval, WS_PING_INTERVAL)
|
||||
)
|
||||
|
||||
def check_origin(self, origin):
|
||||
"""Check Origin == Host or Access-Control-Allow-Origin.
|
||||
|
||||
@ -153,6 +163,58 @@ class ZMQStreamHandler(WebSocketHandler):
|
||||
"""meaningless for websockets"""
|
||||
pass
|
||||
|
||||
def open(self, *args, **kwargs):
|
||||
self.log.debug("Opening websocket %s", self.request.path)
|
||||
|
||||
# start the pinging
|
||||
if self.ping_interval > 0:
|
||||
loop = ioloop.IOLoop.current()
|
||||
self.last_ping = loop.time() # Remember time of last ping
|
||||
self.last_pong = self.last_ping
|
||||
self.ping_callback = ioloop.PeriodicCallback(
|
||||
self.send_ping, self.ping_interval, io_loop=loop,
|
||||
)
|
||||
self.ping_callback.start()
|
||||
return super(WebSocketMixin, self).open(*args, **kwargs)
|
||||
|
||||
def send_ping(self):
|
||||
"""send a ping to keep the websocket alive"""
|
||||
if self.stream.closed() and self.ping_callback is not None:
|
||||
self.ping_callback.stop()
|
||||
return
|
||||
|
||||
# check for timeout on pong. Make sure that we really have sent a recent ping in
|
||||
# case the machine with both server and client has been suspended since the last ping.
|
||||
now = ioloop.IOLoop.current().time()
|
||||
since_last_pong = 1e3 * (now - self.last_pong)
|
||||
since_last_ping = 1e3 * (now - self.last_ping)
|
||||
if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
|
||||
self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
|
||||
self.close()
|
||||
return
|
||||
|
||||
self.ping(b'')
|
||||
self.last_ping = now
|
||||
|
||||
def on_pong(self, data):
|
||||
self.last_pong = ioloop.IOLoop.current().time()
|
||||
|
||||
|
||||
class ZMQStreamHandler(WebSocketMixin, WebSocketHandler):
|
||||
|
||||
if tornado.version_info < (4,1):
|
||||
"""Backport send_error from tornado 4.1 to 4.0"""
|
||||
def send_error(self, *args, **kwargs):
|
||||
if self.stream is None:
|
||||
super(WebSocketHandler, self).send_error(*args, **kwargs)
|
||||
else:
|
||||
# If we get an uncaught exception during the handshake,
|
||||
# we have no choice but to abruptly close the connection.
|
||||
# TODO: for uncaught exceptions after the handshake,
|
||||
# we can close the connection more gracefully.
|
||||
self.stream.close()
|
||||
|
||||
|
||||
def _reserialize_reply(self, msg_list, channel=None):
|
||||
"""Reserialize a reply message using JSON.
|
||||
|
||||
@ -187,29 +249,9 @@ class ZMQStreamHandler(WebSocketHandler):
|
||||
else:
|
||||
self.write_message(msg, binary=isinstance(msg, bytes))
|
||||
|
||||
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
|
||||
ping_callback = None
|
||||
last_ping = 0
|
||||
last_pong = 0
|
||||
|
||||
@property
|
||||
def ping_interval(self):
|
||||
"""The interval for websocket keep-alive pings.
|
||||
|
||||
Set ws_ping_interval = 0 to disable pings.
|
||||
"""
|
||||
return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
|
||||
|
||||
@property
|
||||
def ping_timeout(self):
|
||||
"""If no ping is received in this many milliseconds,
|
||||
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
|
||||
Default is max of 3 pings or 30 seconds.
|
||||
"""
|
||||
return self.settings.get('ws_ping_timeout',
|
||||
max(3 * self.ping_interval, WS_PING_INTERVAL)
|
||||
)
|
||||
|
||||
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
|
||||
|
||||
def set_default_headers(self):
|
||||
"""Undo the set_default_headers in IPythonHandler
|
||||
|
||||
@ -245,37 +287,3 @@ class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
|
||||
self.log.debug("Initializing websocket connection %s", self.request.path)
|
||||
self.session = Session(config=self.config)
|
||||
|
||||
def open(self, *args, **kwargs):
|
||||
self.log.debug("Opening websocket %s", self.request.path)
|
||||
|
||||
# start the pinging
|
||||
if self.ping_interval > 0:
|
||||
loop = ioloop.IOLoop.current()
|
||||
self.last_ping = loop.time() # Remember time of last ping
|
||||
self.last_pong = self.last_ping
|
||||
self.ping_callback = ioloop.PeriodicCallback(
|
||||
self.send_ping, self.ping_interval, io_loop=loop,
|
||||
)
|
||||
self.ping_callback.start()
|
||||
|
||||
def send_ping(self):
|
||||
"""send a ping to keep the websocket alive"""
|
||||
if self.stream.closed() and self.ping_callback is not None:
|
||||
self.ping_callback.stop()
|
||||
return
|
||||
|
||||
# check for timeout on pong. Make sure that we really have sent a recent ping in
|
||||
# case the machine with both server and client has been suspended since the last ping.
|
||||
now = ioloop.IOLoop.current().time()
|
||||
since_last_pong = 1e3 * (now - self.last_pong)
|
||||
since_last_ping = 1e3 * (now - self.last_ping)
|
||||
if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
|
||||
self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
|
||||
self.close()
|
||||
return
|
||||
|
||||
self.ping(b'')
|
||||
self.last_ping = now
|
||||
|
||||
def on_pong(self, data):
|
||||
self.last_pong = ioloop.IOLoop.current().time()
|
||||
|
@ -4,15 +4,10 @@
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import tornado
|
||||
from tornado import web
|
||||
import terminado
|
||||
from ..base.handlers import IPythonHandler
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse # Py 3
|
||||
except ImportError:
|
||||
from urlparse import urlparse # Py 2
|
||||
from ..base.zmqhandlers import WebSocketMixin
|
||||
|
||||
|
||||
class TerminalHandler(IPythonHandler):
|
||||
@ -22,22 +17,18 @@ class TerminalHandler(IPythonHandler):
|
||||
self.write(self.render_template('terminal.html',
|
||||
ws_path="terminals/websocket/%s" % term_name))
|
||||
|
||||
class TermSocket(IPythonHandler, terminado.TermSocket):
|
||||
|
||||
def set_default_headers(self):
|
||||
pass
|
||||
class TermSocket(WebSocketMixin, IPythonHandler, terminado.TermSocket):
|
||||
|
||||
def origin_check(self):
|
||||
"""Override Terminado's origin_check with our own check_origin, confusingly"""
|
||||
return self.check_origin()
|
||||
"""Terminado adds redundant origin_check
|
||||
|
||||
Tornado already calls check_origin, so don't do anything here.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
if not self.get_current_user():
|
||||
raise web.HTTPError(403)
|
||||
return super(TermSocket, self).get(*args, **kwargs)
|
||||
|
||||
def clear_cookie(self, *args, **kwargs):
|
||||
"""meaningless for websockets"""
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user