"""Module for funneling messages from one comm to another."""
import os
import numpy as np
import threading
from cis_interface import backwards
from cis_interface.communication import new_comm, get_comm_class
from cis_interface.drivers.Driver import Driver
from cis_interface.schema import get_schema
def _translate_list2element(arr):
if isinstance(arr, (list, tuple)):
arr = arr[0]
return arr
[docs]class ConnectionDriver(Driver):
r"""Class that continuously passes messages from one comm to another.
Args:
name (str): Name that should be used to set names of input/output comms.
icomm_kws (dict, optional): Keyword arguments for the input communicator.
ocomm_kws (dict, optional): Keyword arguments for the output communicator.
translator (str, func, optional): Function or string specifying function
that should be used to translate messages from the input communicator
before passing them to the output communicator. If a string, the
format should be "<package.module>:<function>" so that <function>
can be imported from <package>. Defaults to None and messages are
passed directly. This can also be a list of functions/strings that
will be called on the messages in the order they are provided.
timeout_send_1st (float, optional): Time in seconds that should be
waited before giving up on the first send. Defaults to self.timeout.
single_use (bool, optional): If True, the driver will be stopped after
one loop. Defaults to False.
onexit (str, optional): Class method that should be called when the
corresponding model exits, but before the driver is shut down.
Defaults to None.
**kwargs: Additonal keyword arguments are passed to the parent class.
Attributes:
icomm_kws (dict): Keyword arguments for the input communicator.
ocomm_kws (dict): Keyword arguments for the output communicator.
icomm (CommBase): Input communicator.
ocomm (CommBase): Output communicator.
nrecv (int): Number of messages received.
nproc (int): Number of messages processed.
nsent (int): Number of messages sent.
nskip (int): Number of messages skipped.
state (str): Descriptor of last action taken.
translator (func): Function that will be used to translate messages from
the input communicator before passing them to the output communicator.
timeout_send_1st (float): Time in seconds that should be waited before
giving up on the first send.
single_use (bool): If True, the driver will be stopped after one
loop.
onexit (str): Class method that should be called when the corresponding
model exits, but before the driver is shut down.
"""
_connection_type = 'default'
_icomm_type = 'DefaultComm'
_ocomm_type = 'DefaultComm'
_direction = 'any'
_schema_type = 'connection'
_schema_required = ['inputs', 'outputs']
_schema_properties = {
'inputs': {'type': 'array', 'minItems': 1,
'items': {'anyOf': [{'$ref': '#/definitions/comm'},
{'$ref': '#/definitions/file'}]}},
'outputs': {'type': 'array', 'minItems': 1,
'items': {'anyOf': [{'$ref': '#/definitions/comm'},
{'$ref': '#/definitions/file'}]}},
'translator': {'type': 'array', 'items': {'type': 'function'}},
'onexit': {'type': 'string'}}
@property
def _is_input(self):
r"""bool: True if the connection is providing input to a model."""
return (self._direction == 'input')
@property
def _is_output(self):
r"""bool: True if the connection is retreiving output from a model."""
return (self._direction == 'output')
def __init__(self, name, translator=None, single_use=False, onexit=None, **kwargs):
super(ConnectionDriver, self).__init__(name, **kwargs)
# Translator
if translator is None:
translator = []
elif not isinstance(translator, list):
translator = [translator]
self.translator = []
for t in translator:
if not hasattr(t, '__call__'):
raise ValueError("Translator %s not callable." % t)
self.translator.append(t)
if (onexit is not None) and (not hasattr(self, onexit)):
raise ValueError("onexit '%s' is not a class method." % onexit)
self.onexit = onexit
# Attributes
self._eof_sent = False
self.single_use = single_use
self._first_send_done = False
self._comm_opened = threading.Event()
self._comm_closed = False
self._used = False
self._skip_after_loop = False
self._model_exited = False
self.nrecv = 0
self.nproc = 0
self.nsent = 0
self.nskip = 0
self.state = 'started'
self.close_state = ''
# Add comms and print debug info
self._init_comms(name, **kwargs)
# self.debug(' env: %s', str(self.env))
self.debug(('\n' + 80 * '=' + '\n'
+ 'class = %s\n'
+ ' input: name = %s, address = %s\n'
+ ' output: name = %s, address = %s\n'
+ (80 * '=')), self.__class__,
self.icomm.name, self.icomm.address,
self.ocomm.name, self.ocomm.address)
def _init_single_comm(self, name, io, comm_kws, **kwargs):
r"""Parse keyword arguments for input/output comm."""
self.debug("Creating %s comm", io)
s = get_schema()
if comm_kws is None:
comm_kws = dict()
if io == 'input':
direction = 'recv'
comm_type = self._icomm_type
touches_model = self._is_output
attr_comm = 'icomm'
comm_kws['close_on_eof_recv'] = False
else:
direction = 'send'
comm_type = self._ocomm_type
touches_model = self._is_input
attr_comm = 'ocomm'
comm_kws['direction'] = direction
comm_kws['dont_open'] = True
comm_kws['reverse_names'] = True
comm_kws.setdefault('comm', {'comm': comm_type})
assert(name == self.name)
comm_kws.setdefault('name', name)
if not isinstance(comm_kws['comm'], list):
comm_kws['comm'] = [comm_kws['comm']]
for i, x in enumerate(comm_kws['comm']):
if x is None:
comm_kws['comm'][i] = dict()
elif not isinstance(x, dict):
comm_kws['comm'][i] = dict(comm=x)
comm_kws['comm'][i].setdefault('comm', comm_type)
any_files = False
all_files = True
if not touches_model:
comm_kws['no_suffix'] = True
ikws = []
for x in comm_kws['comm']:
if get_comm_class(x['comm']).is_file:
any_files = True
ikws += s['file'].get_subtype_properties(x['comm'])
else:
all_files = False
ikws += s['comm'].get_subtype_properties(x['comm'])
ikws = list(set(ikws))
for k in ikws:
if (k not in comm_kws) and (k in kwargs):
comm_kws[k] = kwargs.pop(k)
if ('comm_env' in kwargs) and ('comm_env' not in comm_kws):
comm_kws['env'] = kwargs.pop('comm_env')
if any_files and (io == 'input'):
kwargs.setdefault('timeout_send_1st', 60)
self.debug('%s comm_kws:\n%s', attr_comm, self.pprint(comm_kws, 1))
setattr(self, attr_comm, new_comm(comm_kws.pop('name'), **comm_kws))
setattr(self, '%s_kws' % attr_comm, comm_kws)
if touches_model:
self.env.update(getattr(self, attr_comm).opp_comms)
elif not all_files:
self.comm_env.update(getattr(self, attr_comm).opp_comms)
return kwargs
def _init_comms(self, name, icomm_kws=None, ocomm_kws=None, **kwargs):
r"""Parse keyword arguments for input/output comms."""
kwargs = self._init_single_comm(name, 'input', icomm_kws, **kwargs)
try:
kwargs = self._init_single_comm(name, 'output', ocomm_kws, **kwargs)
except BaseException:
self.icomm.close()
raise
# Apply keywords dependent on comms
self.timeout_send_1st = kwargs.pop('timeout_send_1st', self.timeout)
self.debug('Final env:\n%s', self.pprint(self.env, 1))
[docs] def wait_for_route(self, timeout=None):
r"""Wait until messages have been routed."""
T = self.start_timeout(timeout)
while ((not T.is_out)
and (self.nrecv != (self.nsent + self.nskip))): # pragma: debug
self.sleep()
self.stop_timeout()
return (self.nrecv == (self.nsent + self.nskip))
@property
def is_valid(self):
r"""bool: Returns True if the connection is open and the parent class
is valid."""
with self.lock:
return (super(ConnectionDriver, self).is_valid
and self.is_comm_open and not (self.single_use and self._used))
@property
def is_comm_open(self):
r"""bool: Returns True if both communicators are open."""
with self.lock:
return (self.icomm.is_open and self.ocomm.is_open
and not self._comm_closed)
@property
def is_comm_closed(self):
r"""bool: Returns True if both communicators are closed."""
with self.lock:
return self.icomm.is_closed and self.ocomm.is_closed
@property
def n_msg(self):
r"""int: Number of messages waiting in input communicator."""
with self.lock:
return self.icomm.n_msg_recv
[docs] def open_comm(self):
r"""Open the communicators."""
self.debug('')
with self.lock:
if self._comm_closed:
self.debug('Aborted as comm closed')
return
try:
self.icomm.open()
self.ocomm.open()
except BaseException:
self.close_comm()
raise
self._comm_opened.set()
self.debug('Returning')
[docs] def close_comm(self):
r"""Close the communicators."""
self.debug('')
with self.lock:
self._comm_closed = True
self._skip_after_loop = True
# Capture errors for both comms
ie = None
oe = None
try:
if getattr(self, 'icomm', None) is not None:
self.icomm.close()
except BaseException as e:
ie = e
try:
if getattr(self, 'ocomm', None) is not None:
self.ocomm.close()
except BaseException as e:
oe = e
if ie:
raise ie
if oe:
raise oe
self.debug('Returning')
[docs] def start(self):
r"""Open connection before running."""
self.open_comm()
Tout = self.start_timeout()
while (not self.is_comm_open) and (not Tout.is_out):
self.sleep()
self.stop_timeout()
if not self.is_comm_open:
raise Exception("Connection never finished opening.")
super(ConnectionDriver, self).start()
[docs] def graceful_stop(self, timeout=None, **kwargs):
r"""Stop the driver, first waiting for the input comm to be empty.
Args:
timeout (float, optional): Max time that should be waited. Defaults
to None and is set to attribute timeout.
**kwargs: Additional keyword arguments are passed to the parent
class's graceful_stop method.
"""
self.debug('')
with self.lock:
self.set_close_state('stop')
self._skip_after_loop = True
self.drain_input(timeout=timeout)
self.wait_for_route(timeout=timeout)
self.drain_output(timeout=timeout)
super(ConnectionDriver, self).graceful_stop()
self.debug('Returning')
[docs] def on_model_exit(self):
r"""Drain input and then close it."""
self.debug('')
if (self.onexit not in [None, 'on_model_exit', 'pass']):
self.debug("Calling onexit = '%s'" % self.onexit)
getattr(self, self.onexit)()
self.drain_input(timeout=self.timeout)
self.set_close_state('model exit')
self.debug('Model exit triggered close')
if self._is_input:
with self.lock:
self.icomm.close()
self.ocomm.close()
if self._is_output:
self.wait_for_route(timeout=self.timeout)
with self.lock:
self.icomm.close()
self.set_break_flag()
self.wait()
self.debug('Finished')
super(ConnectionDriver, self).on_model_exit()
[docs] def do_terminate(self):
r"""Stop the driver by closing the communicators."""
self.debug('')
self.set_close_state('terminate')
self.close_comm()
super(ConnectionDriver, self).do_terminate()
[docs] def cleanup(self):
r"""Ensure that the communicators are closed."""
self.debug('')
self.close_comm()
super(ConnectionDriver, self).cleanup()
[docs] def printStatus(self, beg_msg='', end_msg=''):
r"""Print information on the status of the ConnectionDriver.
Arguments:
beg_msg (str, optional): Additional message to print at beginning.
end_msg (str, optional): Additional message to print at end.
"""
msg = beg_msg
msg += '%-50s' % (self.__module__.split('.')[-1] + '(' + self.name + '): ')
msg += '%-30s' % ('last action: ' + self.state)
msg += '%-25s' % ('is_open(%s, %s), ' % (self.icomm.is_open,
self.ocomm.is_open))
msg += '%-15s' % (str(self.nrecv) + ' received, ')
msg += '%-15s' % (str(self.nproc) + ' processed, ')
msg += '%-15s' % (str(self.nskip) + ' skipped, ')
msg += '%-15s' % (str(self.nsent) + ' sent, ')
msg += '%-20s' % (str(self.icomm.n_msg) + ' ready to recv')
msg += '%-20s' % (str(self.ocomm.n_msg) + ' ready to send')
with self.lock:
if self.close_state:
msg += '%-30s' % ('close state: ' + self.close_state)
msg += end_msg
print(msg)
[docs] def confirm_output(self, timeout=None):
r"""Confirm receipt of messages from output comm."""
T = self.start_timeout(timeout)
while not T.is_out: # pragma: debug
with self.lock:
if (not self.ocomm.is_open):
break
elif self.ocomm.is_confirmed_send:
break
self.sleep(10 * self.sleeptime)
self.stop_timeout()
[docs] def drain_output(self, timeout=None):
r"""Drain messages from output comm."""
T = self.start_timeout(timeout)
while not T.is_out:
with self.lock:
if (not self.ocomm.is_open):
break
elif ((self.ocomm.n_msg_send_drain == 0)
and self.ocomm.is_confirmed_send):
break
self.sleep()
self.stop_timeout()
[docs] def before_loop(self):
r"""Actions to perform prior to sending messages."""
self.state = 'before loop'
try:
self.open_comm()
self.sleep() # Help ensure senders/receivers connected before messages
self.debug('Running in %s, is_valid = %s', os.getcwd(), str(self.is_valid))
assert(self.is_valid)
except BaseException: # pragma: debug
self.printStatus()
self.exception('Could not prep for loop (is_open = (%s, %s)).' % (
self.icomm.is_open, self.ocomm.is_open))
self.close_comm()
self.set_break_flag()
[docs] def after_loop(self):
r"""Actions to perform after sending messages."""
self.state = 'after loop'
self.debug('')
# Close input comm in case loop did not
self.confirm_input(timeout=False)
self.debug('Confirmed input')
with self.lock:
self.debug('Acquired lock')
if self._skip_after_loop:
self.debug("After loop skipped.")
return
self.icomm.close()
# Send EOF in case the model didn't
if not self.single_use:
self.send_eof()
# Do not close output comm in case model/connection still receiving
self.debug('Finished')
[docs] def recv_message(self, **kwargs):
r"""Get a new message to send.
Args:
**kwargs: Additional keyword arguments are passed to the appropriate
recv method.
Returns:
str, bool: False if no more messages, message otherwise.
"""
kwargs.setdefault('timeout', 0)
with self.lock:
if self.icomm.is_closed:
return False
flag, msg = self.icomm.recv(**kwargs)
if isinstance(msg, backwards.bytes_type) and (msg == self.icomm.eof_msg):
return self.on_eof()
if flag:
return msg
else:
return flag
[docs] def on_eof(self):
r"""Actions to take when EOF received.
Returns:
str, bool: Value that should be returned by recv_message on EOF.
"""
with self.lock:
self.debug('EOF received')
self.state = 'eof'
self.set_close_state('eof')
self.set_break_flag()
self.debug('After EOF')
return False
[docs] def on_message(self, msg):
r"""Process a message.
Args:
msg (bytes, str): Message to be processed.
Returns:
bytes, str: Processed message.
"""
if (self.ocomm._send_serializer) and self.icomm.serializer._initialized:
self.update_serializer(msg)
for t in self.translator:
msg = t(msg)
return msg
[docs] def update_serializer(self, msg):
r"""Update the serializer for the output comm based on input."""
sinfo = self.icomm.serializer.typedef
sinfo.update(self.icomm.serializer.serializer_info)
sinfo.pop('seritype', None)
self.debug('Before update:\n'
+ ' icomm:\n sinfo:\n%s\n typedef:\n%s\n'
+ ' ocomm:\n sinfo:\n%s\n typedef:\n%s',
self.pprint(self.icomm.serializer.serializer_info, 2),
self.pprint(self.icomm.serializer.typedef, 2),
self.pprint(self.ocomm.serializer.serializer_info, 2),
self.pprint(self.ocomm.serializer.typedef, 2))
self.ocomm.serializer.initialize_serializer(sinfo)
self.ocomm.serializer.update_serializer(skip_type=True,
**self.icomm._last_header)
if (((self.icomm.serializer.typedef['type'] == 'array')
and (self.ocomm.serializer.typedef['type'] != 'array')
and (len(self.icomm.serializer.typedef['items']) == 1))):
self.translator.insert(0, _translate_list2element)
# inter_model = False
# if self.icomm.is_file:
# # Remove the file information and only pass the type definition
# typedef_in = self.icomm.serializer.typedef
# sinfo = self.icomm.serializer.typedef
# sinfo.pop('seritype', None)
# elif self.ocomm.is_file:
# # Maintain the default serializer type for the file
# sinfo = self.icomm.serializer.serializer_info
# sinfo.pop('seritype')
# sinfo.update(self.ocomm.serializer.serializer_info)
# sinfo.update(self.icomm.serializer.typedef)
# else:
# # Copy the serializer and prevent the type from being overwritten
# # TODO: icomm is probably initialized so the serializer info
# # from the output comm won't be used.
# sinfo = self.ocomm.serializer.serializer_info
# sinfo.pop('seritype', None)
# self.ocomm.serializer = self.icomm.serializer
# inter_model = True
# if (not inter_model) and self.ocomm.serializer._initialized: # pragma: debug
# self.ocomm.serializer.update_serializer(**sinfo)
# else:
# self.ocomm.serializer.initialize_serializer(sinfo)
self.debug('After update:\n'
+ ' icomm:\n sinfo:\n%s\n typedef:\n%s\n'
+ ' ocomm:\n sinfo:\n%s\n typedef:\n%s',
self.pprint(self.icomm.serializer.serializer_info, 2),
self.pprint(self.icomm.serializer.typedef, 2),
self.pprint(self.ocomm.serializer.serializer_info, 2),
self.pprint(self.ocomm.serializer.typedef, 2))
def _send_message(self, *args, **kwargs):
r"""Send a single message.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
with self.lock:
if self.ocomm.is_closed:
return False
flag = self.ocomm.send(*args, **kwargs)
return flag
def _send_1st_message(self, *args, **kwargs):
r"""Send the first message, trying multiple times.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
self.ocomm._multiple_first_send = False
T = self.start_timeout(self.timeout_send_1st)
flag = self._send_message(*args, **kwargs)
self.ocomm.suppress_special_debug = True
if not flag:
self.debug("1st send failed, will keep trying for %f s in silence.",
float(self.timeout_send_1st))
while ((not T.is_out) and (not flag)
and self.ocomm.is_open): # pragma: debug
flag = self._send_message(*args, **kwargs)
if not flag:
self.sleep()
self.stop_timeout()
self.ocomm.suppress_special_debug = False
self._first_send_done = True
if not flag:
self.error("1st send failed.")
else:
self.debug("1st send succeded")
return flag
[docs] def send_eof(self):
r"""Send EOF message.
Returns:
bool: Success or failure of send.
"""
with self.lock:
if self._eof_sent: # pragma: debug
self.debug('Already sent EOF')
return False
self._eof_sent = True
self.debug('Sent EOF')
return self.send_message(self.ocomm.eof_msg, is_eof=True)
[docs] def send_message(self, *args, **kwargs):
r"""Send a single message.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
self.debug('')
kwargs.pop('is_eof', False)
with self.lock:
self._used = True
if self._first_send_done:
flag = self._send_message(*args, **kwargs)
else:
flag = self._send_1st_message(*args, **kwargs)
# if self.single_use:
# with self.lock:
# self.debug('Used')
# self.icomm.drain_messages()
# self.icomm.close()
return flag
[docs] def set_close_state(self, state):
r"""Set the close state if its not already set."""
out = False
with self.lock:
if not self.close_state:
self.debug("Setting close state to %s", state)
self.close_state = state
out = True
return out
[docs] def wait_close_state(self, state, timeout=None):
r"""Set the close state after waiting for specified time for the
close state to be set by another method.
Args:
state (str): Close state that should be set after timeout.
timeout (float, optional): Time that should be waited before
setting the timeout. Defaults to self.timeout.
"""
T = self.start_timeout(timeout)
while (not T.is_out): # pragma: debug
with self.lock:
if self.close_state:
break
self.sleep(2 * self.sleeptime)
self.stop_timeout()
self.set_close_state(state)
[docs] def run_loop(self):
r"""Run the driver. Continue looping over messages until there are not
any left or the communication channel is closed.
"""
self.state = 'in loop'
if not self.is_valid:
self.debug("Breaking loop")
self.set_close_state('invalid')
self.set_break_flag()
return
# Receive a message
self.state = 'receiving'
msg = self.recv_message()
if msg is False:
self.debug('No more messages')
self.set_break_flag()
self.set_close_state('receiving')
return
if self.icomm.is_empty_recv(msg):
self.state = 'waiting'
self.verbose_debug(':run: Waiting for next message.')
self.sleep()
return
self.nrecv += 1
self.state = 'received'
if isinstance(msg, backwards.bytes_type):
self.debug('Received message that is %d bytes from %s.',
len(msg), self.icomm.address)
elif isinstance(msg, np.ndarray):
self.debug('Received array with shape %s and data type %s from %s',
msg.shape, msg.dtype, self.icomm.address)
else:
self.debug('Received message of type %s from %s',
type(msg), self.icomm.address)
# Process message
self.state = 'processing'
msg = self.on_message(msg)
if msg is False: # pragma: debug
self.error('Could not process message.')
self.set_break_flag()
self.set_close_state('processing')
return
elif self.ocomm.is_empty_send(msg):
self.debug('Message skipped.')
self.nskip += 1
return
self.nproc += 1
self.state = 'processed'
self.debug('Processed message.')
# Send a message
self.state = 'sending'
ret = self.send_message(msg)
if ret is False:
self.error('Could not send message.')
self.set_break_flag()
self.set_close_state('sending')
return
self.nsent += 1
self.state = 'sent'
self.debug('Sent message to %s.', self.ocomm.address)