import logging
import numpy as np
from yggdrasil.multitasking import _on_mpi, MPI, RLock, MPIRequestWrapper
from yggdrasil.communication import (
CommBase, NoMessages)
logger = logging.getLogger(__name__)
[docs]class MPIClosedError(BaseException):
r"""Exception to raise when the MPI connection has been closed."""
[docs]class MPIRequest(object):
r"""Container for MPI request."""
__slots__ = ['comm', 'address', 'direction', 'tag', 'size_req',
'req', 'size', '_data']
def __init__(self, comm, direction, address, tag, **kwargs):
self.comm = comm
self.address = address
self.direction = direction
self.tag = tag
self.size = np.zeros(1, dtype='i')
self._data = None
self.size_req = None
self.req = self.make_request(**kwargs)
@property
def data(self):
r"""object: Data returned by a request."""
if (self.direction == 'recv') and isinstance(self._data, np.ndarray):
self._data = self._data.tobytes()
return self._data
[docs] def make_request(self, payload=None):
r"""Complete a request."""
kwargs = dict(tag=self.tag)
if self.direction == 'send':
method = 'Isend'
args = ([payload, MPI.CHAR], )
kwargs['dest'] = self.address
# Send size of message
self.size[0] = len(payload)
self.size_req = MPIRequestWrapper(
self.comm.Isend([self.size, MPI.INT], **kwargs))
out = MPIRequestWrapper(
self.comm.Isend(*args, **kwargs))
else:
method = 'Irecv'
args = ([self.size, MPI.INT], )
kwargs['source'] = self.address
self.size_req = MPIRequestWrapper(
self.comm.Irecv(*args, **kwargs))
out = None
logger.debug("rank = %d, method = %s, args = %.100s, kwargs = %.100s",
self.comm.Get_rank(), method, args, kwargs)
return out
@property
def complete(self):
r"""bool: True if the request has been completed, False otherwise."""
if self.direction == 'recv':
if self.size_req.test()[0] and (self.req is None):
self._data = np.zeros(self.size[0], dtype='c')
self.req = MPIRequestWrapper(
self.comm.Irecv([self._data, MPI.CHAR],
source=self.address,
tag=self.tag))
return (((self.req is not None)
and (self.req.test()[0] or self.req.canceled))
or self.size_req.canceled)
[docs] def cancel(self):
r"""Cancel a request."""
self.size_req.cancel()
if self.req is not None: # pragma: intermittent
self.req.cancel()
[docs]class MPIMultiRequest(MPIRequest):
r"""Container for MPI request for multiple partner comms."""
__slots__ = ["remainder"]
def __init__(self, *args, **kwargs):
self.remainder = {}
super(MPIMultiRequest, self).__init__(*args, **kwargs)
[docs] def make_request(self, previous_requests=None):
r"""Complete a request."""
out = previous_requests
if out is None:
out = {}
for x in self.address:
if x not in out:
out[x] = MPIRequest(self.comm, self.direction, x, self.tag[x])
return out
@property
def complete(self):
r"""bool: True if the request has been completed, False otherwise."""
if not self.remainder:
for k, v in self.req.items():
if v.complete:
self._data = v.data
self.address = k
for k2, v2 in self.req.items():
if k2 != k:
self.remainder[k2] = v2
break
return bool(self.remainder)
[docs] def cancel(self):
r"""Cancel a request."""
for k, v in self.req.items():
v.cancel()
[docs]class MPIComm(CommBase.CommBase):
r"""Class for handling I/O via MPI communicators.
Args:
tag_start (int, optional): Tag that MPI messages should start with.
Defaults to 0.
tag_stride (int, optional): Amount that tag should be advanced after
each message to avoid conflicts w/ other MPIComm communicators.
Defaults to 1.
partner_mpi_ranks (list, optional): Rank of MPI processes that partner
models are running on. Defaults to None.
Attributes:
tag (int): Tag that should be used for the next MPI message.
tag_stride (int): Amount that tag should be advanced after each
each message to avoid conflicts w/ other MPIComm communicators.
"""
_commtype = 'mpi'
_schema_subtype_description = 'MPI communicator.'
address_description = "The partner communicator ID(s)."
_spacer_tags = 10
_max_response = 10
# Based on limit of 32bit int, this could be 2**30 - 1, but this is
# too large for stack allocation in C so 2**20 will be used in case a
# C implementation of the MPIComm is added in the future.
_maxMsgSize = 2**20
def __init__(self, *args, ranks=[], tag_start=0, tag_stride=1, **kwargs):
assert _on_mpi
if kwargs.get('partner_mpi_ranks', []):
assert kwargs.get('address', 'generate') in ['generate',
'address']
ranks = kwargs['partner_mpi_ranks']
self._request_lock = RLock(task_method='thread')
self.requests = []
self.unused_tags = {}
self.tags = {}
self.ranks = ranks
self.tag_start = tag_start
self.tag_stride = tag_stride
self.requires_disconnect = False
self.last_request = None
self.mpi_comm = MPI.COMM_WORLD
self._is_open = False
self.response_tags = []
self.eof_recv = {}
kwargs['no_suffix'] = True
super(MPIComm, self).__init__(*args, **kwargs)
self.eof_recv = {x: 0 for x in self.ranks}
[docs] @classmethod
def parse_address(cls, address):
r"""Parse an MPI address for information about the partner process
ranks and how the tags should be iterated.
Args:
address (str): Address to parse.
Returns:
tuple: The ranks, starting tag, and tag stride contained in the
address.
"""
rank_str, tag_start, tag_stride = address.split('_MPI_')
ranks = tuple([int(x) for x in rank_str.split('-')])
return ranks, int(tag_start), int(tag_stride)
@property
def tag(self):
r"""int: Tag for the next message."""
return self.get_tag(max(self.ranks, key=self.get_tag))
[docs] def next_rank(self):
r"""Get the rank that should be used next."""
if len(self.ranks) == 1:
return self.ranks[0]
elif self.direction == 'send':
return min(self.ranks, key=self.get_tag)
else:
return self.ranks
[docs] def get_tag(self, rank=None):
r"""Get the next tag for a rank.
Args:
rank (int): Rank to get tag for.
Returns:
int: Tag that should be used next for the rank.
"""
if rank is None:
rank = self.next_rank()
if isinstance(rank, (list, tuple)):
return {x: self.get_tag(x) for x in rank}
if self.unused_tags.get(rank, []):
return self.unused_tags[rank][0]
return self.tags[rank]
[docs] def advance_tag(self, request):
r"""Advance to the next tag.
Args:
request (MPIRequest, MPIMultiRequest): Request advancing the tag.
"""
if isinstance(request, MPIMultiRequest):
for v in request.req.values():
self.advance_tag(v)
return
if request.tag in self.unused_tags.get(request.address, []):
self.unused_tags[request.address].remove(request.tag)
return
self.tags[request.address] = max(self.tags[request.address],
request.tag + self.tag_stride)
[docs] def cache_tag(self, request):
r"""Store a tag for an uncompleted request.
Args:
request (MPIRequest, MPIMultiRequest): Request to cache.
"""
if isinstance(request, MPIMultiRequest):
for v in request.req.values():
self.cache_tag(v)
return
self.unused_tags.setdefault(request.address, [])
self.unused_tags[request.address].append(request.tag)
[docs] def bind(self):
r"""Bind to random queue if address is generate."""
assert isinstance(self.address, str)
if self.address in ['generate', 'address']:
assert self.ranks
self.address = self.format_address(self.ranks, self.tag_start,
self.tag_stride)
else:
self.ranks, self.tag_start, self.tag_stride = self.parse_address(
self.address)
if not self.tags:
self.tags = {x: self.tag_start for x in self.ranks}
super(MPIComm, self).bind()
@property
def model_env(self):
r"""dict: Mapping between model name and opposite comm
environment variables that need to be provided to the model."""
return {}
@property
def opp_address(self):
r"""str: Address for opposite comm."""
return self.format_address([self.mpi_comm.Get_rank()],
self.tag_start, self.tag_stride)
@property
def get_response_comm_kwargs(self):
r"""dict: Keyword arguments to use for a response comm."""
out = super(MPIComm, self).get_response_comm_kwargs
# tag = self.tag_start + len(self.response_tags) + 2
# tag_stride = self.tag_stride
# if (tag - self.tag_start) >= self._max_response:
# raise Exception("Starting tag for next response comm "
# "exceeds the maximum and may conflict with "
# "other messages this comm will send.")
tag = self.get_tag()
tag_stride = 0
assert tag not in self.response_tags
self.response_tags.append(tag)
out['address'] = self.format_address(
[self.next_rank()], tag, tag_stride)
return out
@property
def create_work_comm_kwargs(self):
r"""dict: Keyword arguments for a new work comm."""
out = super(MPIComm, self).create_work_comm_kwargs
out['address'] = self.format_address(
[self.next_rank()], self.get_tag(), 0)
return out
[docs] def open(self):
r"""Open the queue."""
super(MPIComm, self).open()
if not self.is_open:
self._is_open = True
assert self.mpi_comm.Get_rank() not in self.ranks
def _close(self, linger=False):
r"""Close the queue."""
# Disconnect will only be required if a subset of processes is used,
# but that is not currently supported. This should be uncommented if
# support for dynamic processes management is added.
# if self.requires_disconnect and self.is_open:
# self.mpi_comm.Disconnect()
with self._request_lock:
self.mpi_comm = None
self.cancel_requests()
super(MPIComm, self)._close(linger=linger)
[docs] def cancel_requests(self):
r"""Cancel requests that have not yet been completed."""
with self._request_lock:
complete_requests = []
for x in self.requests:
if x.complete:
complete_requests.append(x)
else:
self.cache_tag(x)
x.cancel()
# Cancel uncompleted partial request for multi-receive?
self.requests = complete_requests
@property
def is_open(self):
r"""bool: True if the queue is not None."""
return (self.mpi_comm is not None) and self._is_open
[docs] def confirm_send(self, noblock=False):
r"""Confirm that sent message was received."""
if noblock:
return True
return (self.n_msg_send == 0)
[docs] def confirm_recv(self, noblock=False):
r"""Confirm that message was received."""
if noblock:
return True
return (self.n_msg_recv == 0)
@property
def n_msg_send(self):
r"""int: Number of messages in the queue to send."""
if self.is_open and self.requests:
return sum([(not x.complete) for x in self.requests])
else:
return 0
@property
def n_msg_recv(self):
r"""int: Number of messages in the queue to recv."""
if self.is_open:
try:
self.add_request(on_empty=True)
except MPIClosedError: # pragma: intermittent
return 0
if self.is_open and self.requests:
return sum([x.complete for x in self.requests])
else:
return 0
[docs] def add_request(self, on_empty=False, **kwargs):
r"""Add a request to the queue."""
with self._request_lock:
if on_empty and self.requests:
return
cls = MPIRequest
address = self.next_rank()
tag = self.get_tag(address)
if isinstance(address, (list, tuple)):
if self.last_request:
kwargs['previous_requests'] = self.last_request.remainder
cls = MPIMultiRequest
if self.mpi_comm is None: # pragma: intermittent
raise MPIClosedError
args = (self.mpi_comm, self.direction, address, tag)
req = cls(*args, **kwargs)
self.requests.append(req)
self.advance_tag(req)
[docs] def send_message(self, msg, **kwargs):
r"""Send a message encapsulated in a CommMessage object.
Args:
msg (CommMessage): Message to be sent.
**kwargs: Additional keyword arguments are passed to _safe_send.
Returns:
bool: Success or failure of send.
"""
if (msg.flag == CommBase.FLAG_EOF) and (msg.args != 'DONT_RECURSE'):
for _ in range(len(self.ranks) - 1):
msg.add_message(msg=msg.msg, length=msg.length, flag=msg.flag,
args='DONT_RECURSE', header=msg.header)
return super(MPIComm, self).send_message(msg, **kwargs)
[docs] def recv_message(self, *args, **kwargs):
r"""Receive a message.
Args:
*args: Arguments are passed to the forked comm's recv_message method.
**kwargs: Keyword arguments are passed to the forked comm's recv_message
method.
Returns:
CommMessage: Received message.
"""
out = super(MPIComm, self).recv_message(*args, **kwargs)
self.sleep()
if out.flag == CommBase.FLAG_EOF:
self.eof_recv[self.last_request.address] = 1
if not all(self.eof_recv.values()):
out = self.recv_message(*args, **kwargs)
return out
def _send(self, payload):
r"""Send a message.
Args:
payload (str): Message to send.
Returns:
bool: Success or failure of sending the message.
"""
self.add_request(payload=payload)
return True
def _recv(self):
r"""Receive a message from the MPI communicator.
Returns:
tuple (bool, str): The success or failure of receiving a message
and the message received.
"""
with self._request_lock:
self.add_request(on_empty=True)
if not self.requests[0].complete:
raise NoMessages("No messages in communicator.")
self.last_request = self.requests.pop(0)
return (True, self.last_request.data)
[docs] def purge(self):
r"""Purge all messages from the comm."""
super(MPIComm, self).purge()
with self._request_lock:
self.cancel_requests()
while self.n_msg_recv > 0: # pragma: intermittent
self._recv()