import copy
from yggdrasil.communication import CommBase, get_comm, import_comm
_address_sep = ':YGG_ADD:'
_pattern_pairs = [('scatter', 'gather')]
[docs]class ForkedCommMessage(CommBase.CommMessage):
r"""Class for forked comm messages.
Args:
msg (CommBase.CommMessage): Message being distributed.
comm_list (list): List of communicators that the message is
being distributed to.
**kwargs: Additional keyword arguments are passed to the 'prepare_message'
method for each communicator.
"""
__slots__ = ['orig']
def __init__(self, msg, comm_list, pattern='broadcast', **kwargs):
super(ForkedCommMessage, self).__init__(
msg=msg.msg, length=msg.length, flag=msg.flag,
args=msg.args, header=msg.header)
for k in CommBase.CommMessage.__slots__:
setattr(self, k, getattr(msg, k))
if (pattern in ['broadcast', 'cycle']) or (msg.flag == CommBase.FLAG_EOF):
msg_list = [copy.deepcopy(msg)
for _ in range(len(comm_list))]
elif pattern == 'scatter':
msg_list = [copy.deepcopy(msg.args[i])
for i in range(len(comm_list))]
kwargs.setdefault('flag', msg.flag)
else: # pragma: debug
raise ValueError("Unsupported pattern: '%s'" % pattern)
if msg.header:
kwargs.setdefault('header_kwargs', msg.header)
args = {i: x.prepare_message(msg_list[i], **kwargs)
for i, x in enumerate(comm_list)}
self.orig = msg.args
self.args = args
[docs]def get_comm_name(name, i):
r"""Get the name of the ith comm in the series.
Args:
name (str): Name of the fork comm.
i (int): Index of comm in fork bundle.
Returns:
str: Name of ith comm in fork bundle.
"""
return '%s_%d' % (name, i)
[docs]class ForkComm(CommBase.CommBase):
r"""Class for receiving/sending messages from/to multiple comms.
Args:
name (str): The environment variable where communication address is
stored.
comm_list (list, optional): The list of options for the comms that
should be bundled. If not provided, the bundle will be empty.
pattern (str, optional): The communication pattern that should be
used to handle outgoing/incoming messages. Options include:
'cycle': Receive from or send to the next available comm in
the list (default for receiving comms).
'broadcast': [SEND ONLY] Send the same message to each comm
(default for sending comms).
'scatter': [SEND ONLY] Send part of message (must be a list)
to each comm.
'gather': [RECV ONLY] Receive lists of messages from each
comm where a message is only returned when there is a
message from each.
**kwargs: Additional keyword arguments are passed to the parent class.
Attributes:
comm_list (list): Comms included in this fork.
curr_comm_index (int): Index comm that next receive will be from.
"""
_commtype = 'fork'
_dont_register = True
child_keys = ['serializer_class', 'serializer_kwargs',
'format_str', 'field_names', 'field_units', 'as_array',
'partner_copies']
noprop_keys = ['send_converter', 'recv_converter', 'filter', 'transform']
def __init__(self, name, comm_list=None, is_async=False,
pattern=None, **kwargs):
child_kwargs = {k: kwargs.pop(k) for k in self.child_keys if k in kwargs}
noprop_kwargs = {k: kwargs.pop(k) for k in self.noprop_keys if k in kwargs}
self.comm_list_backlog = {}
self.comm_list = []
self.curr_comm_index = 0
self.eof_recv = []
self.eof_send = []
self.pattern = pattern
if kwargs.get('direction', 'send') == 'recv':
# if self.pattern is None:
# self.pattern = 'cycle'
assert self.pattern in ['cycle', 'gather']
else:
if self.pattern is None:
self.pattern = 'broadcast'
assert self.pattern in ['cycle', 'scatter', 'broadcast']
address = kwargs.pop('address', None)
if comm_list is None:
if isinstance(address, list):
ncomm = len(address)
else:
ncomm = 0
comm_list = [None for i in range(ncomm)]
assert isinstance(comm_list, list)
ncomm = len(comm_list)
for i in range(ncomm):
if comm_list[i] is None:
comm_list[i] = {}
if comm_list[i].get('name', None) is None:
comm_list[i]['name'] = get_comm_name(name, i)
for k in child_kwargs.keys():
if k in comm_list[i]: # pragma: debug
raise ValueError("The keyword '%s' was specified for both the "
"root ForkComm and a child comm, but can only "
"be present for one." % k)
if isinstance(address, list):
assert len(address) == ncomm
for i in range(ncomm):
comm_list[i]['address'] = address[i]
for i in range(ncomm):
ikw = copy.deepcopy(kwargs)
ikw.update(child_kwargs)
ikw.update(comm_list[i])
ikw.setdefault('use_async', is_async)
iname = ikw.pop('name')
self.comm_list.append(get_comm(iname, **ikw))
self.eof_recv.append(0)
self.eof_send.append(0)
self.comm_list_backlog[i] = []
if ncomm > 0:
kwargs['address'] = [x.address for x in self.comm_list]
kwargs.update(noprop_kwargs)
super(ForkComm, self).__init__(name, is_async=is_async, **kwargs)
assert not self.single_use
assert not self.is_server
assert not (self.is_client and (self.pattern != 'cycle'))
[docs] def disconnect(self):
r"""Disconnect attributes that are aliases."""
for x in self.comm_list:
x.disconnect()
super(ForkComm, self).disconnect()
[docs] def get_status_message(self, **kwargs):
r"""Return lines composing a status message.
Args:
**kwargs: Keyword arguments are passed on to the parent class's
method.
Returns:
tuple(list, prefix): Lines composing the status message and the
prefix string used for the last message.
"""
nindent = kwargs.get('nindent', 0)
extra_lines_after = ['%-15s: %s' % ('pattern', self.pattern)]
for x in self.comm_list:
extra_lines_after += x.get_status_message(nindent=nindent + 1)[0]
extra_lines_after += kwargs.get('extra_lines_after', [])
kwargs['extra_lines_after'] = extra_lines_after
return super(ForkComm, self).get_status_message(**kwargs)
def __len__(self):
return len(self.comm_list)
@property
def any_files(self):
r"""bool: True if the comm interfaces with any files."""
return any(x.is_file for x in self.comm_list)
@property
def last_comm(self):
r"""CommBase: Last comm that was used."""
idx = self.curr_comm_index
if idx > 0:
idx -= 1
return self.comm_list[idx % len(self)]
@property
def curr_comm(self):
r"""CommBase: Current comm."""
return self.comm_list[self.curr_comm_index % len(self)]
@property
def maxMsgSize(self):
r"""int: Maximum size of a single message that should be sent."""
return min([x.maxMsgSize for x in self.comm_list])
[docs] @classmethod
def new_comm_kwargs(cls, name, *args, **kwargs):
r"""Get keyword arguments for new comm."""
if 'address' not in kwargs:
addresses = []
comm_list = kwargs.get('comm_list', None)
ncomm = kwargs.pop('ncomm', 0)
if comm_list is None:
comm_list = [None for i in range(ncomm)]
assert isinstance(comm_list, list)
ncomm = len(comm_list)
for i in range(ncomm):
x = comm_list[i]
if x is None:
x = {}
iname = x.pop('name', get_comm_name(name, i))
icls = import_comm(x.get('commtype', None))
_, ickw = icls.new_comm_kwargs(iname, **x)
ickw['name'] = iname
comm_list[i] = ickw
addresses.append(ickw['address'])
kwargs['comm_list'] = comm_list
kwargs['address'] = addresses
args = tuple([name] + list(args))
return args, kwargs
@property
def model_env(self):
r"""dict: Mapping between model name and opposite comm
environment variables that need to be provided to the model."""
out = {}
for x in self.comm_list:
iout = x.model_env
for k, v in iout.items():
out.setdefault(k, {})
out[k].update(v)
return out
# @property
# def mpi_model_kws(self):
# r"""dict: Mapping between model name and opposite comm keyword
# arguments that need to be provided to the model for the MPI
# connection."""
# out = {}
# for x in self.comm_list:
# iout = x.mpi_model_kws
# for k, v in iout.items():
# out.setdefault(k, [])
# out[k] += v
# return out
@property
def opp_comms(self):
r"""dict: Name/address pairs for opposite comms."""
out = super(ForkComm, self).opp_comms
out.pop(self.name)
for x in self.comm_list:
out.update(**x.opp_comms)
return out
[docs] def opp_comm_kwargs(self, for_yaml=False):
r"""Get keyword arguments to initialize communication with opposite
comm object.
Args:
for_yaml (bool, optional): If True, the returned dict will only
contain values that can be specified in a YAML file. Defaults
to False.
Returns:
dict: Keyword arguments for opposite comm object.
"""
kwargs = super(ForkComm, self).opp_comm_kwargs(for_yaml=for_yaml)
kwargs['comm_list'] = [x.opp_comm_kwargs(for_yaml=for_yaml)
for x in self.comm_list]
for pair in _pattern_pairs:
if self.pattern in pair:
kwargs['pattern'] = pair[(pair.index(self.pattern) + 1) % 2]
return kwargs
@property
def get_response_comm_kwargs(self):
r"""dict: Keyword arguments to use for a response comm."""
assert self.pattern == 'cycle'
return self.curr_comm.get_response_comm_kwargs
[docs] def bind(self):
r"""Bind in place of open."""
for x in self.comm_list:
x.bind()
[docs] def open(self):
r"""Open the connection."""
for x in self.comm_list:
x.open()
[docs] def close(self, *args, **kwargs):
r"""Close the connection."""
for x in self.comm_list:
x.close(*args, **kwargs)
[docs] def close_in_thread(self, *args, **kwargs): # pragma: debug
r"""In a new thread, close the comm when it is empty."""
raise Exception("ForkComm should not be closed in thread.")
@property
def is_open(self):
r"""bool: True if the connection is open."""
for x in self.comm_list:
if x.is_open:
return True
return False
@property
def is_confirmed_send(self):
r"""bool: True if all sent messages have been confirmed."""
for x in self.comm_list:
if not x.is_confirmed_send: # pragma: debug
return False
return True
@property
def is_confirmed_recv(self):
r"""bool: True if all received messages have been confirmed."""
for x in self.comm_list:
if not x.is_confirmed_recv: # pragma: debug
return False
return True
[docs] def confirm_send(self, noblock=False):
r"""Confirm that sent message was received."""
for x in self.comm_list:
if not x.confirm_send(noblock=noblock): # pragma: debug
return False
return True
[docs] def confirm_recv(self, noblock=False):
r"""Confirm that message was received."""
for x in self.comm_list:
if not x.confirm_recv(noblock=noblock): # pragma: debug
return False
return True
@property
def n_msg_direct_recv(self):
r"""int: Number of messages currently being routed in recv."""
if self.pattern == 'gather':
return min([x.n_msg_direct_recv for x in self.comm_list])
return sum([x.n_msg_direct_recv for x in self.comm_list])
@property
def n_msg_direct_send(self):
r"""int: Number of messages currently being routed in send."""
if self.pattern in ['broadcast', 'scatter']:
return max([x.n_msg_direct_send for x in self.comm_list])
return sum([x.n_msg_direct_send for x in self.comm_list])
@property
def n_msg_direct(self):
r"""int: Number of messages currently being routed."""
if self.direction == 'send':
return self.n_msg_direct_send
else:
return self.n_msg_direct_recv
@property
def n_msg_recv(self):
r"""int: The number of incoming messages in the connection."""
if self.pattern == 'gather':
return min([x.n_msg_recv for x in self.comm_list])
return sum([x.n_msg_recv for x in self.comm_list])
@property
def n_msg_send(self):
r"""int: The number of outgoing messages in the connection."""
if self.pattern in ['broadcast', 'scatter']:
return max([x.n_msg_send for x in self.comm_list])
return sum([x.n_msg_send for x in self.comm_list])
@property
def n_msg_recv_drain(self):
r"""int: The number of incoming messages in the connection to drain."""
return sum([x.n_msg_recv_drain for x in self.comm_list])
@property
def n_msg_send_drain(self):
r"""int: The number of outgoing messages in the connection to drain."""
return sum([x.n_msg_send_drain for x in self.comm_list])
@property
def empty_obj_recv(self):
r"""obj: Empty message object."""
if self.pattern in ['gather']:
return []
return self.last_comm.empty_obj_recv
[docs] def update_serializer_from_message(self, msg):
r"""Update the serializer based on information stored in a message.
Args:
msg (CommMessage): Outgoing message.
"""
if msg.stype is not None:
msg.stype = self.apply_transform_to_type(msg.stype)
if (self.direction == 'send') and (self.pattern == 'scatter'):
if msg.stype['type'] != 'array': # pragma: debug
raise RuntimeError("Only 'array' type messages can be "
"scattered.")
for i, x in enumerate(self.comm_list):
imsg = copy.deepcopy(msg)
imsg.header = {'__meta__': {}}
imsg.stype = msg.stype['items'][i]
imsg.args = msg.args[i]
x.update_serializer_from_message(imsg)
return
for x in self.comm_list:
x.update_serializer_from_message(msg)
[docs] def prepare_message(self, *args, **kwargs):
r"""Perform actions preparing to send a message.
Args:
*args: Components of the outgoing message.
**kwargs: Additional keyword arguments are passed to the prepare_message
methods for the forked comms.
Returns:
CommMessage: Serialized and annotated message.
"""
kws_root = {'skip_serialization': True}
for k in ['header_kwargs']:
if k in kwargs:
kws_root[k] = kwargs.pop(k)
msg = super(ForkComm, self).prepare_message(*args, **kws_root)
if not isinstance(msg, ForkedCommMessage):
msg = ForkedCommMessage(msg, self.comm_list,
pattern=self.pattern, **kwargs)
return msg
[docs] def send_message(self, msg, **kwargs):
r"""Send a message encapsulated in a CommMessage object.
Args:
msg (CommMessage): Message to be sent.
**kwargs: Additional keyword arguments are passed to _safe_send.
Returns:
bool: Success or failure of send.
"""
assert isinstance(msg.args, dict)
for idx in range(len(self)):
i = self.curr_comm_index % len(self)
x = self.curr_comm
out = x.send_message(msg.args[i], **kwargs)
self.errors += x.errors
if msg.flag == CommBase.FLAG_EOF:
self.eof_send[i] = 1
self.curr_comm_index += 1
if not out:
return out
elif (self.pattern == 'cycle') and (msg.flag != CommBase.FLAG_EOF):
break
msg.args = msg.orig
msg.additional_messages = []
kwargs['skip_safe_send'] = True
return super(ForkComm, self).send_message(msg, **kwargs)
[docs] def recv_message(self, *args, **kwargs):
r"""Receive a message.
Args:
*args: Arguments are passed to the forked comm's recv_message method.
**kwargs: Keyword arguments are passed to the forked comm's recv_message
method.
Returns:
CommMessage: Received message.
"""
timeout = kwargs.pop('timeout', None)
if timeout is None:
timeout = self.recv_timeout
kwargs['timeout'] = 0
first_comm = True
T = self.start_timeout(timeout, key_suffix='recv:forkd')
out = None
out_gather = {}
idx = None
if self.pattern == 'gather':
def complete():
return (len(out_gather) == len(self))
else:
def complete():
return bool(out_gather)
while ((not T.is_out) or first_comm) and self.is_open and (not complete()):
for i in range(len(self)):
if complete():
break
idx = self.curr_comm_index % len(self)
x = self.curr_comm
if idx not in out_gather:
if self.comm_list_backlog[idx]:
out_gather[idx] = self.comm_list_backlog[idx].pop(0)
elif x.is_open:
msg = x.recv_message(*args, **kwargs)
self.errors += x.errors
if msg.flag == CommBase.FLAG_EOF:
self.eof_recv[idx] = 1
if self.pattern == 'gather':
assert all((v.flag == CommBase.FLAG_EOF)
for v in out_gather.values())
out_gather[idx] = msg
elif sum(self.eof_recv) == len(self):
out_gather[idx] = msg
else:
x.finalize_message(msg)
elif msg.flag == CommBase.FLAG_SUCCESS:
out_gather[idx] = msg
self.curr_comm_index += 1
first_comm = False
if not complete():
self.sleep()
self.stop_timeout(key_suffix='recv:forkd')
if complete():
if self.pattern == 'cycle':
idx, out = next(iter(out_gather.items()))
args_copy = copy.deepcopy(out)
out.args = {idx: args_copy}
elif self.pattern == 'gather':
out = copy.deepcopy(next(iter(out_gather.values())))
out.args = {idx: v for idx, v in out_gather.items()}
# TODO: Gather header/type etc?
else:
for idx, v in out_gather.items():
self.comm_list_backlog[idx].append(v)
if self.is_closed:
self.debug('Comm closed')
out = CommBase.CommMessage(flag=CommBase.FLAG_FAILURE)
else:
out = CommBase.CommMessage(flag=CommBase.FLAG_EMPTY)
if self.pattern == 'cycle':
out.args = self.last_comm.empty_obj_recv
else:
out.args = []
return out
[docs] def finalize_message(self, msg, **kwargs):
r"""Perform actions to decipher a message.
Args:
msg (CommMessage): Initial message object to be finalized.
**kwargs: Keyword arguments are passed to the forked comm's
finalize_message method.
Returns:
CommMessage: Deserialized and annotated message.
"""
if msg.flag in [CommBase.FLAG_EOF, CommBase.FLAG_SUCCESS]:
msg.args = {
idx: self.comm_list[idx].finalize_message(
v, skip_python2language=True)
for idx, v in msg.args.items()}
if self.pattern == 'cycle':
assert len(msg.args) == 1
msg = next(iter(msg.args.values()))
elif msg.flag == CommBase.FLAG_EOF:
msg.args = msg.args[0].args
else:
msg.args = [msg.args[idx].args for idx in range(len(self))]
msg.finalized = False
return super(ForkComm, self).finalize_message(msg, **kwargs)
@property
def _multiple_first_send(self): # pragma: debug
return self.last_comm._multiple_first_send
@_multiple_first_send.setter
def _multiple_first_send(self, value):
for x in self.comm_list:
x._multiple_first_send = value
@property
def suppress_special_debug(self):
return self.last_comm.suppress_special_debug
@suppress_special_debug.setter
def suppress_special_debug(self, value):
for x in self.comm_list:
x.suppress_special_debug = value
@property
def _type_errors(self): # pragma: debug
return self.last_comm._type_errors
@_type_errors.setter
def _type_errors(self, value):
for x in self.comm_list:
x._type_errors = value
[docs] def purge(self):
r"""Purge all messages from the comm."""
super(ForkComm, self).purge()
for x in self.comm_list:
x.purge()
[docs] def drain_server_signon_messages(self, **kwargs):
r"""Drain server signon messages. This should only be used
for testing purposes."""
for x in self.comm_list:
x.drain_server_signon_messages(**kwargs)
[docs] def coerce_to_dict(self, msg, key_order, metadata):
r"""Convert a message to a dictionary.
Args:
msg (object): Message to convert to a dictionary.
key_order (list): Key order to use for the output dictionary.
metadata (dict): Header data to accompany the message.
Returns:
dict: Converted message.
"""
if self.pattern in ['scatter', 'gather']:
assert isinstance(msg, (list, tuple)) and (len(msg) == len(self))
out = [x.coerce_to_dict(msg[i], key_order, metadata)
for i, x in enumerate(self.comm_list)]
return out
return super(ForkComm, self).coerce_to_dict(msg, key_order, metadata)