mirror of
https://github.com/jupyter/notebook.git
synced 2025-01-18 11:55:46 +08:00
commit
7af9f5f1f8
@ -152,6 +152,48 @@ class IPythonHandler(AuthenticatedHandler):
|
||||
def project_dir(self):
|
||||
return self.notebook_manager.notebook_dir
|
||||
|
||||
#---------------------------------------------------------------
|
||||
# CORS
|
||||
#---------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def allow_origin(self):
|
||||
"""Normal Access-Control-Allow-Origin"""
|
||||
return self.settings.get('allow_origin', '')
|
||||
|
||||
@property
|
||||
def allow_origin_pat(self):
|
||||
"""Regular expression version of allow_origin"""
|
||||
return self.settings.get('allow_origin_pat', None)
|
||||
|
||||
@property
|
||||
def allow_credentials(self):
|
||||
"""Whether to set Access-Control-Allow-Credentials"""
|
||||
return self.settings.get('allow_credentials', False)
|
||||
|
||||
def set_default_headers(self):
|
||||
"""Add CORS headers, if defined"""
|
||||
super(IPythonHandler, self).set_default_headers()
|
||||
if self.allow_origin:
|
||||
self.set_header("Access-Control-Allow-Origin", self.allow_origin)
|
||||
elif self.allow_origin_pat:
|
||||
origin = self.get_origin()
|
||||
if origin and self.allow_origin_pat.match(origin):
|
||||
self.set_header("Access-Control-Allow-Origin", origin)
|
||||
if self.allow_credentials:
|
||||
self.set_header("Access-Control-Allow-Credentials", 'true')
|
||||
|
||||
def get_origin(self):
|
||||
# Handle WebSocket Origin naming convention differences
|
||||
# The difference between version 8 and 13 is that in 8 the
|
||||
# client sends a "Sec-Websocket-Origin" header and in 13 it's
|
||||
# simply "Origin".
|
||||
if "Origin" in self.request.headers:
|
||||
origin = self.request.headers.get("Origin")
|
||||
else:
|
||||
origin = self.request.headers.get("Sec-Websocket-Origin", None)
|
||||
return origin
|
||||
|
||||
#---------------------------------------------------------------
|
||||
# template rendering
|
||||
#---------------------------------------------------------------
|
||||
|
@ -15,6 +15,8 @@ try:
|
||||
except ImportError:
|
||||
from Cookie import SimpleCookie # Py 2
|
||||
import logging
|
||||
|
||||
import tornado
|
||||
from tornado import web
|
||||
from tornado import websocket
|
||||
|
||||
@ -26,29 +28,36 @@ from .handlers import IPythonHandler
|
||||
|
||||
|
||||
class ZMQStreamHandler(websocket.WebSocketHandler):
|
||||
|
||||
def same_origin(self):
|
||||
"""Check to see that origin and host match in the headers."""
|
||||
|
||||
# The difference between version 8 and 13 is that in 8 the
|
||||
# client sends a "Sec-Websocket-Origin" header and in 13 it's
|
||||
# simply "Origin".
|
||||
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
|
||||
origin_header = self.request.headers.get("Sec-Websocket-Origin")
|
||||
else:
|
||||
origin_header = self.request.headers.get("Origin")
|
||||
|
||||
def check_origin(self, origin):
|
||||
"""Check Origin == Host or Access-Control-Allow-Origin.
|
||||
|
||||
Tornado >= 4 calls this method automatically, raising 403 if it returns False.
|
||||
We call it explicitly in `open` on Tornado < 4.
|
||||
"""
|
||||
if self.allow_origin == '*':
|
||||
return True
|
||||
|
||||
host = self.request.headers.get("Host")
|
||||
|
||||
# If no header is provided, assume we can't verify origin
|
||||
if(origin_header is None or host is None):
|
||||
if(origin is None or host is None):
|
||||
return False
|
||||
|
||||
host_origin = "{0}://{1}".format(self.request.protocol, host)
|
||||
|
||||
# OK if origin matches host
|
||||
if origin == host_origin:
|
||||
return True
|
||||
|
||||
# Check CORS headers
|
||||
if self.allow_origin:
|
||||
return self.allow_origin == origin
|
||||
elif self.allow_origin_pat:
|
||||
return bool(self.allow_origin_pat.match(origin))
|
||||
else:
|
||||
# No CORS headers deny the request
|
||||
return False
|
||||
|
||||
parsed_origin = urlparse(origin_header)
|
||||
origin = parsed_origin.netloc
|
||||
|
||||
# Check to see that origin matches host directly, including ports
|
||||
return origin == host
|
||||
|
||||
def clear_cookie(self, *args, **kwargs):
|
||||
"""meaningless for websockets"""
|
||||
@ -96,13 +105,21 @@ class ZMQStreamHandler(websocket.WebSocketHandler):
|
||||
|
||||
|
||||
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
|
||||
def set_default_headers(self):
|
||||
"""Undo the set_default_headers in IPythonHandler
|
||||
|
||||
which doesn't make sense for websockets
|
||||
"""
|
||||
pass
|
||||
|
||||
def open(self, kernel_id):
|
||||
self.kernel_id = cast_unicode(kernel_id, 'ascii')
|
||||
# Check to see that origin matches host directly, including ports
|
||||
if not self.same_origin():
|
||||
self.log.warn("Cross Origin WebSocket Attempt.")
|
||||
raise web.HTTPError(404)
|
||||
# Tornado 4 already does CORS checking
|
||||
if tornado.version_info[0] < 4:
|
||||
if not self.check_origin(self.get_origin()):
|
||||
self.log.warn("Cross Origin WebSocket Attempt from %s", self.get_origin())
|
||||
raise web.HTTPError(403)
|
||||
|
||||
self.session = Session(config=self.config)
|
||||
self.save_on_message = self.on_message
|
||||
|
@ -13,6 +13,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import select
|
||||
import signal
|
||||
import socket
|
||||
@ -334,8 +335,34 @@ class NotebookApp(BaseIPythonApplication):
|
||||
self.file_to_run = base
|
||||
self.notebook_dir = path
|
||||
|
||||
# Network related information.
|
||||
|
||||
# Network related information
|
||||
|
||||
allow_origin = Unicode('', config=True,
|
||||
help="""Set the Access-Control-Allow-Origin header
|
||||
|
||||
Use '*' to allow any origin to access your server.
|
||||
|
||||
Takes precedence over allow_origin_pat.
|
||||
"""
|
||||
)
|
||||
|
||||
allow_origin_pat = Unicode('', config=True,
|
||||
help="""Use a regular expression for the Access-Control-Allow-Origin header
|
||||
|
||||
Requests from an origin matching the expression will get replies with:
|
||||
|
||||
Access-Control-Allow-Origin: origin
|
||||
|
||||
where `origin` is the origin of the request.
|
||||
|
||||
Ignored if allow_origin is set.
|
||||
"""
|
||||
)
|
||||
|
||||
allow_credentials = Bool(False, config=True,
|
||||
help="Set the Access-Control-Allow-Credentials: true header"
|
||||
)
|
||||
|
||||
ip = Unicode('localhost', config=True,
|
||||
help="The IP address the notebook server will listen on."
|
||||
)
|
||||
@ -650,6 +677,10 @@ class NotebookApp(BaseIPythonApplication):
|
||||
|
||||
def init_webapp(self):
|
||||
"""initialize tornado webapp and httpserver"""
|
||||
self.webapp_settings['allow_origin'] = self.allow_origin
|
||||
self.webapp_settings['allow_origin_pat'] = re.compile(self.allow_origin_pat)
|
||||
self.webapp_settings['allow_credentials'] = self.allow_credentials
|
||||
|
||||
self.web_app = NotebookWebApplication(
|
||||
self, self.kernel_manager, self.notebook_manager,
|
||||
self.cluster_manager, self.session_manager, self.kernel_spec_manager,
|
||||
|
Loading…
Reference in New Issue
Block a user