Files
nand2tetris/asm/venv/lib/python3.11/site-packages/spyder_kernels/comms/commbase.py
Sven Riwoldt c7bc862c6f asm
2024-04-01 20:30:24 +02:00

559 lines
18 KiB
Python

# -*- coding: utf-8 -*-
#
# Copyright © Spyder Project Contributors
# Licensed under the terms of the MIT License
# (see spyder/__init__.py for details)
"""
Class that handles communications between Spyder kernel and frontend.
Comms transmit data in a list of buffers, and in a json-able dictionnary.
Here, we only support a buffer list with a single element.
The messages exchanged have the following msg_dict:
```
msg_dict = {
'spyder_msg_type': spyder_msg_type,
'content': content,
}
```
The buffer is generated by cloudpickle using `PICKLE_PROTOCOL = 2`.
To simplify the usage of messaging, we use a higher level function calling
mechanism:
- The `remote_call` method returns a RemoteCallHandler object
- By calling an attribute of this object, the call is sent to the other
side of the comm.
- If the `_wait_reply` is implemented, remote_call can be called with
`blocking=True`, which will wait for a reply sent by the other side.
The messages exchanged are:
- Function call (spyder_msg_type = 'remote_call'):
- The content is a dictionnary {
'call_name': The name of the function to be called,
'call_id': uuid to match the request to a potential reply,
'settings': A dictionnary of settings,
}
- The buffer encodes a dictionnary {
'call_args': The function args,
'call_kwargs': The function kwargs,
}
- If the 'settings' has `'blocking' = True`, a reply is sent.
(spyder_msg_type = 'remote_call_reply'):
- The buffer contains the return value of the function.
- The 'content' is a dict with: {
'is_error': a boolean indicating if the return value is an
exception to be raised.
'call_id': The uuid from above,
'call_name': The function name (mostly for debugging)
}
"""
from __future__ import print_function
import cloudpickle
import pickle
import logging
import sys
import uuid
import traceback
from spyder_kernels.py3compat import PY2, PY3
logger = logging.getLogger(__name__)
# To be able to get and set variables between Python 2 and 3
DEFAULT_PICKLE_PROTOCOL = 2
# Max timeout (in secs) for blocking calls
TIMEOUT = 3
class CommError(RuntimeError):
pass
class CommsErrorWrapper():
def __init__(self, call_name, call_id):
self.call_name = call_name
self.call_id = call_id
self.etype, self.error, tb = sys.exc_info()
self.tb = traceback.extract_tb(tb)
def raise_error(self):
"""
Raise the error while adding informations on the callback.
"""
# Add the traceback in the error, so it can be handled upstream
raise self.etype(self)
def format_error(self):
"""
Format the error received from the other side and returns a list of
strings.
"""
lines = (['Exception in comms call {}:\n'.format(self.call_name)]
+ traceback.format_list(self.tb)
+ traceback.format_exception_only(self.etype, self.error))
return lines
def print_error(self, file=None):
"""
Print the error to file or to sys.stderr if file is None.
"""
if file is None:
file = sys.stderr
for line in self.format_error():
print(line, file=file)
def __str__(self):
"""Get string representation."""
return str(self.error)
def __repr__(self):
"""Get repr."""
return repr(self.error)
# Replace sys.excepthook to handle CommsErrorWrapper
sys_excepthook = sys.excepthook
def comm_excepthook(type, value, tb):
if len(value.args) == 1 and isinstance(value.args[0], CommsErrorWrapper):
traceback.print_tb(tb)
value.args[0].print_error()
return
sys_excepthook(type, value, tb)
sys.excepthook = comm_excepthook
class CommBase(object):
"""
Class with the necessary attributes and methods to handle
communications between a kernel and a frontend.
Subclasses must open a comm and register it with `self._register_comm`.
"""
def __init__(self):
super(CommBase, self).__init__()
self.calling_comm_id = None
self._comms = {}
# Handlers
self._message_handlers = {}
self._remote_call_handlers = {}
# Lists of reply numbers
self._reply_inbox = {}
self._reply_waitlist = {}
self._register_message_handler(
'remote_call', self._handle_remote_call)
self._register_message_handler(
'remote_call_reply', self._handle_remote_call_reply)
self.register_call_handler('_set_pickle_protocol',
self._set_pickle_protocol)
def get_comm_id_list(self, comm_id=None):
"""Get a list of comms id."""
if comm_id is None:
id_list = list(self._comms.keys())
else:
id_list = [comm_id]
return id_list
def close(self, comm_id=None):
"""Close the comm and notify the other side."""
id_list = self.get_comm_id_list(comm_id)
for comm_id in id_list:
try:
self._comms[comm_id]['comm'].close()
del self._comms[comm_id]
except KeyError:
pass
def is_open(self, comm_id=None):
"""Check to see if the comm is open."""
if comm_id is None:
return len(self._comms) > 0
return comm_id in self._comms
def is_ready(self, comm_id=None):
"""
Check to see if the other side replied.
The check is made with _set_pickle_protocol as this is the first call
made. If comm_id is not specified, check all comms.
"""
id_list = self.get_comm_id_list(comm_id)
if len(id_list) == 0:
return False
return all([self._comms[cid]['status'] == 'ready' for cid in id_list])
def register_call_handler(self, call_name, handler):
"""
Register a remote call handler.
Parameters
----------
call_name : str
The name of the called function.
handler : callback
A function to handle the request, or `None` to unregister
`call_name`.
"""
if not handler:
self._remote_call_handlers.pop(call_name, None)
return
self._remote_call_handlers[call_name] = handler
def remote_call(self, comm_id=None, callback=None, **settings):
"""Get a handler for remote calls."""
return RemoteCallFactory(self, comm_id, callback, **settings)
# ---- Private -----
def _send_message(self, spyder_msg_type, content=None, data=None,
comm_id=None):
"""
Publish custom messages to the other side.
Parameters
----------
spyder_msg_type: str
The spyder message type
content: dict
The (JSONable) content of the message
data: any
Any object that is serializable by cloudpickle (should be most
things). Will arrive as cloudpickled bytes in `.buffers[0]`.
comm_id: int
the comm to send to. If None sends to all comms.
"""
if not self.is_open(comm_id):
raise CommError("The comm is not connected.")
id_list = self.get_comm_id_list(comm_id)
for comm_id in id_list:
msg_dict = {
'spyder_msg_type': spyder_msg_type,
'content': content,
'pickle_protocol': self._comms[comm_id]['pickle_protocol'],
'python_version': sys.version,
}
buffers = [cloudpickle.dumps(
data, protocol=self._comms[comm_id]['pickle_protocol'])]
self._comms[comm_id]['comm'].send(msg_dict, buffers=buffers)
def _set_pickle_protocol(self, protocol):
"""Set the pickle protocol used to send data."""
protocol = min(protocol, pickle.HIGHEST_PROTOCOL)
self._comms[self.calling_comm_id]['pickle_protocol'] = protocol
self._comms[self.calling_comm_id]['status'] = 'ready'
@property
def _comm_name(self):
"""
Get the name used for the underlying comms.
"""
return 'spyder_api'
def _register_message_handler(self, message_id, handler):
"""
Register a message handler.
Parameters
----------
message_id : str
The identifier for the message
handler : callback
A function to handle the message. This is called with 3 arguments:
- msg_dict: A dictionary with message information.
- buffer: The data transmitted in the buffer
Pass None to unregister the message_id
"""
if handler is None:
self._message_handlers.pop(message_id, None)
return
self._message_handlers[message_id] = handler
def _register_comm(self, comm):
"""
Open a new comm to the kernel.
"""
comm.on_msg(self._comm_message)
comm.on_close(self._comm_close)
self._comms[comm.comm_id] = {
'comm': comm,
'pickle_protocol': DEFAULT_PICKLE_PROTOCOL,
'status': 'opening',
}
def _comm_close(self, msg):
"""Close comm."""
comm_id = msg['content']['comm_id']
del self._comms[comm_id]
def _comm_message(self, msg):
"""
Handle internal spyder messages.
"""
self.calling_comm_id = msg['content']['comm_id']
# Get message dict
msg_dict = msg['content']['data']
# Load the buffer. Only one is supported.
try:
if PY3:
# https://docs.python.org/3/library/pickle.html#pickle.loads
# Using encoding='latin1' is required for unpickling
# NumPy arrays and instances of datetime, date and time
# pickled by Python 2.
buffer = cloudpickle.loads(msg['buffers'][0],
encoding='latin-1')
else:
buffer = cloudpickle.loads(msg['buffers'][0])
except Exception as e:
logger.debug(
"Exception in cloudpickle.loads : %s" % str(e))
buffer = CommsErrorWrapper(
msg_dict['content']['call_name'],
msg_dict['content']['call_id'])
msg_dict['content']['is_error'] = True
spyder_msg_type = msg_dict['spyder_msg_type']
if spyder_msg_type in self._message_handlers:
self._message_handlers[spyder_msg_type](
msg_dict, buffer)
else:
logger.debug("No such spyder message type: %s" % spyder_msg_type)
def _handle_remote_call(self, msg, buffer):
"""Handle a remote call."""
msg_dict = msg['content']
self.on_incoming_call(msg_dict)
try:
return_value = self._remote_callback(
msg_dict['call_name'],
buffer['call_args'],
buffer['call_kwargs'])
self._set_call_return_value(msg_dict, return_value)
except Exception:
exc_infos = CommsErrorWrapper(
msg_dict['call_name'], msg_dict['call_id'])
self._set_call_return_value(msg_dict, exc_infos, is_error=True)
def _remote_callback(self, call_name, call_args, call_kwargs):
"""Call the callback function for the remote call."""
if call_name in self._remote_call_handlers:
return self._remote_call_handlers[call_name](
*call_args, **call_kwargs)
raise CommError("No such spyder call type: %s" % call_name)
def _set_call_return_value(self, call_dict, data, is_error=False):
"""
A remote call has just been processed.
This will reply if settings['blocking'] == True
"""
settings = call_dict['settings']
display_error = ('display_error' in settings and
settings['display_error'])
if is_error and display_error:
data.print_error()
send_reply = 'send_reply' in settings and settings['send_reply']
if not send_reply:
# Nothing to send back
return
content = {
'is_error': is_error,
'call_id': call_dict['call_id'],
'call_name': call_dict['call_name']
}
self._send_message('remote_call_reply', content=content, data=data,
comm_id=self.calling_comm_id)
def _register_call(self, call_dict, callback=None):
"""
Register the call so the reply can be properly treated.
"""
settings = call_dict['settings']
blocking = 'blocking' in settings and settings['blocking']
call_id = call_dict['call_id']
if blocking or callback is not None:
self._reply_waitlist[call_id] = blocking, callback
def on_outgoing_call(self, call_dict):
"""A message is about to be sent"""
call_dict["pickle_highest_protocol"] = pickle.HIGHEST_PROTOCOL
return call_dict
def on_incoming_call(self, call_dict):
"""A call was received"""
if "pickle_highest_protocol" in call_dict:
self._set_pickle_protocol(call_dict["pickle_highest_protocol"])
def _get_call_return_value(self, call_dict, call_data, comm_id):
"""
Send a remote call and return the reply.
If settings['blocking'] == True, this will wait for a reply and return
the replied value.
"""
call_dict = self.on_outgoing_call(call_dict)
self._send_message(
'remote_call', content=call_dict, data=call_data,
comm_id=comm_id)
settings = call_dict['settings']
blocking = 'blocking' in settings and settings['blocking']
if not blocking:
return
call_id = call_dict['call_id']
call_name = call_dict['call_name']
# Wait for the blocking call
if 'timeout' in settings and settings['timeout'] is not None:
timeout = settings['timeout']
else:
timeout = TIMEOUT
self._wait_reply(call_id, call_name, timeout)
reply = self._reply_inbox.pop(call_id)
if reply['is_error']:
return self._sync_error(reply['value'])
return reply['value']
def _wait_reply(self, call_id, call_name, timeout):
"""
Wait for the other side reply.
"""
raise NotImplementedError
def _handle_remote_call_reply(self, msg_dict, buffer):
"""
A blocking call received a reply.
"""
content = msg_dict['content']
call_id = content['call_id']
call_name = content['call_name']
is_error = content['is_error']
# Unexpected reply
if call_id not in self._reply_waitlist:
if is_error:
return self._async_error(buffer)
else:
logger.debug('Got an unexpected reply {}, id:{}'.format(
call_name, call_id))
return
blocking, callback = self._reply_waitlist.pop(call_id)
# Async error
if is_error and not blocking:
return self._async_error(buffer)
# Callback
if callback is not None and not is_error:
callback(buffer)
# Blocking inbox
if blocking:
self._reply_inbox[call_id] = {
'is_error': is_error,
'value': buffer,
'content': content
}
def _async_error(self, error_wrapper):
"""
Handle an error that was raised on the other side asyncronously.
"""
error_wrapper.print_error()
def _sync_error(self, error_wrapper):
"""
Handle an error that was raised on the other side syncronously.
"""
error_wrapper.raise_error()
class RemoteCallFactory(object):
"""Class to create `RemoteCall`s."""
def __init__(self, comms_wrapper, comm_id, callback, **settings):
# Avoid setting attributes
super(RemoteCallFactory, self).__setattr__(
'_comms_wrapper', comms_wrapper)
super(RemoteCallFactory, self).__setattr__('_comm_id', comm_id)
super(RemoteCallFactory, self).__setattr__('_callback', callback)
super(RemoteCallFactory, self).__setattr__('_settings', settings)
def __getattr__(self, name):
"""Get a call for a function named 'name'."""
return RemoteCall(name, self._comms_wrapper, self._comm_id,
self._callback, self._settings)
def __setattr__(self, name, value):
"""Set an attribute to the other side."""
raise NotImplementedError
class RemoteCall():
"""Class to call the other side of the comms like a function."""
def __init__(self, name, comms_wrapper, comm_id, callback, settings):
self._name = name
self._comms_wrapper = comms_wrapper
self._comm_id = comm_id
self._settings = settings
self._callback = callback
def __call__(self, *args, **kwargs):
"""
Transmit the call to the other side of the tunnel.
The args and kwargs have to be picklable.
"""
blocking = 'blocking' in self._settings and self._settings['blocking']
self._settings['send_reply'] = blocking or self._callback is not None
call_id = uuid.uuid4().hex
call_dict = {
'call_name': self._name,
'call_id': call_id,
'settings': self._settings,
}
call_data = {
'call_args': args,
'call_kwargs': kwargs,
}
if not self._comms_wrapper.is_open(self._comm_id):
# Only an error if the call is blocking.
if blocking:
raise CommError("The comm is not connected.")
logger.debug("Call to unconnected comm: %s" % self._name)
return
self._comms_wrapper._register_call(call_dict, self._callback)
return self._comms_wrapper._get_call_return_value(
call_dict, call_data, self._comm_id)