|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- # Copyright (c) 2016 Anki, Inc.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License in the file LICENSE.txt or at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- __all__ = []
-
-
- import asyncio
- import struct
- import sys
- from threading import Lock
-
- from . import logger_protocol
-
- LOG_ALL = 'all'
-
- if sys.byteorder != 'little':
- raise ImportError("Cozmo SDK doesn't support byte order '%s' - contact Anki support to request this", sys.byteorder)
-
-
- class CLADProtocol(asyncio.Protocol):
- '''Low level CLAD codec'''
- _send_mutex = Lock()
-
- clad_decode_union = None
- clad_encode_union = None
- _clad_log_which = None
-
- def __init__(self):
- super().__init__()
-
- self._buf = bytearray()
- self._abort_connection = False # abort connection on failed handshake, ignore subsequent messages!
-
- def connection_made(self, transport):
- self.transport = transport
- logger_protocol.debug('Connected to transport')
-
- def connection_lost(self, exc):
- logger_protocol.debug("Connnection to transport lost: %s" % exc)
-
- def data_received(self, data):
- self._buf.extend(data)
- # pull clad messages out
-
- while not self._abort_connection:
- msg = self.decode_msg()
- # must compare msg against None, not just "if not msg" as the latter
- # would match against any message with len==0 (which is the case
- # for deliberately empty messages where the tag alone is the signal).
- if msg is None:
- return
- name = msg.tag_name
- if self._clad_log_which is LOG_ALL or (self._clad_log_which is not None and name in self._clad_log_which):
- logger_protocol.debug('RECV %s', msg._data)
- self.msg_received(msg)
-
- def decode_msg(self):
- if len(self._buf) < 2:
- return None
-
- # TODO: handle error
- # messages are prefixed by a 2 byte length
- msg_size = struct.unpack_from('H', self._buf)[0]
- if len(self._buf) < 2 + msg_size:
- return None
-
- buf, self._buf = self._buf[2:2+msg_size], self._buf[2+msg_size:]
-
- try:
- return self.clad_decode_union.unpack(buf)
- except ValueError as e:
- logger_protocol.warn("Failed to decode CLAD message for buflen=%d: %s", len(buf), e)
-
- def eof_received(self):
- logger_protocol.info("EOF received on connection")
-
- def send_msg(self, msg, **params):
- if self.transport.is_closing():
- return
-
- name = msg.__class__.__name__
- msg = self.clad_encode_union(**{name: msg})
- msg_buf = msg.pack()
- msg_size = struct.pack('H', len(msg_buf))
-
- self._send_mutex.acquire()
- try:
- self.transport.write(msg_size)
- self.transport.write(msg_buf)
- if self._clad_log_which is LOG_ALL or (self._clad_log_which is not None and name in self._clad_log_which):
- logger_protocol.debug("SENT %s", msg)
-
- finally:
- self._send_mutex.release()
-
- def send_msg_new(self, msg):
- name = msg.__class__.__name__
- return self.send_msg(name, msg)
-
- def msg_received(self, msg):
- pass
-
|