# Copyright 2018-2020 Drexel University
# Author: Geoffrey Mainland <mainland@drexel.edu>
"""Support for protobuf over UDP, TCP, and ZMQ"""
import asyncio
from functools import partial, wraps
import inspect
import logging
from pprint import pformat
import re
import struct
import zmq.asyncio
logger = logging.getLogger('protobuf')
[docs]def setTimestamp(self, ts):
"""Set timestamp (sec)"""
self.seconds = int(ts)
self.picoseconds = int(ts % 1 * 1e12)
[docs]def getTimestamp(self):
"""Get timestamp (sec)"""
return self.seconds + self.picoseconds*1e-12
[docs]class HandlerTable:
"""Table for protobuf message handlers"""
# pylint: disable=too-few-public-methods
def __init__(self):
self.message_types = {}
self.message_handlers = {}
[docs]def handler(message_type):
"""Add automatic support for handling protobuf messages with a payload
structure. Should be used to decorate a class handling protobuf messages.
Args:
cls (class): The protobuf message class handled
"""
def decorator(cls):
if not hasattr(cls, 'handlers'):
cls.handlers = {}
table = HandlerTable()
cls.handlers[message_type.__name__] = table
for field in message_type.DESCRIPTOR.fields:
table.message_types[field.name] = field.number
for (_, f) in inspect.getmembers(cls, predicate=inspect.isfunction):
if 'message_name' in f.__dict__:
(cname, fname) = re.split(r'\.', f.message_name)
if cname == message_type.__name__:
if fname not in table.message_types:
raise ValueError("Illegal message type '{}' for class {}".\
format(fname, message_type.__name__))
table.message_handlers[fname] = f
return cls
return decorator
[docs]def findHandler(obj, cls, msg):
"""Find the handler associated with a protobuf message"""
return obj.handler_obj.handlers[cls.__name__].\
message_handlers[msg.WhichOneof('payload')]
[docs]def handle(name):
"""
Indicate that a method handles a specific message. Should be used to
decorate a method of a class that handles protobuf messages.
Args:
name (str): The name of the message the function handles.
"""
def decorator(f):
f.message_name = name
return f
return decorator
[docs]class ZMQProtoServer:
"""Protobuf-over-ZMQ server"""
# pylint: disable=too-few-public-methods
def __init__(self, handler_obj=None, loop=None):
self.handler_obj = handler_obj
"""Protobuf message handler object"""
self.loop = loop
"""asyncio loop"""
[docs] def startServer(self, cls, listen_ip, listen_port):
"""Start server"""
return self.loop.create_task(self._serverLoop(cls, listen_ip, listen_port),
name='ZMQ {}'.format(cls.__name__))
[docs] async def _serverLoop(self, cls, listen_ip, listen_port):
try:
ctx = zmq.asyncio.Context()
listen_sock = ctx.socket(zmq.PULL)
listen_sock.bind('tcp://{}:{}'.format(listen_ip, listen_port))
while True:
raw = await listen_sock.recv()
msg = cls.FromString(raw)
logger.debug('Received message: %s', pformat(msg))
try:
f = findHandler(self, cls, msg)
self.loop.create_task(f(self.handler_obj, msg))
except KeyError as err:
logger.error('Received unsupported message type: %s', err)
except asyncio.CancelledError:
listen_sock.close()
ctx.term()
[docs]class ZMQProtoClient:
"""Protobuf-over-ZMQ client"""
def __init__(self, loop=None, server_host=None, server_port=None):
self.loop = loop
"""asyncio loop"""
self.server_host = server_host
"""Server hostname"""
self.server_port = server_port
"""Server port"""
self.ctx = None
"""ZMQ context"""
self.server_sock = None
"""ZMQ server socket"""
def __enter__(self):
self.open()
def __exit__(self, _type, _value, _traceback):
self.close()
[docs] def open(self):
"""Open connection to server"""
self.ctx = zmq.asyncio.Context()
self.server_sock = self.ctx.socket(zmq.PUSH)
self.server_sock.connect('tcp://{}:{}'.format(self.server_host, self.server_port))
# See:
# https://github.com/zeromq/pyzmq/issues/102
# http://api.zeromq.org/2-1:zmq-setsockopt
self.server_sock.setsockopt(zmq.LINGER, 2500)
[docs] def close(self):
"""Close connection to server"""
if self.server_sock:
self.server_sock.close()
self.ctx.term()
self.server_sock = None
self.ctx = None
[docs] async def send(self, msg):
"""Send a message"""
await self.server_sock.send(msg.SerializeToString())
[docs]def send(cls):
"""
Automatically add support to a function for constructing and sending a
protobuf message. Should be used to decorate the methods of a
ZMQProtoClient subclass.
Args:
cls (class): The message class to send.
"""
def decorator(f):
@wraps(f)
async def wrapper(self, *args, **kwargs):
msg = cls()
await f(self, msg, *args, **kwargs)
logger.debug('Sending message %s', pformat(msg))
await self.send(msg)
return wrapper
return decorator
[docs]class ProtobufProtocol(asyncio.Protocol):
"""Abstract base class for protobuf-over-TCP clients"""
# pylint: disable=too-many-instance-attributes
def __init__(self, cls=None, handler_obj=None, loop=None, **kwargs):
super().__init__(**kwargs)
self.cls = cls
"""Protobuf message class of messages received by server"""
self.handler_obj = handler_obj
"""Protobuf message handler object"""
self.loop = loop
"""asyncio loop"""
self.connected_event = asyncio.Event()
"""Event set when connection is made"""
self.server_task = None
"""Server loop task"""
self.transport = None
"""Transport associated with protocol"""
self.buffer = bytearray()
"""Received bytes"""
self.buffer_lock = asyncio.Lock()
"""Lock for buffer"""
self.buffer_cond = asyncio.Condition(lock=self.buffer_lock)
"""Condition variable for buffer"""
[docs] def connection_made(self, transport):
self.transport = transport
self.connected_event.set()
if self.cls:
self.server_task = self.loop.create_task(self._serverLoop())
[docs] def connection_lost(self, exc):
self.connected_event.clear()
self.transport = None
if self.server_task is not None:
self.server_task.cancel()
[docs] def data_received(self, data):
async def f():
with await self.buffer_lock:
self.buffer.extend(data)
self.buffer_cond.notify_all()
self.loop.create_task(f())
[docs] async def send(self, msg):
"""Serialize and send a protobuf message with its length prepended"""
# Wait until we are connected
await self.connected_event.wait()
logger.debug('Sending message %s', pformat(msg))
data = msg.SerializeToString()
self.transport.write(struct.pack('!H', len(data)))
self.transport.write(data)
[docs] async def recv(self, cls):
"""Receive a protobuf message of the given message class"""
# Get message length
datalen = await self._recvBytes(2)
datalen, = struct.unpack('!H', datalen)
# Get message
data = await self._recvBytes(datalen)
# Decode message
msg = cls.FromString(data)
logger.debug('Received message: %s', pformat(msg))
return msg
[docs] async def _recvBytes(self, count):
with await self.buffer_lock:
await self.buffer_cond.wait_for(lambda: len(self.buffer) >= count)
data = self.buffer[:count]
self.buffer = self.buffer[count:]
return data
[docs] async def _serverLoop(self):
while True:
try:
req = await self.recv(self.cls)
f = findHandler(self, self.cls, req)
resp = f(self.handler_obj, req)
if resp:
await self.send(resp)
except KeyError as exc:
logger.error('Received unsupported message type: %s', exc)
except asyncio.CancelledError:
return
[docs]class TCPProtoServer:
"""Server for protobuf-over-TCP"""
# pylint: disable=too-few-public-methods
def __init__(self, handler_obj, loop=None):
self.handler_obj = handler_obj
"""Protobuf message handler object"""
self.loop = loop
"""asyncio loop"""
[docs] def startServer(self, cls, listen_ip, listen_port):
"""Start a protobuf TCP server"""
return self.loop.create_task(self._serverLoop(cls, listen_ip, listen_port),
name='TCP {}'.format(cls.__name__))
[docs] async def _serverLoop(self, cls, listen_ip, listen_port):
while True:
try:
server = await self.loop.create_server(partial(ProtobufProtocol,
cls=cls,
handler_obj=self.handler_obj,
loop=self.loop),
host=listen_ip,
port=listen_port,
reuse_port=True)
await server.wait_closed()
except asyncio.CancelledError:
return
except: # pylint: disable=bare-except
logger.exception('Restarting TCP proto server')
[docs]class TCPProtoClient(ProtobufProtocol):
"""Client for protobuf-over-TCP"""
def __init__(self, server_host=None, server_port=None, **kwargs):
super().__init__(**kwargs)
self.server_host = server_host
"""Server hostname"""
self.server_port = server_port
"""Server port"""
def __call__(self):
return self
def __enter__(self):
self.open()
def __exit__(self, _type, _value, _traceback):
self.close()
[docs] def open(self):
"""Open connection to server"""
async def f():
self.transport, _protocol = await self.loop.create_connection(self,
host=self.server_host,
port=self.server_port)
self.loop.create_task(f())
[docs] def close(self):
"""Close connection to server"""
if self.transport:
self.transport.close()
self.transport = None
[docs]def rpc(req_cls, resp_cls):
"""Automatically add support for synchronously waiting on the results of an
RPC call. Should be used to decorate the methods of a TCPProtoClient
subclass.
Args:
req_cls (class): The message class for RPC requests.
resp_cls (class): The message class for RPC responses.
"""
def decorator(f):
@wraps(f)
def wrapper(self, *args, timeout=1, **kwargs):
async def run():
req = req_cls()
f(self, req, *args, **kwargs)
logger.debug('Sending RPC request %s', req)
await self.send(req)
resp = await self.recv(resp_cls)
logger.debug('Recevied RPC response message %s', resp)
return resp
return self.loop.run_until_complete(asyncio.wait_for(run(), timeout))
return wrapper
return decorator
[docs]class ProtobufDatagramProtocol(asyncio.DatagramProtocol):
"""Abstract base class of client for protobuf-over-UDP"""
def __init__(self, cls=None, handler_obj=None, loop=None, **kwargs):
super().__init__(**kwargs)
self.cls = cls
"""Protobuf message class of messages received by server"""
self.handler_obj = handler_obj
"""Protobuf message handler object"""
self.loop = loop
"""asyncio loop"""
self.connected_event = asyncio.Event()
"""Event set when connection is made"""
self.transport = None
"""Transport associated with protocol"""
[docs] def connection_made(self, transport):
self.transport = transport
self.connected_event.set()
[docs] def connection_lost(self, exc):
self.connected_event.clear()
self.transport = None
[docs] def datagram_received(self, data, addr):
try:
msg = self.cls.FromString(data)
logger.debug('Received message: %s', pformat(msg))
f = findHandler(self, self.cls, msg)
f(self.handler_obj, msg)
except KeyError as err:
logger.error('Received unsupported message type: %s', err)
[docs] async def send(self, msg):
"""Serialize and send a protobuf message with its length prepended"""
# Wait until we are connected
await self.connected_event.wait()
logger.debug('Sending message %s', pformat(msg))
self.transport.sendto(msg.SerializeToString(), addr=None)
[docs]class UDPProtoServer:
"""Server for protobuf-over-UDP"""
# pylint: disable=too-few-public-methods
def __init__(self, handler_obj, loop=None):
self.handler_obj = handler_obj
"""Protobuf message handler object"""
self.loop = loop
"""asyncio loop"""
[docs] def startServer(self, cls, listen_ip, listen_port):
"""Start a protobuf UDP server"""
return self.loop.create_task(self._serverLoop(cls, listen_ip, listen_port),
name='UDP {}'.format(cls.__name__))
[docs] async def _serverLoop(self, cls, listen_ip, listen_port):
"""Create server endpoint"""
await self.loop.create_datagram_endpoint(partial(ProtobufDatagramProtocol,
cls=cls,
handler_obj=self.handler_obj,
loop=self.loop),
local_addr=(listen_ip, listen_port),
allow_broadcast=True)
[docs]class UDPProtoClient(ProtobufDatagramProtocol):
"""Client for protobuf-over-UDP"""
def __init__(self, server_host=None, server_port=None, **kwargs):
super().__init__(**kwargs)
self.server_host = server_host
"""Server hostname"""
self.server_port = server_port
"""Server port"""
def __call__(self):
return self
def __enter__(self):
self.open()
def __exit__(self, _type, _value, _traceback):
self.close()
[docs] def open(self):
"""Open connection to server"""
async def f():
await self.loop.create_datagram_endpoint(lambda: self,
remote_addr=(self.server_host, self.server_port),
allow_broadcast=True)
self.loop.create_task(f())
[docs] def close(self):
"""Close connection to server"""
if self.transport:
self.transport.close()
self.transport = None