use session.deserialize to unpack message for rate limiting

rather than hardcoding json.loads

Messages should **never** be deserialized by any means other than the Session API.
This commit is contained in:
Min RK 2016-01-21 11:26:27 +01:00
parent 9e2c95dc07
commit c280b773fb
2 changed files with 17 additions and 9 deletions

View File

@ -218,16 +218,23 @@ class ZMQStreamHandler(WebSocketMixin, WebSocketHandler):
self.stream.close()
def _reserialize_reply(self, msg_list, channel=None):
def _reserialize_reply(self, msg_or_list, channel=None):
"""Reserialize a reply message using JSON.
This takes the msg list from the ZMQ socket, deserializes it using
self.session and then serializes the result using JSON. This method
should be used by self._on_zmq_reply to build messages that can
msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
If it is the zmq list, it will be deserialized with self.session.
This takes the msg list from the ZMQ socket and serializes the result for the websocket.
This method should be used by self._on_zmq_reply to build messages that can
be sent back to the browser.
"""
idents, msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(msg_list)
if isinstance(msg_or_list, dict):
# already unpacked
msg = msg_or_list
else:
idents, msg_list = self.session.feed_identities(msg_or_list)
msg = self.session.deserialize(msg_list)
if channel:
msg['channel'] = channel
if msg['buffers']:

View File

@ -269,9 +269,10 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler):
def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)
parent = msg['parent_header']
def write_stderr(error_message):
self.log.warn(error_message)
parent = json.loads(fed_msg_list[2])
msg = self.session.msg("stream",
content={"text": error_message, "name": "stderr"},
parent=parent
@ -280,7 +281,7 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler):
self.write_message(json.dumps(msg, default=date_default))
channel = getattr(stream, 'channel', None)
msg_type = json.loads(fed_msg_list[1])['msg_type']
msg_type = msg['header']['msg_type']
if channel == 'iopub' and msg_type not in {'status', 'comm_open', 'execute_input'}:
# Remove the counts queued for removal.
@ -345,7 +346,7 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler):
# If either of the limit flags are set, do not send the message.
if self._iopub_msgs_exceeded or self._iopub_data_exceeded:
return
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg_list)
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)
def on_close(self):