mirror of
https://github.com/jupyter/notebook.git
synced 2024-12-27 04:20:22 +08:00
abee155a11
- Draft76 was added to support QtWebKit 4 (and hence PhantomJS) a little bit longer, but before this change kernel websockets failed to start on that platform.
312 lines
12 KiB
Python
312 lines
12 KiB
Python
"""WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
|
|
|
|
The contents of this file are Copyright (c) Tornado
|
|
Used under the Apache 2.0 license
|
|
"""
|
|
|
|
|
|
from __future__ import absolute_import, division, print_function, with_statement
|
|
# Author: Jacob Kristhammar, 2010
|
|
|
|
import functools
|
|
import hashlib
|
|
import struct
|
|
import time
|
|
import tornado.escape
|
|
import tornado.web
|
|
|
|
from tornado.log import gen_log, app_log
|
|
from tornado.util import bytes_type, unicode_type
|
|
|
|
from tornado.websocket import WebSocketHandler, WebSocketProtocol13
|
|
|
|
class AllowDraftWebSocketHandler(WebSocketHandler):
|
|
"""Restore Draft76 support for tornado 4
|
|
|
|
Remove when we can run tests without phantomjs + qt4
|
|
"""
|
|
|
|
# get is unmodified except between the BEGIN/END PATCH lines
|
|
@tornado.web.asynchronous
|
|
def get(self, *args, **kwargs):
|
|
self.open_args = args
|
|
self.open_kwargs = kwargs
|
|
|
|
# Upgrade header should be present and should be equal to WebSocket
|
|
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
|
|
self.set_status(400)
|
|
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
|
|
return
|
|
|
|
# Connection header should be upgrade. Some proxy servers/load balancers
|
|
# might mess with it.
|
|
headers = self.request.headers
|
|
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
|
|
if 'upgrade' not in connection:
|
|
self.set_status(400)
|
|
self.finish("\"Connection\" must be \"Upgrade\".")
|
|
return
|
|
|
|
# Handle WebSocket Origin naming convention differences
|
|
# The difference between version 8 and 13 is that in 8 the
|
|
# client sends a "Sec-Websocket-Origin" header and in 13 it's
|
|
# simply "Origin".
|
|
if "Origin" in self.request.headers:
|
|
origin = self.request.headers.get("Origin")
|
|
else:
|
|
origin = self.request.headers.get("Sec-Websocket-Origin", None)
|
|
|
|
|
|
# If there was an origin header, check to make sure it matches
|
|
# according to check_origin. When the origin is None, we assume it
|
|
# did not come from a browser and that it can be passed on.
|
|
if origin is not None and not self.check_origin(origin):
|
|
self.set_status(403)
|
|
self.finish("Cross origin websockets not allowed")
|
|
return
|
|
|
|
self.stream = self.request.connection.detach()
|
|
self.stream.set_close_callback(self.on_connection_close)
|
|
|
|
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
|
|
self.ws_connection = WebSocketProtocol13(self)
|
|
self.ws_connection.accept_connection()
|
|
#--------------- BEGIN PATCH ----------------
|
|
elif (self.allow_draft76() and
|
|
"Sec-WebSocket-Version" not in self.request.headers):
|
|
self.ws_connection = WebSocketProtocol76(self)
|
|
self.ws_connection.accept_connection()
|
|
#--------------- END PATCH ----------------
|
|
else:
|
|
if not self.stream.closed():
|
|
self.stream.write(tornado.escape.utf8(
|
|
"HTTP/1.1 426 Upgrade Required\r\n"
|
|
"Sec-WebSocket-Version: 8\r\n\r\n"))
|
|
self.stream.close()
|
|
|
|
# 3.2 methods removed in 4.0:
|
|
def allow_draft76(self):
|
|
"""Using this class allows draft76 connections by default"""
|
|
return True
|
|
|
|
def get_websocket_scheme(self):
|
|
"""Return the url scheme used for this request, either "ws" or "wss".
|
|
This is normally decided by HTTPServer, but applications
|
|
may wish to override this if they are using an SSL proxy
|
|
that does not provide the X-Scheme header as understood
|
|
by HTTPServer.
|
|
Note that this is only used by the draft76 protocol.
|
|
"""
|
|
return "wss" if self.request.protocol == "https" else "ws"
|
|
|
|
|
|
|
|
# No modifications from tornado-3.2.2 below this line
|
|
|
|
class WebSocketProtocol(object):
|
|
"""Base class for WebSocket protocol versions.
|
|
"""
|
|
def __init__(self, handler):
|
|
self.handler = handler
|
|
self.request = handler.request
|
|
self.stream = handler.stream
|
|
self.client_terminated = False
|
|
self.server_terminated = False
|
|
|
|
def async_callback(self, callback, *args, **kwargs):
|
|
"""Wrap callbacks with this if they are used on asynchronous requests.
|
|
|
|
Catches exceptions properly and closes this WebSocket if an exception
|
|
is uncaught.
|
|
"""
|
|
if args or kwargs:
|
|
callback = functools.partial(callback, *args, **kwargs)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return callback(*args, **kwargs)
|
|
except Exception:
|
|
app_log.error("Uncaught exception in %s",
|
|
self.request.path, exc_info=True)
|
|
self._abort()
|
|
return wrapper
|
|
|
|
def on_connection_close(self):
|
|
self._abort()
|
|
|
|
def _abort(self):
|
|
"""Instantly aborts the WebSocket connection by closing the socket"""
|
|
self.client_terminated = True
|
|
self.server_terminated = True
|
|
self.stream.close() # forcibly tear down the connection
|
|
self.close() # let the subclass cleanup
|
|
|
|
|
|
class WebSocketProtocol76(WebSocketProtocol):
|
|
"""Implementation of the WebSockets protocol, version hixie-76.
|
|
|
|
This class provides basic functionality to process WebSockets requests as
|
|
specified in
|
|
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
|
|
"""
|
|
def __init__(self, handler):
|
|
WebSocketProtocol.__init__(self, handler)
|
|
self.challenge = None
|
|
self._waiting = None
|
|
|
|
def accept_connection(self):
|
|
try:
|
|
self._handle_websocket_headers()
|
|
except ValueError:
|
|
gen_log.debug("Malformed WebSocket request received")
|
|
self._abort()
|
|
return
|
|
|
|
scheme = self.handler.get_websocket_scheme()
|
|
|
|
# draft76 only allows a single subprotocol
|
|
subprotocol_header = ''
|
|
subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
|
|
if subprotocol:
|
|
selected = self.handler.select_subprotocol([subprotocol])
|
|
if selected:
|
|
assert selected == subprotocol
|
|
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
|
|
|
|
# Write the initial headers before attempting to read the challenge.
|
|
# This is necessary when using proxies (such as HAProxy), which
|
|
# need to see the Upgrade headers before passing through the
|
|
# non-HTTP traffic that follows.
|
|
self.stream.write(tornado.escape.utf8(
|
|
"HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
|
|
"Upgrade: WebSocket\r\n"
|
|
"Connection: Upgrade\r\n"
|
|
"Server: TornadoServer/%(version)s\r\n"
|
|
"Sec-WebSocket-Origin: %(origin)s\r\n"
|
|
"Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
|
|
"%(subprotocol)s"
|
|
"\r\n" % (dict(
|
|
version=tornado.version,
|
|
origin=self.request.headers["Origin"],
|
|
scheme=scheme,
|
|
host=self.request.host,
|
|
uri=self.request.uri,
|
|
subprotocol=subprotocol_header))))
|
|
self.stream.read_bytes(8, self._handle_challenge)
|
|
|
|
def challenge_response(self, challenge):
|
|
"""Generates the challenge response that's needed in the handshake
|
|
|
|
The challenge parameter should be the raw bytes as sent from the
|
|
client.
|
|
"""
|
|
key_1 = self.request.headers.get("Sec-Websocket-Key1")
|
|
key_2 = self.request.headers.get("Sec-Websocket-Key2")
|
|
try:
|
|
part_1 = self._calculate_part(key_1)
|
|
part_2 = self._calculate_part(key_2)
|
|
except ValueError:
|
|
raise ValueError("Invalid Keys/Challenge")
|
|
return self._generate_challenge_response(part_1, part_2, challenge)
|
|
|
|
def _handle_challenge(self, challenge):
|
|
try:
|
|
challenge_response = self.challenge_response(challenge)
|
|
except ValueError:
|
|
gen_log.debug("Malformed key data in WebSocket request")
|
|
self._abort()
|
|
return
|
|
self._write_response(challenge_response)
|
|
|
|
def _write_response(self, challenge):
|
|
self.stream.write(challenge)
|
|
self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
|
|
self._receive_message()
|
|
|
|
def _handle_websocket_headers(self):
|
|
"""Verifies all invariant- and required headers
|
|
|
|
If a header is missing or have an incorrect value ValueError will be
|
|
raised
|
|
"""
|
|
fields = ("Origin", "Host", "Sec-Websocket-Key1",
|
|
"Sec-Websocket-Key2")
|
|
if not all(map(lambda f: self.request.headers.get(f), fields)):
|
|
raise ValueError("Missing/Invalid WebSocket headers")
|
|
|
|
def _calculate_part(self, key):
|
|
"""Processes the key headers and calculates their key value.
|
|
|
|
Raises ValueError when feed invalid key."""
|
|
# pyflakes complains about variable reuse if both of these lines use 'c'
|
|
number = int(''.join(c for c in key if c.isdigit()))
|
|
spaces = len([c2 for c2 in key if c2.isspace()])
|
|
try:
|
|
key_number = number // spaces
|
|
except (ValueError, ZeroDivisionError):
|
|
raise ValueError
|
|
return struct.pack(">I", key_number)
|
|
|
|
def _generate_challenge_response(self, part_1, part_2, part_3):
|
|
m = hashlib.md5()
|
|
m.update(part_1)
|
|
m.update(part_2)
|
|
m.update(part_3)
|
|
return m.digest()
|
|
|
|
def _receive_message(self):
|
|
self.stream.read_bytes(1, self._on_frame_type)
|
|
|
|
def _on_frame_type(self, byte):
|
|
frame_type = ord(byte)
|
|
if frame_type == 0x00:
|
|
self.stream.read_until(b"\xff", self._on_end_delimiter)
|
|
elif frame_type == 0xff:
|
|
self.stream.read_bytes(1, self._on_length_indicator)
|
|
else:
|
|
self._abort()
|
|
|
|
def _on_end_delimiter(self, frame):
|
|
if not self.client_terminated:
|
|
self.async_callback(self.handler.on_message)(
|
|
frame[:-1].decode("utf-8", "replace"))
|
|
if not self.client_terminated:
|
|
self._receive_message()
|
|
|
|
def _on_length_indicator(self, byte):
|
|
if ord(byte) != 0x00:
|
|
self._abort()
|
|
return
|
|
self.client_terminated = True
|
|
self.close()
|
|
|
|
def write_message(self, message, binary=False):
|
|
"""Sends the given message to the client of this Web Socket."""
|
|
if binary:
|
|
raise ValueError(
|
|
"Binary messages not supported by this version of websockets")
|
|
if isinstance(message, unicode_type):
|
|
message = message.encode("utf-8")
|
|
assert isinstance(message, bytes_type)
|
|
self.stream.write(b"\x00" + message + b"\xff")
|
|
|
|
def write_ping(self, data):
|
|
"""Send ping frame."""
|
|
raise ValueError("Ping messages not supported by this version of websockets")
|
|
|
|
def close(self):
|
|
"""Closes the WebSocket connection."""
|
|
if not self.server_terminated:
|
|
if not self.stream.closed():
|
|
self.stream.write("\xff\x00")
|
|
self.server_terminated = True
|
|
if self.client_terminated:
|
|
if self._waiting is not None:
|
|
self.stream.io_loop.remove_timeout(self._waiting)
|
|
self._waiting = None
|
|
self.stream.close()
|
|
elif self._waiting is None:
|
|
self._waiting = self.stream.io_loop.add_timeout(
|
|
time.time() + 5, self._abort)
|
|
|