559 lines
18 KiB
Python
559 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright © Spyder Project Contributors
|
|
# Licensed under the terms of the MIT License
|
|
# (see spyder/__init__.py for details)
|
|
|
|
"""
|
|
Class that handles communications between Spyder kernel and frontend.
|
|
|
|
Comms transmit data in a list of buffers, and in a json-able dictionnary.
|
|
Here, we only support a buffer list with a single element.
|
|
|
|
The messages exchanged have the following msg_dict:
|
|
|
|
```
|
|
msg_dict = {
|
|
'spyder_msg_type': spyder_msg_type,
|
|
'content': content,
|
|
}
|
|
```
|
|
|
|
The buffer is generated by cloudpickle using `PICKLE_PROTOCOL = 2`.
|
|
|
|
To simplify the usage of messaging, we use a higher level function calling
|
|
mechanism:
|
|
- The `remote_call` method returns a RemoteCallHandler object
|
|
- By calling an attribute of this object, the call is sent to the other
|
|
side of the comm.
|
|
- If the `_wait_reply` is implemented, remote_call can be called with
|
|
`blocking=True`, which will wait for a reply sent by the other side.
|
|
|
|
The messages exchanged are:
|
|
- Function call (spyder_msg_type = 'remote_call'):
|
|
- The content is a dictionnary {
|
|
'call_name': The name of the function to be called,
|
|
'call_id': uuid to match the request to a potential reply,
|
|
'settings': A dictionnary of settings,
|
|
}
|
|
- The buffer encodes a dictionnary {
|
|
'call_args': The function args,
|
|
'call_kwargs': The function kwargs,
|
|
}
|
|
- If the 'settings' has `'blocking' = True`, a reply is sent.
|
|
(spyder_msg_type = 'remote_call_reply'):
|
|
- The buffer contains the return value of the function.
|
|
- The 'content' is a dict with: {
|
|
'is_error': a boolean indicating if the return value is an
|
|
exception to be raised.
|
|
'call_id': The uuid from above,
|
|
'call_name': The function name (mostly for debugging)
|
|
}
|
|
"""
|
|
from __future__ import print_function
|
|
|
|
import cloudpickle
|
|
import pickle
|
|
import logging
|
|
import sys
|
|
import uuid
|
|
import traceback
|
|
|
|
from spyder_kernels.py3compat import PY2, PY3
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# To be able to get and set variables between Python 2 and 3
|
|
DEFAULT_PICKLE_PROTOCOL = 2
|
|
|
|
# Max timeout (in secs) for blocking calls
|
|
TIMEOUT = 3
|
|
|
|
|
|
class CommError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class CommsErrorWrapper():
|
|
def __init__(self, call_name, call_id):
|
|
self.call_name = call_name
|
|
self.call_id = call_id
|
|
self.etype, self.error, tb = sys.exc_info()
|
|
self.tb = traceback.extract_tb(tb)
|
|
|
|
def raise_error(self):
|
|
"""
|
|
Raise the error while adding informations on the callback.
|
|
"""
|
|
# Add the traceback in the error, so it can be handled upstream
|
|
raise self.etype(self)
|
|
|
|
def format_error(self):
|
|
"""
|
|
Format the error received from the other side and returns a list of
|
|
strings.
|
|
"""
|
|
lines = (['Exception in comms call {}:\n'.format(self.call_name)]
|
|
+ traceback.format_list(self.tb)
|
|
+ traceback.format_exception_only(self.etype, self.error))
|
|
return lines
|
|
|
|
def print_error(self, file=None):
|
|
"""
|
|
Print the error to file or to sys.stderr if file is None.
|
|
"""
|
|
if file is None:
|
|
file = sys.stderr
|
|
for line in self.format_error():
|
|
print(line, file=file)
|
|
|
|
def __str__(self):
|
|
"""Get string representation."""
|
|
return str(self.error)
|
|
|
|
def __repr__(self):
|
|
"""Get repr."""
|
|
return repr(self.error)
|
|
|
|
|
|
# Replace sys.excepthook to handle CommsErrorWrapper
|
|
sys_excepthook = sys.excepthook
|
|
|
|
|
|
def comm_excepthook(type, value, tb):
|
|
if len(value.args) == 1 and isinstance(value.args[0], CommsErrorWrapper):
|
|
traceback.print_tb(tb)
|
|
value.args[0].print_error()
|
|
return
|
|
sys_excepthook(type, value, tb)
|
|
|
|
|
|
sys.excepthook = comm_excepthook
|
|
|
|
|
|
class CommBase(object):
|
|
"""
|
|
Class with the necessary attributes and methods to handle
|
|
communications between a kernel and a frontend.
|
|
Subclasses must open a comm and register it with `self._register_comm`.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(CommBase, self).__init__()
|
|
self.calling_comm_id = None
|
|
self._comms = {}
|
|
# Handlers
|
|
self._message_handlers = {}
|
|
self._remote_call_handlers = {}
|
|
# Lists of reply numbers
|
|
self._reply_inbox = {}
|
|
self._reply_waitlist = {}
|
|
|
|
self._register_message_handler(
|
|
'remote_call', self._handle_remote_call)
|
|
self._register_message_handler(
|
|
'remote_call_reply', self._handle_remote_call_reply)
|
|
self.register_call_handler('_set_pickle_protocol',
|
|
self._set_pickle_protocol)
|
|
|
|
def get_comm_id_list(self, comm_id=None):
|
|
"""Get a list of comms id."""
|
|
if comm_id is None:
|
|
id_list = list(self._comms.keys())
|
|
else:
|
|
id_list = [comm_id]
|
|
return id_list
|
|
|
|
def close(self, comm_id=None):
|
|
"""Close the comm and notify the other side."""
|
|
id_list = self.get_comm_id_list(comm_id)
|
|
|
|
for comm_id in id_list:
|
|
try:
|
|
self._comms[comm_id]['comm'].close()
|
|
del self._comms[comm_id]
|
|
except KeyError:
|
|
pass
|
|
|
|
def is_open(self, comm_id=None):
|
|
"""Check to see if the comm is open."""
|
|
if comm_id is None:
|
|
return len(self._comms) > 0
|
|
return comm_id in self._comms
|
|
|
|
def is_ready(self, comm_id=None):
|
|
"""
|
|
Check to see if the other side replied.
|
|
|
|
The check is made with _set_pickle_protocol as this is the first call
|
|
made. If comm_id is not specified, check all comms.
|
|
"""
|
|
id_list = self.get_comm_id_list(comm_id)
|
|
if len(id_list) == 0:
|
|
return False
|
|
return all([self._comms[cid]['status'] == 'ready' for cid in id_list])
|
|
|
|
def register_call_handler(self, call_name, handler):
|
|
"""
|
|
Register a remote call handler.
|
|
|
|
Parameters
|
|
----------
|
|
call_name : str
|
|
The name of the called function.
|
|
handler : callback
|
|
A function to handle the request, or `None` to unregister
|
|
`call_name`.
|
|
"""
|
|
if not handler:
|
|
self._remote_call_handlers.pop(call_name, None)
|
|
return
|
|
|
|
self._remote_call_handlers[call_name] = handler
|
|
|
|
def remote_call(self, comm_id=None, callback=None, **settings):
|
|
"""Get a handler for remote calls."""
|
|
return RemoteCallFactory(self, comm_id, callback, **settings)
|
|
|
|
# ---- Private -----
|
|
def _send_message(self, spyder_msg_type, content=None, data=None,
|
|
comm_id=None):
|
|
"""
|
|
Publish custom messages to the other side.
|
|
|
|
Parameters
|
|
----------
|
|
spyder_msg_type: str
|
|
The spyder message type
|
|
content: dict
|
|
The (JSONable) content of the message
|
|
data: any
|
|
Any object that is serializable by cloudpickle (should be most
|
|
things). Will arrive as cloudpickled bytes in `.buffers[0]`.
|
|
comm_id: int
|
|
the comm to send to. If None sends to all comms.
|
|
"""
|
|
if not self.is_open(comm_id):
|
|
raise CommError("The comm is not connected.")
|
|
id_list = self.get_comm_id_list(comm_id)
|
|
for comm_id in id_list:
|
|
msg_dict = {
|
|
'spyder_msg_type': spyder_msg_type,
|
|
'content': content,
|
|
'pickle_protocol': self._comms[comm_id]['pickle_protocol'],
|
|
'python_version': sys.version,
|
|
}
|
|
buffers = [cloudpickle.dumps(
|
|
data, protocol=self._comms[comm_id]['pickle_protocol'])]
|
|
self._comms[comm_id]['comm'].send(msg_dict, buffers=buffers)
|
|
|
|
def _set_pickle_protocol(self, protocol):
|
|
"""Set the pickle protocol used to send data."""
|
|
protocol = min(protocol, pickle.HIGHEST_PROTOCOL)
|
|
self._comms[self.calling_comm_id]['pickle_protocol'] = protocol
|
|
self._comms[self.calling_comm_id]['status'] = 'ready'
|
|
|
|
@property
|
|
def _comm_name(self):
|
|
"""
|
|
Get the name used for the underlying comms.
|
|
"""
|
|
return 'spyder_api'
|
|
|
|
def _register_message_handler(self, message_id, handler):
|
|
"""
|
|
Register a message handler.
|
|
|
|
Parameters
|
|
----------
|
|
message_id : str
|
|
The identifier for the message
|
|
handler : callback
|
|
A function to handle the message. This is called with 3 arguments:
|
|
- msg_dict: A dictionary with message information.
|
|
- buffer: The data transmitted in the buffer
|
|
Pass None to unregister the message_id
|
|
"""
|
|
if handler is None:
|
|
self._message_handlers.pop(message_id, None)
|
|
return
|
|
|
|
self._message_handlers[message_id] = handler
|
|
|
|
def _register_comm(self, comm):
|
|
"""
|
|
Open a new comm to the kernel.
|
|
"""
|
|
comm.on_msg(self._comm_message)
|
|
comm.on_close(self._comm_close)
|
|
self._comms[comm.comm_id] = {
|
|
'comm': comm,
|
|
'pickle_protocol': DEFAULT_PICKLE_PROTOCOL,
|
|
'status': 'opening',
|
|
}
|
|
|
|
def _comm_close(self, msg):
|
|
"""Close comm."""
|
|
comm_id = msg['content']['comm_id']
|
|
del self._comms[comm_id]
|
|
|
|
def _comm_message(self, msg):
|
|
"""
|
|
Handle internal spyder messages.
|
|
"""
|
|
self.calling_comm_id = msg['content']['comm_id']
|
|
|
|
# Get message dict
|
|
msg_dict = msg['content']['data']
|
|
|
|
# Load the buffer. Only one is supported.
|
|
try:
|
|
if PY3:
|
|
# https://docs.python.org/3/library/pickle.html#pickle.loads
|
|
# Using encoding='latin1' is required for unpickling
|
|
# NumPy arrays and instances of datetime, date and time
|
|
# pickled by Python 2.
|
|
buffer = cloudpickle.loads(msg['buffers'][0],
|
|
encoding='latin-1')
|
|
else:
|
|
buffer = cloudpickle.loads(msg['buffers'][0])
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Exception in cloudpickle.loads : %s" % str(e))
|
|
buffer = CommsErrorWrapper(
|
|
msg_dict['content']['call_name'],
|
|
msg_dict['content']['call_id'])
|
|
|
|
msg_dict['content']['is_error'] = True
|
|
|
|
spyder_msg_type = msg_dict['spyder_msg_type']
|
|
|
|
if spyder_msg_type in self._message_handlers:
|
|
self._message_handlers[spyder_msg_type](
|
|
msg_dict, buffer)
|
|
else:
|
|
logger.debug("No such spyder message type: %s" % spyder_msg_type)
|
|
|
|
def _handle_remote_call(self, msg, buffer):
|
|
"""Handle a remote call."""
|
|
msg_dict = msg['content']
|
|
self.on_incoming_call(msg_dict)
|
|
try:
|
|
return_value = self._remote_callback(
|
|
msg_dict['call_name'],
|
|
buffer['call_args'],
|
|
buffer['call_kwargs'])
|
|
self._set_call_return_value(msg_dict, return_value)
|
|
except Exception:
|
|
exc_infos = CommsErrorWrapper(
|
|
msg_dict['call_name'], msg_dict['call_id'])
|
|
self._set_call_return_value(msg_dict, exc_infos, is_error=True)
|
|
|
|
def _remote_callback(self, call_name, call_args, call_kwargs):
|
|
"""Call the callback function for the remote call."""
|
|
if call_name in self._remote_call_handlers:
|
|
return self._remote_call_handlers[call_name](
|
|
*call_args, **call_kwargs)
|
|
|
|
raise CommError("No such spyder call type: %s" % call_name)
|
|
|
|
def _set_call_return_value(self, call_dict, data, is_error=False):
|
|
"""
|
|
A remote call has just been processed.
|
|
|
|
This will reply if settings['blocking'] == True
|
|
"""
|
|
settings = call_dict['settings']
|
|
|
|
display_error = ('display_error' in settings and
|
|
settings['display_error'])
|
|
if is_error and display_error:
|
|
data.print_error()
|
|
|
|
send_reply = 'send_reply' in settings and settings['send_reply']
|
|
if not send_reply:
|
|
# Nothing to send back
|
|
return
|
|
content = {
|
|
'is_error': is_error,
|
|
'call_id': call_dict['call_id'],
|
|
'call_name': call_dict['call_name']
|
|
}
|
|
|
|
self._send_message('remote_call_reply', content=content, data=data,
|
|
comm_id=self.calling_comm_id)
|
|
|
|
def _register_call(self, call_dict, callback=None):
|
|
"""
|
|
Register the call so the reply can be properly treated.
|
|
"""
|
|
settings = call_dict['settings']
|
|
blocking = 'blocking' in settings and settings['blocking']
|
|
call_id = call_dict['call_id']
|
|
if blocking or callback is not None:
|
|
self._reply_waitlist[call_id] = blocking, callback
|
|
|
|
def on_outgoing_call(self, call_dict):
|
|
"""A message is about to be sent"""
|
|
call_dict["pickle_highest_protocol"] = pickle.HIGHEST_PROTOCOL
|
|
return call_dict
|
|
|
|
def on_incoming_call(self, call_dict):
|
|
"""A call was received"""
|
|
if "pickle_highest_protocol" in call_dict:
|
|
self._set_pickle_protocol(call_dict["pickle_highest_protocol"])
|
|
|
|
def _get_call_return_value(self, call_dict, call_data, comm_id):
|
|
"""
|
|
Send a remote call and return the reply.
|
|
|
|
If settings['blocking'] == True, this will wait for a reply and return
|
|
the replied value.
|
|
"""
|
|
call_dict = self.on_outgoing_call(call_dict)
|
|
self._send_message(
|
|
'remote_call', content=call_dict, data=call_data,
|
|
comm_id=comm_id)
|
|
|
|
settings = call_dict['settings']
|
|
|
|
blocking = 'blocking' in settings and settings['blocking']
|
|
|
|
if not blocking:
|
|
return
|
|
|
|
call_id = call_dict['call_id']
|
|
call_name = call_dict['call_name']
|
|
|
|
# Wait for the blocking call
|
|
if 'timeout' in settings and settings['timeout'] is not None:
|
|
timeout = settings['timeout']
|
|
else:
|
|
timeout = TIMEOUT
|
|
|
|
self._wait_reply(call_id, call_name, timeout)
|
|
|
|
reply = self._reply_inbox.pop(call_id)
|
|
|
|
if reply['is_error']:
|
|
return self._sync_error(reply['value'])
|
|
|
|
return reply['value']
|
|
|
|
def _wait_reply(self, call_id, call_name, timeout):
|
|
"""
|
|
Wait for the other side reply.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _handle_remote_call_reply(self, msg_dict, buffer):
|
|
"""
|
|
A blocking call received a reply.
|
|
"""
|
|
content = msg_dict['content']
|
|
call_id = content['call_id']
|
|
call_name = content['call_name']
|
|
is_error = content['is_error']
|
|
|
|
# Unexpected reply
|
|
if call_id not in self._reply_waitlist:
|
|
if is_error:
|
|
return self._async_error(buffer)
|
|
else:
|
|
logger.debug('Got an unexpected reply {}, id:{}'.format(
|
|
call_name, call_id))
|
|
return
|
|
|
|
blocking, callback = self._reply_waitlist.pop(call_id)
|
|
|
|
# Async error
|
|
if is_error and not blocking:
|
|
return self._async_error(buffer)
|
|
|
|
# Callback
|
|
if callback is not None and not is_error:
|
|
callback(buffer)
|
|
|
|
# Blocking inbox
|
|
if blocking:
|
|
self._reply_inbox[call_id] = {
|
|
'is_error': is_error,
|
|
'value': buffer,
|
|
'content': content
|
|
}
|
|
|
|
def _async_error(self, error_wrapper):
|
|
"""
|
|
Handle an error that was raised on the other side asyncronously.
|
|
"""
|
|
error_wrapper.print_error()
|
|
|
|
def _sync_error(self, error_wrapper):
|
|
"""
|
|
Handle an error that was raised on the other side syncronously.
|
|
"""
|
|
error_wrapper.raise_error()
|
|
|
|
|
|
class RemoteCallFactory(object):
|
|
"""Class to create `RemoteCall`s."""
|
|
|
|
def __init__(self, comms_wrapper, comm_id, callback, **settings):
|
|
# Avoid setting attributes
|
|
super(RemoteCallFactory, self).__setattr__(
|
|
'_comms_wrapper', comms_wrapper)
|
|
super(RemoteCallFactory, self).__setattr__('_comm_id', comm_id)
|
|
super(RemoteCallFactory, self).__setattr__('_callback', callback)
|
|
super(RemoteCallFactory, self).__setattr__('_settings', settings)
|
|
|
|
def __getattr__(self, name):
|
|
"""Get a call for a function named 'name'."""
|
|
return RemoteCall(name, self._comms_wrapper, self._comm_id,
|
|
self._callback, self._settings)
|
|
|
|
def __setattr__(self, name, value):
|
|
"""Set an attribute to the other side."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class RemoteCall():
|
|
"""Class to call the other side of the comms like a function."""
|
|
|
|
def __init__(self, name, comms_wrapper, comm_id, callback, settings):
|
|
self._name = name
|
|
self._comms_wrapper = comms_wrapper
|
|
self._comm_id = comm_id
|
|
self._settings = settings
|
|
self._callback = callback
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
"""
|
|
Transmit the call to the other side of the tunnel.
|
|
|
|
The args and kwargs have to be picklable.
|
|
"""
|
|
blocking = 'blocking' in self._settings and self._settings['blocking']
|
|
self._settings['send_reply'] = blocking or self._callback is not None
|
|
|
|
call_id = uuid.uuid4().hex
|
|
call_dict = {
|
|
'call_name': self._name,
|
|
'call_id': call_id,
|
|
'settings': self._settings,
|
|
}
|
|
call_data = {
|
|
'call_args': args,
|
|
'call_kwargs': kwargs,
|
|
}
|
|
|
|
if not self._comms_wrapper.is_open(self._comm_id):
|
|
# Only an error if the call is blocking.
|
|
if blocking:
|
|
raise CommError("The comm is not connected.")
|
|
logger.debug("Call to unconnected comm: %s" % self._name)
|
|
return
|
|
self._comms_wrapper._register_call(call_dict, self._callback)
|
|
return self._comms_wrapper._get_call_return_value(
|
|
call_dict, call_data, self._comm_id)
|