diff --git a/IPython/html/base/handlers.py b/IPython/html/base/handlers.py
index d8d107cf1..e8e60297b 100644
--- a/IPython/html/base/handlers.py
+++ b/IPython/html/base/handlers.py
@@ -152,6 +152,48 @@ class IPythonHandler(AuthenticatedHandler):
def project_dir(self):
return self.notebook_manager.notebook_dir
+ #---------------------------------------------------------------
+ # CORS
+ #---------------------------------------------------------------
+
+ @property
+ def allow_origin(self):
+ """Normal Access-Control-Allow-Origin"""
+ return self.settings.get('allow_origin', '')
+
+ @property
+ def allow_origin_pat(self):
+ """Regular expression version of allow_origin"""
+ return self.settings.get('allow_origin_pat', None)
+
+ @property
+ def allow_credentials(self):
+ """Whether to set Access-Control-Allow-Credentials"""
+ return self.settings.get('allow_credentials', False)
+
+ def set_default_headers(self):
+ """Add CORS headers, if defined"""
+ super(IPythonHandler, self).set_default_headers()
+ if self.allow_origin:
+ self.set_header("Access-Control-Allow-Origin", self.allow_origin)
+ elif self.allow_origin_pat:
+ origin = self.get_origin()
+ if origin and self.allow_origin_pat.match(origin):
+ self.set_header("Access-Control-Allow-Origin", origin)
+ if self.allow_credentials:
+ self.set_header("Access-Control-Allow-Credentials", 'true')
+
+ def get_origin(self):
+ # Handle WebSocket Origin naming convention differences
+ # The difference between version 8 and 13 is that in 8 the
+ # client sends a "Sec-Websocket-Origin" header and in 13 it's
+ # simply "Origin".
+ if "Origin" in self.request.headers:
+ origin = self.request.headers.get("Origin")
+ else:
+ origin = self.request.headers.get("Sec-Websocket-Origin", None)
+ return origin
+
#---------------------------------------------------------------
# template rendering
#---------------------------------------------------------------
diff --git a/IPython/html/base/zmqhandlers.py b/IPython/html/base/zmqhandlers.py
index 8999b2672..3e3f45123 100644
--- a/IPython/html/base/zmqhandlers.py
+++ b/IPython/html/base/zmqhandlers.py
@@ -15,6 +15,8 @@ try:
except ImportError:
from Cookie import SimpleCookie # Py 2
import logging
+
+import tornado
from tornado import web
from tornado import websocket
@@ -26,29 +28,36 @@ from .handlers import IPythonHandler
class ZMQStreamHandler(websocket.WebSocketHandler):
-
- def same_origin(self):
- """Check to see that origin and host match in the headers."""
-
- # The difference between version 8 and 13 is that in 8 the
- # client sends a "Sec-Websocket-Origin" header and in 13 it's
- # simply "Origin".
- if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
- origin_header = self.request.headers.get("Sec-Websocket-Origin")
- else:
- origin_header = self.request.headers.get("Origin")
+
+ def check_origin(self, origin):
+ """Check Origin == Host or Access-Control-Allow-Origin.
+
+ Tornado >= 4 calls this method automatically, raising 403 if it returns False.
+ We call it explicitly in `open` on Tornado < 4.
+ """
+ if self.allow_origin == '*':
+ return True
host = self.request.headers.get("Host")
# If no header is provided, assume we can't verify origin
- if(origin_header is None or host is None):
+ if(origin is None or host is None):
+ return False
+
+ host_origin = "{0}://{1}".format(self.request.protocol, host)
+
+ # OK if origin matches host
+ if origin == host_origin:
+ return True
+
+ # Check CORS headers
+ if self.allow_origin:
+ return self.allow_origin == origin
+ elif self.allow_origin_pat:
+ return bool(self.allow_origin_pat.match(origin))
+ else:
+ # No CORS headers deny the request
return False
-
- parsed_origin = urlparse(origin_header)
- origin = parsed_origin.netloc
-
- # Check to see that origin matches host directly, including ports
- return origin == host
def clear_cookie(self, *args, **kwargs):
"""meaningless for websockets"""
@@ -96,13 +105,21 @@ class ZMQStreamHandler(websocket.WebSocketHandler):
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
+ def set_default_headers(self):
+ """Undo the set_default_headers in IPythonHandler
+
+ which doesn't make sense for websockets
+ """
+ pass
def open(self, kernel_id):
self.kernel_id = cast_unicode(kernel_id, 'ascii')
# Check to see that origin matches host directly, including ports
- if not self.same_origin():
- self.log.warn("Cross Origin WebSocket Attempt.")
- raise web.HTTPError(404)
+ # Tornado 4 already does CORS checking
+ if tornado.version_info[0] < 4:
+ if not self.check_origin(self.get_origin()):
+ self.log.warn("Cross Origin WebSocket Attempt from %s", self.get_origin())
+ raise web.HTTPError(403)
self.session = Session(config=self.config)
self.save_on_message = self.on_message
diff --git a/IPython/html/notebookapp.py b/IPython/html/notebookapp.py
index fad0dc740..2dc404c5a 100644
--- a/IPython/html/notebookapp.py
+++ b/IPython/html/notebookapp.py
@@ -13,6 +13,7 @@ import json
import logging
import os
import random
+import re
import select
import signal
import socket
@@ -334,8 +335,34 @@ class NotebookApp(BaseIPythonApplication):
self.file_to_run = base
self.notebook_dir = path
- # Network related information.
-
+ # Network related information
+
+ allow_origin = Unicode('', config=True,
+ help="""Set the Access-Control-Allow-Origin header
+
+ Use '*' to allow any origin to access your server.
+
+ Takes precedence over allow_origin_pat.
+ """
+ )
+
+ allow_origin_pat = Unicode('', config=True,
+ help="""Use a regular expression for the Access-Control-Allow-Origin header
+
+ Requests from an origin matching the expression will get replies with:
+
+ Access-Control-Allow-Origin: origin
+
+ where `origin` is the origin of the request.
+
+ Ignored if allow_origin is set.
+ """
+ )
+
+ allow_credentials = Bool(False, config=True,
+ help="Set the Access-Control-Allow-Credentials: true header"
+ )
+
ip = Unicode('localhost', config=True,
help="The IP address the notebook server will listen on."
)
@@ -650,6 +677,10 @@ class NotebookApp(BaseIPythonApplication):
def init_webapp(self):
"""initialize tornado webapp and httpserver"""
+ self.webapp_settings['allow_origin'] = self.allow_origin
+ self.webapp_settings['allow_origin_pat'] = re.compile(self.allow_origin_pat)
+ self.webapp_settings['allow_credentials'] = self.allow_credentials
+
self.web_app = NotebookWebApplication(
self, self.kernel_manager, self.notebook_manager,
self.cluster_manager, self.session_manager, self.kernel_spec_manager,