from yggdrasil import constants
from yggdrasil.drivers.ConnectionDriver import ConnectionDriver, run_remotely
from yggdrasil.drivers.RPCResponseDriver import RPCResponseDriver
from yggdrasil.communication import CommBase
[docs]class RPCRequestDriver(ConnectionDriver):
r"""Class for handling client side RPC type communication.
Args:
model_request_name (str): The name of the channel used by the client
model to send requests.
**kwargs: Additional keyword arguments are passed to parent class.
Attributes:
response_drivers (dict): Response drivers created for each request.
"""
_connection_type = 'rpc_request'
def __init__(self, model_request_name, response_kwargs=None, **kwargs):
# Input communicator
inputs = kwargs.get('inputs', [{}])
# inputs[0]['name'] = model_request_name + '.client_model_request'
kwargs['inputs'] = inputs
# Output communicator
outputs = kwargs.get('outputs', [{}])
# outputs[0]['name'] = model_request_name + '.server_model_request'
outputs[0]['is_client'] = True
outputs[0]['close_on_eof_send'] = False
kwargs['outputs'] = outputs
if response_kwargs is None:
response_kwargs = {}
self.response_kwargs = response_kwargs
# Parent and attributes
super(RPCRequestDriver, self).__init__(model_request_name, **kwargs)
self.response_drivers = {}
self._block_response = False
@property
def servers_recvd(self):
r"""list: Names of server models that have returned responses."""
out = {}
for x in self.response_drivers.values():
for k, v in x.models_recvd.items():
out.setdefault(k, 0)
out[k] += v
return out
@property
@run_remotely
def clients(self):
r"""list: Clients that are connected."""
return self.models['input'].copy()
@property
@run_remotely
def nclients(self):
r"""int: Number of clients that are connected."""
return len(self.clients)
@property
def model_env(self):
r"""dict: Mapping between model name and opposite comm
environment variables that need to be provided to the model."""
out = super(RPCRequestDriver, self).model_env
# Add is_rpc flag to output model env variables
for k in self.ocomm.model_env.keys():
out[k]['YGG_IS_SERVER'] = 'True'
return out
[docs] def close_response_drivers(self):
r"""Close response driver."""
with self.lock:
self.debug("Closing response drivers.")
self._block_response = True
for x in self.response_drivers.values():
x.terminate()
self.response_drivers = {}
[docs] def close_comm(self):
r"""Close response drivers."""
self.close_response_drivers()
super(RPCRequestDriver, self).close_comm()
[docs] def printStatus(self, *args, **kwargs):
r"""Also print response drivers."""
out = super(RPCRequestDriver, self).printStatus(*args, **kwargs)
for x in self.response_drivers.values():
x_out = x.printStatus(*args, **kwargs)
if kwargs.get('return_str', False):
out += x_out
return out
[docs] @run_remotely
def remove_model(self, direction, name):
r"""Remove a model from the list of models.
Args:
direction (str): Direction of model.
name (str): Name of model exiting.
Returns:
bool: True if all of the input/output models have signed
off; False otherwise.
"""
with self.lock:
clients = self.clients
if (direction == "input") and (name in clients) and (len(clients) > 1):
super(RPCRequestDriver, self).send_message(
CommBase.CommMessage(args=constants.YGG_CLIENT_EOF,
flag=CommBase.FLAG_SUCCESS),
header_kwargs={'raw': True, '__meta__': {'model': name}},
skip_processing=True)
out = super(RPCRequestDriver, self).remove_model(
direction, name)
if out:
self.send_eof(header_kwargs={'__meta__': {'model': name}})
return out
# def send_eof(self):
# r"""Send EOF message.
# Returns:
# bool: Success or failure of send.
# """
# if self.ocomm.partner_copies > 1:
# self.ocomm.partner_copies = len(self.servers_recvd)
# return super(RPCRequestDriver, self).send_eof()
[docs] def on_eof(self, msg):
r"""On EOF, decrement number of clients. Only send EOF if the number
of clients drops to 0.
Args:
msg (CommMessage): Message object that provided the EOF.
Returns:
CommMessage, bool: Value that should be returned by recv_message on EOF.
"""
with self.lock:
self.remove_model('input', msg.header['__meta__'].get('model', ''))
if self.nclients == 0:
self.debug("All clients have signed off (EOF).")
return super(RPCRequestDriver, self).on_eof(msg)
return CommBase.CommMessage(flag=CommBase.FLAG_EMPTY,
args=self.icomm.empty_obj_recv)
[docs] def before_loop(self):
r"""Send client sign on to server response driver."""
super(RPCRequestDriver, self).before_loop()
self.ocomm._send_serializer = True
[docs] def send_message(self, msg, **kwargs):
r"""Start a response driver for a request message and send message with
header.
Args:
msg (CommMessage): Message being sent.
**kwargs: Keyword arguments are passed to parent class send_message.
Returns:
bool: Success or failure of send.
"""
if self.ocomm.is_closed:
return False
# Start response driver
if msg.flag != CommBase.FLAG_EOF:
# Remove client that signed off
if ((msg.header.get('raw', False)
and (msg.args == constants.YGG_CLIENT_EOF))): # pragma: intermittent
self.remove_model('input', msg.header['__meta__']['model'])
return True
with self.lock:
if (not self.is_comm_open) or self._block_response: # pragma: debug
self.debug("Comm closed, not creating response driver.")
return False
key = msg.header['__meta__']['response_address']
if self.ocomm._commtype == 'fork':
key = (msg.header['__meta__']['response_address'],
self.ocomm.curr_comm_index % len(self.ocomm))
if key in self.response_drivers:
response_driver = self.response_drivers[key]
else:
response_kwargs = self.response_kwargs.copy()
response_kwargs.update(
self.ocomm.get_response_comm_kwargs)
drv_args = [msg.header['__meta__']['response_address'],
msg.header['__meta__']['request_id']]
drv_kwargs = dict(
request_name=self.name,
inputs=[response_kwargs],
outputs=[{'commtype': msg.header["commtype"]}])
self.debug("Creating response comm: address = %s, request_id = %s",
msg.header['__meta__']['response_address'],
msg.header['__meta__']['request_id'])
try:
response_driver = RPCResponseDriver(
*drv_args, **drv_kwargs)
self.response_drivers[key] = response_driver
response_driver.start()
self.debug("Started response comm: address = %s, request_id = %s",
msg.header['__meta__']['response_address'],
msg.header['__meta__']['request_id'])
except BaseException: # pragma: debug
self.exception("Could not create/start response driver.")
return False
# Send response address in header
kwargs.setdefault('header_kwargs', {})
kwargs['header_kwargs'].setdefault('__meta__', {})
kwargs['header_kwargs']['__meta__'].setdefault(
'response_address', response_driver.response_address)
kwargs['header_kwargs']['__meta__'].setdefault(
'request_id', msg.header['__meta__']['request_id'])
kwargs['header_kwargs']['__meta__'].setdefault(
'model', msg.header['__meta__'].get('model', ''))
return super(RPCRequestDriver, self).send_message(msg, **kwargs)
[docs] def run_loop(self):
r"""Run the driver. Continue looping over messages until there are not
any left or the communication channel is closed.
"""
super(RPCRequestDriver, self).run_loop()
if not self.was_break:
self.prune_response_drivers()
[docs] def prune_response_drivers(self):
r"""Promote errors from response drivers."""
with self.lock:
for x in self.response_drivers.values():
self.errors += x.errors