move common websocket methods to WebSocketMixin

- origin check
- ws ping

used by both kernels and terminals
This commit is contained in:
Min RK 2015-10-12 14:46:07 +02:00
parent afdbf3942c
commit c2c39a7c9d
2 changed files with 84 additions and 85 deletions

View File

@ -95,21 +95,31 @@ if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
# draft 76 doesn't support ping
WS_PING_INTERVAL = 0
class ZMQStreamHandler(WebSocketHandler):
if tornado.version_info < (4,1):
"""Backport send_error from tornado 4.1 to 4.0"""
def send_error(self, *args, **kwargs):
if self.stream is None:
super(WebSocketHandler, self).send_error(*args, **kwargs)
else:
# If we get an uncaught exception during the handshake,
# we have no choice but to abruptly close the connection.
# TODO: for uncaught exceptions after the handshake,
# we can close the connection more gracefully.
self.stream.close()
class WebSocketMixin(object):
"""Mixin for common websocket options"""
ping_callback = None
last_ping = 0
last_pong = 0
@property
def ping_interval(self):
"""The interval for websocket keep-alive pings.
Set ws_ping_interval = 0 to disable pings.
"""
return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
@property
def ping_timeout(self):
"""If no ping is received in this many milliseconds,
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
Default is max of 3 pings or 30 seconds.
"""
return self.settings.get('ws_ping_timeout',
max(3 * self.ping_interval, WS_PING_INTERVAL)
)
def check_origin(self, origin):
"""Check Origin == Host or Access-Control-Allow-Origin.
@ -153,6 +163,58 @@ class ZMQStreamHandler(WebSocketHandler):
"""meaningless for websockets"""
pass
def open(self, *args, **kwargs):
self.log.debug("Opening websocket %s", self.request.path)
# start the pinging
if self.ping_interval > 0:
loop = ioloop.IOLoop.current()
self.last_ping = loop.time() # Remember time of last ping
self.last_pong = self.last_ping
self.ping_callback = ioloop.PeriodicCallback(
self.send_ping, self.ping_interval, io_loop=loop,
)
self.ping_callback.start()
return super(WebSocketMixin, self).open(*args, **kwargs)
def send_ping(self):
"""send a ping to keep the websocket alive"""
if self.stream.closed() and self.ping_callback is not None:
self.ping_callback.stop()
return
# check for timeout on pong. Make sure that we really have sent a recent ping in
# case the machine with both server and client has been suspended since the last ping.
now = ioloop.IOLoop.current().time()
since_last_pong = 1e3 * (now - self.last_pong)
since_last_ping = 1e3 * (now - self.last_ping)
if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
self.close()
return
self.ping(b'')
self.last_ping = now
def on_pong(self, data):
self.last_pong = ioloop.IOLoop.current().time()
class ZMQStreamHandler(WebSocketMixin, WebSocketHandler):
if tornado.version_info < (4,1):
"""Backport send_error from tornado 4.1 to 4.0"""
def send_error(self, *args, **kwargs):
if self.stream is None:
super(WebSocketHandler, self).send_error(*args, **kwargs)
else:
# If we get an uncaught exception during the handshake,
# we have no choice but to abruptly close the connection.
# TODO: for uncaught exceptions after the handshake,
# we can close the connection more gracefully.
self.stream.close()
def _reserialize_reply(self, msg_list, channel=None):
"""Reserialize a reply message using JSON.
@ -187,29 +249,9 @@ class ZMQStreamHandler(WebSocketHandler):
else:
self.write_message(msg, binary=isinstance(msg, bytes))
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
ping_callback = None
last_ping = 0
last_pong = 0
@property
def ping_interval(self):
"""The interval for websocket keep-alive pings.
Set ws_ping_interval = 0 to disable pings.
"""
return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
@property
def ping_timeout(self):
"""If no ping is received in this many milliseconds,
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
Default is max of 3 pings or 30 seconds.
"""
return self.settings.get('ws_ping_timeout',
max(3 * self.ping_interval, WS_PING_INTERVAL)
)
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
def set_default_headers(self):
"""Undo the set_default_headers in IPythonHandler
@ -245,37 +287,3 @@ class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
self.log.debug("Initializing websocket connection %s", self.request.path)
self.session = Session(config=self.config)
def open(self, *args, **kwargs):
self.log.debug("Opening websocket %s", self.request.path)
# start the pinging
if self.ping_interval > 0:
loop = ioloop.IOLoop.current()
self.last_ping = loop.time() # Remember time of last ping
self.last_pong = self.last_ping
self.ping_callback = ioloop.PeriodicCallback(
self.send_ping, self.ping_interval, io_loop=loop,
)
self.ping_callback.start()
def send_ping(self):
"""send a ping to keep the websocket alive"""
if self.stream.closed() and self.ping_callback is not None:
self.ping_callback.stop()
return
# check for timeout on pong. Make sure that we really have sent a recent ping in
# case the machine with both server and client has been suspended since the last ping.
now = ioloop.IOLoop.current().time()
since_last_pong = 1e3 * (now - self.last_pong)
since_last_ping = 1e3 * (now - self.last_ping)
if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
self.close()
return
self.ping(b'')
self.last_ping = now
def on_pong(self, data):
self.last_pong = ioloop.IOLoop.current().time()

View File

@ -4,15 +4,10 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import tornado
from tornado import web
import terminado
from ..base.handlers import IPythonHandler
try:
from urllib.parse import urlparse # Py 3
except ImportError:
from urlparse import urlparse # Py 2
from ..base.zmqhandlers import WebSocketMixin
class TerminalHandler(IPythonHandler):
@ -22,22 +17,18 @@ class TerminalHandler(IPythonHandler):
self.write(self.render_template('terminal.html',
ws_path="terminals/websocket/%s" % term_name))
class TermSocket(IPythonHandler, terminado.TermSocket):
def set_default_headers(self):
pass
class TermSocket(WebSocketMixin, IPythonHandler, terminado.TermSocket):
def origin_check(self):
"""Override Terminado's origin_check with our own check_origin, confusingly"""
return self.check_origin()
"""Terminado adds redundant origin_check
Tornado already calls check_origin, so don't do anything here.
"""
return True
def get(self, *args, **kwargs):
if not self.get_current_user():
raise web.HTTPError(403)
return super(TermSocket, self).get(*args, **kwargs)
def clear_cookie(self, *args, **kwargs):
"""meaningless for websockets"""
pass