"""Module for funneling messages from one comm to another."""
import os
import copy
import numpy as np
import functools
import queue
from yggdrasil import multitasking
from yggdrasil.communication import new_comm, CommBase
from yggdrasil.drivers.Driver import Driver
from yggdrasil.components import create_component, isinstance_component
from yggdrasil.drivers.DuplicatedModelDriver import DuplicatedModelDriver
[docs]class TaskThreadError(RuntimeError):
pass
[docs]def run_remotely(method):
r"""Decorator for methods that should be run remotely."""
@functools.wraps(method)
def modified_method(self, *args, **kwargs):
method_name = method.__name__
if self.can_run_remotely:
try:
return self.task_thread.run_task_remote(method_name, args, kwargs)
except (TaskThreadError,
multitasking.AliasDisconnectError): # pragma: debug
pass
return method(self, *args, **kwargs)
return modified_method
[docs]class RemoteTaskLoop(multitasking.YggTaskLoop):
r"""Class to handle running tasks on the connection loop process."""
_disconnect_attr = (multitasking.YggTaskLoop._disconnect_attr
+ ['q_tasks', 'q_results'])
def __init__(self, connection, **kwargs):
self.connection = connection
self.q_tasks = multitasking.Queue(
task_method='process',
task_context=connection.process_instance.context)
self.q_results = multitasking.Queue(
task_method='process',
task_context=connection.process_instance.context)
super(RemoteTaskLoop, self).__init__(target=self.target, **kwargs)
# Overwrite break flag with process safe Event
self.break_flag.disconnect()
self.break_flag = multitasking.Event(
task_method='process',
task_context=connection.process_instance.context)
def __getstate__(self):
out = super(RemoteTaskLoop, self).__getstate__()
del out['connection']
return out
[docs] def is_open(self):
return (not self.was_break)
[docs] def close(self):
r"""Close the queues."""
self.terminate()
self.q_tasks.disconnect()
self.q_results.disconnect()
[docs] def run_task_local(self, task, args, kwargs):
r"""Run task on the current process."""
f_task = getattr(self.connection, task)
if hasattr(f_task, '__call__'):
out = f_task(*args, **kwargs)
else:
out = f_task
return out
[docs] def run_task_remote(self, task, args, kwargs):
r"""Run task on the connection loop process."""
assert (self.connection.as_process
and (not self.connection.in_process)
and self.connection.is_alive())
if self.break_flag.is_set(): # pragma: debug
raise TaskThreadError("Task thread was stopped.")
self.q_tasks.put_nowait((task, args, kwargs))
try:
out = self.q_results.get(timeout=180.0)
except queue.Empty: # pragma: debug
raise TaskThreadError("Task thread was stopped.")
if out == 'TERMINATED': # pragma: debug
raise TaskThreadError("Task thread was stopped.")
return out
[docs] def after_loop(self):
r"""Actions performed after the loop."""
super(RemoteTaskLoop, self).after_loop()
try:
while not self.q_tasks.empty(): # pragma: debug
self.q_tasks.get_nowait()
self.q_results.put('TERMINATED')
except multitasking.AliasDisconnectError: # pragma: debug
pass
[docs] def target(self):
r"""Complete all pending tasks."""
try:
while not self.q_tasks.empty():
self.debug("Task waiting")
args = self.q_tasks.get_nowait()
self.debug("Task received: %s", args)
out = self.run_task_local(*args)
self.debug("Task complete: %s = %s", args, out)
self.q_results.put(out)
self.debug("Task returned: %s", args)
if self.was_break:
return
self.sleep()
except multitasking.AliasDisconnectError: # pragma: debug
self.set_break_flag()
[docs]class ConnectionDriver(Driver):
r"""Class that continuously passes messages from one comm to another.
Args:
name (str): Name that should be used to set names of input/output comms.
inputs (list, optional): One or more dictionaries containing keyword
arguments for constructing input communicators. Defaults to an
empty dictionary if not provided.
outputs (list, optional): One or more dictionaries containing keyword
arguments for constructing output communicators. Defaults to an
empty dictionary if not provided.
input_pattern (str, optional): The communication pattern that should
be used to handle incoming messages when there is more than one
input communicators present. Defaults to 'cycle'. Options
include:
'cycle': Receive from the next available input communicator.
'gather': Receive lists of messages with one element from each
communicator where a message is only returned when there is
a message from each.
output_pattern (str, optional): The communication pattern that should
be used to handling outgoing messages when there is more than one
output communicator present. Defaults to 'broadcast'. Options
include:
'cycle': Rotate through output comms, sending one message to
each.
'broadcast': Send the same message to each comm.
'scatter': Send part of message (must be a list) to each comm.
transform (str, func, optional): Function or string specifying function
that should be used to translate messages from the input communicator
before passing them to the output communicator. If a string, the
format should be "<package.module>:<function>" so that <function>
can be imported from <package>. Defaults to None and messages are
passed directly. This can also be a list of functions/strings that
will be called on the messages in the order they are provided.
timeout_send_1st (float, optional): Time in seconds that should be
waited before giving up on the first send. Defaults to self.timeout.
single_use (bool, optional): If True, the driver will be stopped after
one loop. Defaults to False.
onexit (str, optional): Class method that should be called when a
model that the connection interacts with exits, but before the
connection driver is shut down. Defaults to None.
**kwargs: Additonal keyword arguments are passed to the parent class.
Attributes:
icomm_kws (dict): Keyword arguments for the input communicator.
ocomm_kws (dict): Keyword arguments for the output communicator.
icomm (CommBase): Input communicator.
ocomm (CommBase): Output communicator.
nrecv (int): Number of messages received.
nproc (int): Number of messages processed.
nsent (int): Number of messages sent.
state (str): Descriptor of last action taken.
transform (func): Function that will be used to translate messages from
the input communicator before passing them to the output communicator.
timeout_send_1st (float): Time in seconds that should be waited before
giving up on the first send.
single_use (bool): If True, the driver will be stopped after one
loop.
onexit (str): Class method that should be called when the corresponding
model exits, but before the driver is shut down.
"""
_connection_type = 'default'
_icomm_type = 'default'
_ocomm_type = 'default'
_direction = 'any'
_schema_type = 'connection'
_schema_subtype_key = 'connection_type'
_schema_subtype_description = ('Connection between one or more comms/files '
'and one or more comms/files.')
_schema_subtype_default = 'default'
_connection_type = 'connection'
_schema_required = ['inputs', 'outputs']
_schema_properties = {
'connection_type': {'type': 'string'},
'inputs': {'type': 'array', 'minItems': 1,
'items': {'anyOf': [{'$ref': '#/definitions/comm'},
{'$ref': '#/definitions/file'}]},
'allowSingular': True,
'aliases': ['input', 'from', 'input_file', 'input_files'],
'description': (
'One or more name(s) of model output channel(s) '
'and/or new channel/file objects that the '
'connection should receive messages from. '
'A full description of file entries and the '
'available options can be found :ref:`here<'
'yaml_file_options>`.')},
'outputs': {'type': 'array', 'minItems': 1,
'items': {'anyOf': [{'$ref': '#/definitions/comm'},
{'$ref': '#/definitions/file'}]},
'allowSingular': True,
'aliases': ['output', 'to', 'output_file', 'output_files'],
'description': (
'One or more name(s) of model input channel(s) '
'and/or new channel/file objects that the '
'connection should send messages to. '
'A full description of file entries and the '
'available options can be found :ref:`here<'
'yaml_file_options>`.')},
'input_pattern': {'type': 'string',
'enum': ['cycle', 'gather'],
'default': 'cycle'},
'output_pattern': {'type': 'string',
'enum': ['cycle', 'broadcast', 'scatter'],
'default': 'broadcast'},
'transform': {'type': 'array',
'items': {'anyOf': [
{'$ref': '#/definitions/transform'},
{'type': ['function', 'string']}]},
'allowSingular': True,
'aliases': ['transforms', 'translator',
'translators']},
'read_meth': {'type': 'string', 'deprecated': True,
'enum': ['all', 'line', 'table_array', 'ascii',
'binary', 'json', 'map', 'mat', 'netcdf',
'obj', 'pandas', 'pickle', 'ply', 'table',
'wofost', 'yaml']},
'write_meth': {'type': 'string', 'deprecated': True,
'enum': ['all', 'line', 'table_array', 'ascii',
'binary', 'json', 'map', 'mat', 'netcdf',
'obj', 'pandas', 'pickle', 'ply', 'table',
'wofost', 'yaml']},
'onexit': {'type': 'string'},
'working_dir': {'type': 'string'}}
_schema_excluded_from_class_validation = ['inputs', 'outputs']
_schema_additional_kwargs_base = {
'pushProperties': {
'!$properties/inputs/items': ['transform', 'onexit',
'read_meth', 'write_meth'],
'!$properties/outputs/items/anyOf/1': ['transform', 'onexit',
'read_meth', 'write_meth'],
('$properties/inputs/items/anyOf/1/allOf/1/anyOf/0/'
'properties/serializer'): True,
('$properties/outputs/items/anyOf/1/allOf/1/anyOf/0/'
'properties/serializer'): True}}
_disconnect_attr = Driver._disconnect_attr + [
'_comm_closed', '_skip_after_loop', 'shared', 'task_thread',
'icomm', 'ocomm']
def __init__(self, name, single_use=False, onexit=None,
models=None, **kwargs):
# kwargs['method'] = 'process'
super(ConnectionDriver, self).__init__(name, **kwargs)
# Shared attributes (set once or synced using events)
self.single_use = single_use
self.shared = self.context.Dict()
self.shared.update(nrecv=0, nproc=0, nsent=0,
state='started', close_state='',
_comm_closed=multitasking.DummyEvent(),
_skip_after_loop=multitasking.DummyEvent())
# Attributes used by process
self._eof_sent = False
self._first_send_done = False
self._used = False
self.onexit = None
self.task_thread = None
if self.as_process:
self.task_thread = RemoteTaskLoop(
self, name=('%s.TaskThread' % self.name))
# Translator
if self.transform is None:
self.transform = []
elif not isinstance(self.transform, list):
self.transform = [self.transform]
for i, t in enumerate(self.transform):
if isinstance(t, dict):
self.transform[i] = create_component('transform', **t)
if not hasattr(self.transform[i], '__call__'):
raise ValueError(f"Transform {self.transform[i]} not callable.")
if (onexit is not None) and (not hasattr(self, onexit)):
raise ValueError("onexit '%s' is not a class method." % onexit)
self.onexit = onexit
# Add comms and print debug info
self._init_comms(name, **kwargs)
self.models = models
if self.models is None:
self.models = {'input': list(self.icomm.model_env.keys()),
'output': list(self.ocomm.model_env.keys())}
self.models_recvd = {}
# self.debug(' env: %s', str(self.env))
self.debug(('\n' + 80 * '=' + '\n'
+ 'class = %s\n'
+ ' input: name = %s, address = %s, models=%s\n'
+ ' output: name = %s, address = %s, models=%s\n'
+ (80 * '=')), self.__class__,
self.icomm.name, self.icomm.address, self.models['input'],
self.ocomm.name, self.ocomm.address, self.models['output'])
def _init_single_comm(self, io, comm_list):
r"""Parse keyword arguments for input/output comm."""
self.debug("Creating %s comm", io)
comm_kws = dict()
assert isinstance(comm_list, list)
assert comm_list
if io == 'input':
direction = 'recv'
attr_comm = 'icomm'
comm_kws['close_on_eof_recv'] = False
comm_type = self._icomm_type
else:
direction = 'send'
attr_comm = 'ocomm'
comm_type = self._ocomm_type
comm_kws['direction'] = direction
comm_kws['dont_open'] = True
comm_kws['reverse_names'] = True
comm_kws['use_async'] = True
comm_kws['name'] = self.name
if len(comm_list) > 0:
comm_kws['pattern'] = getattr(self, f'{io}_pattern')
for i, x in enumerate(comm_list):
if x is None:
comm_list[i] = dict()
else:
assert isinstance(x, dict)
if 'filetype' not in comm_list[i]:
comm_list[i].setdefault('commtype', comm_type)
if self.as_process:
comm_list[i]['buffer_task_method'] = 'process'
if (((comm_list[i].get('partner_copies', 0) > 1)
and (not comm_list[i].get('is_client', False))
and (direction == 'send')
and (not comm_list[i].get('dont_copy', False)))):
from yggdrasil.communication import ForkComm
# TODO: Handle recv?
comm_list[i]['commtype'] = [
dict(comm_list[i],
partner_model=DuplicatedModelDriver.name_format % (
comm_list[i]['partner_model'], idx))
for idx in range(comm_list[i]['partner_copies'])]
for k in ForkComm.ForkComm.child_keys:
comm_list[i].pop(k, None)
comm_kws['commtype'] = copy.deepcopy(comm_list)
for x in comm_kws['commtype']:
if isinstance(x.get('datatype', {}), dict):
if ((x.get('datatype', {}).get('from_function', False)
and (x.get('datatype', {}).get('type', None)
in ['any', 'instance']))):
x['datatype'] = {'type': 'scalar', 'subtype': 'string'}
x.get('datatype', {}).pop('from_function', False)
self.debug('%s comm_kws:\n%s', attr_comm, self.pprint(comm_kws, 1))
setattr(self, attr_comm, new_comm(**comm_kws))
setattr(self, '%s_kws' % attr_comm, comm_kws)
def _init_comms(self, name, **kwargs):
r"""Parse keyword arguments for input/output comms."""
self._init_single_comm('input', self.inputs)
try:
self._init_single_comm('output', self.outputs)
except BaseException:
self.icomm.close()
self.icomm.disconnect()
raise
# Apply keywords dependent on comms
if self.icomm.any_files:
kwargs.setdefault('timeout_send_1st', 60)
self.timeout_send_1st = kwargs.pop('timeout_send_1st', self.timeout)
self.debug('Final env:\n%s', self.pprint(self.env, 1))
def __setstate__(self, state):
super(ConnectionDriver, self).__setstate__(state)
if self.as_process:
self.task_thread.connection = self
@property
def model_env(self):
r"""dict: Mapping between model name and opposite comm
environment variables that need to be provided to the model."""
out = {}
for x in [self.icomm, self.ocomm]:
if x._commtype == 'mpi':
continue
iout = x.model_env
for k, v in iout.items():
if k in out:
out[k].update(v)
else:
out[k] = v
return out
[docs] def get_flag_attr(self, attr):
r"""Return the flag attribute."""
if hasattr(self, 'shared') and (attr in self.shared):
return self.shared[attr]
return super(ConnectionDriver, self).get_flag_attr(attr)
[docs] def set_flag_attr(self, attr, value=True):
r"""Set a flag."""
if hasattr(self, 'shared') and (attr in self.shared):
exist = self.shared[attr]
if value:
exist.set()
else:
exist.clear()
self.shared[attr] = exist
return
super(ConnectionDriver, self).set_flag_attr(attr, value=value)
@property
def nrecv(self):
r"""int: Number of messages received."""
return self.shared['nrecv']
@nrecv.setter
def nrecv(self, x):
self.shared['nrecv'] = x
@property
def nsent(self):
r"""int: Number of messages sent."""
return self.shared['nsent']
@nsent.setter
def nsent(self, x):
self.shared['nsent'] = x
@property
def nproc(self):
r"""int: Number of messages processed."""
return self.shared['nproc']
@nproc.setter
def nproc(self, x):
self.shared['nproc'] = x
@property
def state(self):
r"""str: Current state of the connection."""
return self.shared['state']
@state.setter
def state(self, x):
if hasattr(self, 'shared'):
self.shared['state'] = x
@property
def close_state(self):
r"""str: State of the connection at close."""
return self.shared['close_state']
@close_state.setter
def close_state(self, x):
self.shared['close_state'] = x
@property
def can_run_remotely(self):
r"""bool: True if process should be run remotely."""
return (self.as_process and (not self.in_process)
and self.is_alive()
and self.task_thread.is_open())
[docs] @run_remotely
def wait_for_route(self, timeout=None):
r"""Wait until messages have been routed."""
T = self.start_timeout(timeout, key_suffix='.route')
while ((not T.is_out)
and (self.icomm.n_msg > 0)
and (self.nrecv != self.nsent)): # pragma: debug
self.sleep()
self.stop_timeout(key_suffix='.route')
return (self.nrecv == self.nsent)
@property
@run_remotely
def is_valid(self):
r"""bool: Returns True if the connection is open and the parent class
is valid."""
with self.lock:
return (super(ConnectionDriver, self).is_valid
and self.is_comm_open and not (self.single_use and self._used))
@property
@run_remotely
def is_comm_open(self):
r"""bool: Returns True if both communicators are open."""
with self.lock:
return (self.icomm.is_open and self.ocomm.is_open
and not self.check_flag_attr('_comm_closed'))
@property
@run_remotely
def is_comm_closed(self):
r"""bool: Returns True if both communicators are closed."""
with self.lock:
return self.icomm.is_closed and self.ocomm.is_closed
@property
@run_remotely
def n_msg(self):
r"""int: Number of messages waiting in input communicator."""
with self.lock:
return self.icomm.n_msg_recv
[docs] @run_remotely
def open_comm(self):
r"""Open the communicators."""
self.debug('')
with self.lock:
if self.check_flag_attr('_comm_closed'):
self.debug('Aborted as comm closed')
return
try:
self.icomm.open()
self.ocomm.open()
except BaseException:
self.close_comm()
raise
self.debug('Returning')
[docs] @run_remotely
def close_comm(self):
r"""Close the communicators."""
self.debug('')
with self.lock:
self.set_flag_attr('_comm_closed')
self.set_flag_attr('_skip_after_loop')
# Capture errors for both comms
ie = None
oe = None
try:
if getattr(self, 'icomm', None) is not None:
self.icomm.close()
self.icomm.disconnect()
except BaseException as e:
ie = e
try:
if getattr(self, 'ocomm', None) is not None:
self.ocomm.close()
self.ocomm.disconnect()
except BaseException as e:
oe = e
if ie:
raise ie
if oe:
raise oe
self.debug('Returning')
[docs] def start(self):
r"""Open connection before running."""
if not self.as_process:
self.open_comm()
Tout = self.start_timeout()
while (not self.is_comm_open) and (not Tout.is_out):
self.sleep()
self.stop_timeout()
if not self.is_comm_open:
raise Exception("Connection never finished opening.")
super(ConnectionDriver, self).start()
self.debug('Started connection process')
if self.as_process:
self.wait_flag_attr('loop_flag', timeout=120.0)
self.icomm.disconnect()
self.ocomm.disconnect()
[docs] def graceful_stop(self, timeout=None, **kwargs):
r"""Stop the driver, first waiting for the input comm to be empty.
Args:
timeout (float, optional): Max time that should be waited. Defaults
to None and is set to attribute timeout.
**kwargs: Additional keyword arguments are passed to the parent
class's graceful_stop method.
"""
self.debug('')
with self.lock:
self.set_close_state('stop')
self.set_flag_attr('_skip_after_loop')
self.drain_input(timeout=timeout)
self.wait_for_route(timeout=timeout)
self.drain_output(timeout=timeout)
super(ConnectionDriver, self).graceful_stop()
self.debug('Returning')
[docs] @run_remotely
def remove_model(self, direction, name):
r"""Remove a model from the list of models.
Args:
direction (str): Direction of model.
name (str): Name of model exiting.
Returns:
bool: True if all of the input/output models have signed
off; False otherwise.
"""
self.debug('')
with self.lock:
if name in self.models[direction]:
self.models[direction].remove(name)
self.debug(("%s model '%s' signed off."
"\n\tInput models: %d"
"\n\tOutput models: %d")
% (direction.title(), name,
len(self.models["input"]),
len(self.models["output"])))
return (len(self.models[direction]) == 0)
[docs] @run_remotely
def on_model_exit_remote(self, direction, name,
errors=False):
r"""Drain input and then close it (on the remote process).
Args:
direction (str): Direction of model.
name (str): Name of model exiting.
errors (list, optional): Errors generated by the
model. Defaults to False.
Returns:
bool: True if all of the input/output models have signed
off; False otherwise.
"""
if not self.remove_model(direction, name):
self.debug("%s models remain: %s",
direction, self.models[direction])
return False
if not self.is_alive():
return False
self.debug("All %s models have signed off.", direction)
if (((self.onexit not in [None, 'on_model_exit', 'pass'])
and (not errors))):
self.debug("Calling onexit = '%s'" % self.onexit)
getattr(self, self.onexit)()
if not errors:
if direction == 'output':
T = self.start_timeout(60, key_suffix='.model_exit')
while (not T.is_out) and self.models['input']:
self.debug("remaining input models: %s",
self.models['input'])
self.sleep(10 * self.sleeptime)
self.stop_timeout(key_suffix='.model_exit')
self.drain_input(timeout=self.timeout)
if direction == 'input':
if not errors:
self.wait_for_route(timeout=self.timeout)
with self.lock:
self.icomm.close()
elif direction == 'output':
with self.lock:
# self.icomm.close()
self.ocomm.close()
self.set_close_state('%s model exit' % direction)
self.debug('Exit of %s model triggered close', direction)
self.set_break_flag()
return True
[docs] def on_model_exit(self, direction, name, errors=False):
r"""Drain input and then close it."""
self.debug('%s model %s exiting', direction.title(), name)
if self.on_model_exit_remote(direction, name,
errors=errors):
self.wait()
self.debug('Finished')
[docs] def do_terminate(self):
r"""Stop the driver by closing the communicators."""
self.debug('')
self.set_close_state('terminate')
self.close_comm()
if self.as_process:
self.task_thread.terminate()
super(ConnectionDriver, self).do_terminate()
[docs] def cleanup(self):
r"""Ensure that the communicators are closed."""
self.close_comm()
if self.as_process:
self.task_thread.close()
super(ConnectionDriver, self).cleanup()
[docs] @run_remotely
def printStatus(self, beg_msg='', end_msg='',
verbose=False, return_str=False):
r"""Print information on the status of the ConnectionDriver.
Arguments:
beg_msg (str, optional): Additional message to print at beginning.
end_msg (str, optional): Additional message to print at end.
verbose (bool, optional): If True, the status of
individual comms will be displayed. Defaults to
False.
return_str (bool, optional): If True, the message string is
returned. Defaults to False.
"""
msg = beg_msg
msg += '%-50s' % (self.__module__.split('.')[-1] + '(' + self.name + '): ')
msg += '\n\t'
msg += '%-30s' % ('last action: ' + self.state)
msg += '%-25s' % ('is_open(%s, %s), ' % (self.icomm.is_open,
self.ocomm.is_open))
msg += '%-15s' % (str(self.nrecv) + ' received, ')
msg += '%-15s' % (str(self.nproc) + ' processed, ')
msg += '%-15s' % (str(self.nsent) + ' sent, ')
msg += '%-20s' % (str(self.icomm.n_msg) + ' ready to recv')
msg += '%-20s' % (str(self.ocomm.n_msg) + ' ready to send')
with self.lock:
if self.close_state:
msg += '%-30s' % ('close state: ' + self.close_state)
msg += end_msg
if not return_str:
print(msg)
if verbose:
i_msg = self.icomm.printStatus(return_str=return_str)
o_msg = self.ocomm.printStatus(return_str=return_str)
if return_str:
msg += '\n%s\n%s' % (i_msg, o_msg)
return msg
[docs] @run_remotely
def confirm_output(self, timeout=None):
r"""Confirm receipt of messages from output comm."""
T = self.start_timeout(timeout, key_suffix='.confirm_output')
while not T.is_out: # pragma: debug
with self.lock:
if (not self.ocomm.is_open):
break
elif self.ocomm.is_confirmed_send:
break
self.sleep(10 * self.sleeptime)
self.stop_timeout(key_suffix='.confirm_output')
[docs] @run_remotely
def drain_output(self, timeout=None, dont_confirm_eof=False):
r"""Drain messages from output comm."""
nwait = 0
if dont_confirm_eof:
nwait += 1
T = self.start_timeout(timeout, key_suffix='.drain_output')
while not T.is_out:
with self.lock:
if (not (self.ocomm.is_open
or self.was_terminated)): # pragma: no cover
break
elif ((self.ocomm.n_msg_send_drain <= nwait)
and self.ocomm.is_confirmed_send):
break
self.sleep() # pragma: no cover
self.stop_timeout(key_suffix='.drain_output')
[docs] def before_loop(self):
r"""Actions to perform prior to sending messages."""
self.state = 'before loop'
try:
if self.as_process:
self.task_thread.start()
self.open_comm()
self.sleep() # Help ensure senders/receivers connected before messages
self.debug('Running in %s, is_valid = %s', os.getcwd(), str(self.is_valid))
assert self.is_valid
except BaseException: # pragma: debug
self.printStatus()
self.exception('Could not prep for loop (is_open = (%s, %s)).' % (
self.icomm.is_open, self.ocomm.is_open))
self.close_comm()
self.set_break_flag()
if self.as_process:
self.task_thread.terminate()
[docs] def after_loop_process(self):
r"""Actions to preform after loop for process."""
self.debug("After loop process")
self.task_thread.set_break_flag()
self.task_thread.wait()
[docs] def after_loop(self):
r"""Actions to perform after sending messages."""
self.state = 'after loop'
self.debug('')
# Close input comm in case loop did not
self.confirm_input(timeout=False)
self.debug('Confirmed input')
if self.check_flag_attr('_skip_after_loop') and self.as_process:
self.after_loop_process()
with self.lock:
self.debug('Acquired lock')
if self.check_flag_attr('_skip_after_loop'):
self.debug("After loop skipped.")
return
self.icomm.close()
# Send EOF in case the model didn't
if not self.single_use:
self.send_eof()
# Do not close output comm in case model/connection still receiving
if self.as_process and self.ocomm.touches_model:
self.drain_output(timeout=False, dont_confirm_eof=True)
self.debug('Finished')
if self.as_process:
self.after_loop_process()
[docs] def recv_message(self, **kwargs):
r"""Get a new message to send.
Args:
**kwargs: Additional keyword arguments are passed to the appropriate
recv method.
Returns:
CommMessage, bool: False if no more messages, message otherwise.
"""
assert self.in_process
kwargs.setdefault('timeout', 0)
with self.lock:
if self.icomm.is_closed:
return False
msg = self.icomm.recv(return_message_object=True, **kwargs)
self.errors += self.icomm.errors
if msg.header and ('model' in msg.header.get('__meta__', {})):
self.models_recvd.setdefault(msg.header['__meta__']['model'], 0)
self.models_recvd[msg.header['__meta__']['model']] += 1
if ((self.models_recvd[msg.header['__meta__']['model']] == 1
and msg.header['__meta__']['model'] not in self.models['input'])):
self.models['input'].append(msg.header['__meta__']['model'])
if msg.flag == CommBase.FLAG_EOF:
return self.on_eof(msg)
if msg.flag == CommBase.FLAG_SUCCESS:
return msg
else:
return bool(msg.flag)
[docs] def on_eof(self, msg):
r"""Actions to take when EOF received.
Args:
msg (CommMessage): Message object that provided the EOF.
Returns:
CommMessage, bool: Value that should be returned by recv_message on EOF.
"""
with self.lock:
self.debug('EOF received')
self.state = 'eof'
self.set_close_state('eof')
self.set_break_flag()
self.debug('After EOF')
return False
[docs] def on_message(self, msg):
r"""Process a message.
Args:
msg (bytes, str): Message to be processed.
Returns:
bytes, str: Processed message.
"""
if (self.ocomm._send_serializer) and self.icomm.serializer.initialized:
self.update_serializer(msg)
for t in self.transform:
msg.args = t(msg.args)
return msg
[docs] def update_serializer(self, msg):
r"""Update the serializer for the output comm based on input."""
self.debug('Before update:\n icomm:%s\n ocomm:%s\n'
% ("\n".join(self.icomm.get_status_message(nindent=1)[0][1:]),
"\n".join(self.ocomm.get_status_message(nindent=1)[0][1:])))
for t in self.transform:
if isinstance_component(t, 'transform'):
t.set_original_datatype(msg.stype)
msg.stype = t.transformed_datatype
if self.transform:
msg.sinfo = {}
# This can be removed if send_message is set up to update and send the
# received message rather than create a new one by sending msg.args
self.ocomm.update_serializer_from_message(msg)
self.debug('After update:\n icomm:\n%s\n ocomm:\n%s\n'
% ("\n".join(self.icomm.get_status_message(nindent=1)[0][1:]),
"\n".join(self.ocomm.get_status_message(nindent=1)[0][1:])))
def _send_message(self, *args, **kwargs):
r"""Send a single message.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
with self.lock:
if self.ocomm.is_closed:
return False
return self.ocomm.send_message(*args, **kwargs)
def _send_1st_message(self, *args, **kwargs):
r"""Send the first message, trying multiple times.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
self.ocomm._multiple_first_send = False
T = self.start_timeout(self.timeout_send_1st,
key_suffix='.1st_send')
flag = self._send_message(*args, **kwargs)
self.ocomm.suppress_special_debug = True
if (not flag) and (not self.ocomm._type_errors):
self.debug("1st send failed, will keep trying for %f s in silence.",
float(self.timeout_send_1st))
while ((not T.is_out) and (not flag)
and self.ocomm.is_open): # pragma: debug
flag = self._send_message(*args, **kwargs)
if not flag:
self.sleep()
self.stop_timeout(key_suffix='.1st_send')
self.ocomm.suppress_special_debug = False
self._first_send_done = True
if not flag:
self.error("1st send failed.")
else:
self.debug("1st send succeded")
return flag
[docs] def send_eof(self, **kwargs):
r"""Send EOF message.
Returns:
bool: Success or failure of send.
"""
with self.lock:
if self._eof_sent: # pragma: debug
self.debug('Already sent EOF')
return False
self._eof_sent = True
self.debug('Sent EOF')
msg = CommBase.CommMessage(flag=CommBase.FLAG_EOF,
args=self.ocomm.eof_msg)
return self.send_message(msg, **kwargs)
[docs] def send_message(self, msg, **kwargs):
r"""Send a single message.
Args:
msg (CommMessage): Message being sent.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
assert self.in_process
self.debug('')
with self.lock:
self._used = True
if (msg.header is not None) and ('model' in msg.header.get('__meta__', {})):
kwargs.setdefault('header_kwargs', {})
kwargs['header_kwargs'].setdefault('__meta__', {})
kwargs['header_kwargs']['__meta__'].setdefault(
'model', msg.header['__meta__']['model'])
kws_prepare = {k: kwargs.pop(k) for k in self.ocomm._prepare_message_kws
if k in kwargs}
msg_out = self.ocomm.prepare_message(msg.args, **kws_prepare)
if self._first_send_done:
flag = self._send_message(msg_out, **kwargs)
else:
flag = self._send_1st_message(msg_out, **kwargs)
# if self.single_use:
# with self.lock:
# self.debug('Used')
# self.icomm.drain_messages()
# self.icomm.close()
self.errors += self.ocomm.errors
return flag
[docs] def set_close_state(self, state):
r"""Set the close state if its not already set."""
out = False
with self.lock:
if not self.close_state:
self.debug("Setting close state to %s", state)
self.close_state = state
out = True
return out
[docs] def run_loop(self):
r"""Run the driver. Continue looping over messages until there are not
any left or the communication channel is closed.
"""
self.state = 'in loop'
# if not self.is_valid:
if (((self.single_use and self._used)
or self.check_flag_attr('_comm_closed'))):
self.debug("Breaking loop")
self.set_close_state('invalid')
self.set_break_flag()
return
# Receive a message
self.state = 'receiving'
msg = self.recv_message()
if msg is False:
self.debug('No more messages')
self.set_break_flag()
self.set_close_state('receiving')
return
if (msg is True) or (isinstance(msg, CommBase.CommMessage)
and (msg.flag != CommBase.FLAG_SUCCESS)):
self.state = 'waiting'
self.verbose_debug(':run: Waiting for next message.')
self.sleep()
return
self.nrecv += 1
self.state = 'received'
if isinstance(msg.args, bytes):
self.debug('Received message that is %d bytes from %s.',
len(msg.args), self.icomm.address)
elif isinstance(msg.args, np.ndarray):
self.debug('Received array with shape %s and data type %s from %s',
msg.args.shape, msg.args.dtype, self.icomm.address)
else:
self.debug('Received message of type %s from %s',
type(msg.args), self.icomm.address)
# Process message
self.state = 'processing'
msg = self.on_message(msg)
if msg is False: # pragma: debug
self.error('Could not process message.')
self.set_break_flag()
self.set_close_state('processing')
return
self.nproc += 1
self.state = 'processed'
self.debug('Processed message.')
# Send a message
self.state = 'sending'
ret = self.send_message(msg)
if ret is False:
self.error('Could not send message.')
self.set_break_flag()
self.set_close_state('sending')
return
self.nsent += 1
self.state = 'sent'
self.debug('Sent message to %s.', self.ocomm.address)