Source code for yggdrasil.drivers.RPCRequestDriver

from yggdrasil import constants
from yggdrasil.drivers.ConnectionDriver import ConnectionDriver, run_remotely
from yggdrasil.drivers.RPCResponseDriver import RPCResponseDriver
from yggdrasil.communication import CommBase


[docs]class RPCRequestDriver(ConnectionDriver): r"""Class for handling client side RPC type communication. Args: model_request_name (str): The name of the channel used by the client model to send requests. **kwargs: Additional keyword arguments are passed to parent class. Attributes: response_drivers (dict): Response drivers created for each request. """ _connection_type = 'rpc_request' def __init__(self, model_request_name, response_kwargs=None, **kwargs): # Input communicator inputs = kwargs.get('inputs', [{}]) # inputs[0]['name'] = model_request_name + '.client_model_request' kwargs['inputs'] = inputs # Output communicator outputs = kwargs.get('outputs', [{}]) # outputs[0]['name'] = model_request_name + '.server_model_request' outputs[0]['is_client'] = True outputs[0]['close_on_eof_send'] = False kwargs['outputs'] = outputs if response_kwargs is None: response_kwargs = {} self.response_kwargs = response_kwargs # Parent and attributes super(RPCRequestDriver, self).__init__(model_request_name, **kwargs) self.response_drivers = {} self._block_response = False @property def servers_recvd(self): r"""list: Names of server models that have returned responses.""" out = {} for x in self.response_drivers.values(): for k, v in x.models_recvd.items(): out.setdefault(k, 0) out[k] += v return out @property @run_remotely def clients(self): r"""list: Clients that are connected.""" return self.models['input'].copy() @property @run_remotely def nclients(self): r"""int: Number of clients that are connected.""" return len(self.clients) @property def model_env(self): r"""dict: Mapping between model name and opposite comm environment variables that need to be provided to the model.""" out = super(RPCRequestDriver, self).model_env # Add is_rpc flag to output model env variables for k in self.ocomm.model_env.keys(): out[k]['YGG_IS_SERVER'] = 'True' return out
[docs] def close_response_drivers(self): r"""Close response driver.""" with self.lock: self.debug("Closing response drivers.") self._block_response = True for x in self.response_drivers.values(): x.terminate() self.response_drivers = {}
[docs] def close_comm(self): r"""Close response drivers.""" self.close_response_drivers() super(RPCRequestDriver, self).close_comm()
[docs] def printStatus(self, *args, **kwargs): r"""Also print response drivers.""" out = super(RPCRequestDriver, self).printStatus(*args, **kwargs) for x in self.response_drivers.values(): x_out = x.printStatus(*args, **kwargs) if kwargs.get('return_str', False): out += x_out return out
[docs] @run_remotely def remove_model(self, direction, name): r"""Remove a model from the list of models. Args: direction (str): Direction of model. name (str): Name of model exiting. Returns: bool: True if all of the input/output models have signed off; False otherwise. """ with self.lock: clients = self.clients if (direction == "input") and (name in clients) and (len(clients) > 1): super(RPCRequestDriver, self).send_message( CommBase.CommMessage(args=constants.YGG_CLIENT_EOF, flag=CommBase.FLAG_SUCCESS), header_kwargs={'raw': True, '__meta__': {'model': name}}, skip_processing=True) out = super(RPCRequestDriver, self).remove_model( direction, name) if out: self.send_eof(header_kwargs={'__meta__': {'model': name}}) return out
# def send_eof(self): # r"""Send EOF message. # Returns: # bool: Success or failure of send. # """ # if self.ocomm.partner_copies > 1: # self.ocomm.partner_copies = len(self.servers_recvd) # return super(RPCRequestDriver, self).send_eof()
[docs] def on_eof(self, msg): r"""On EOF, decrement number of clients. Only send EOF if the number of clients drops to 0. Args: msg (CommMessage): Message object that provided the EOF. Returns: CommMessage, bool: Value that should be returned by recv_message on EOF. """ with self.lock: self.remove_model('input', msg.header['__meta__'].get('model', '')) if self.nclients == 0: self.debug("All clients have signed off (EOF).") return super(RPCRequestDriver, self).on_eof(msg) return CommBase.CommMessage(flag=CommBase.FLAG_EMPTY, args=self.icomm.empty_obj_recv)
[docs] def before_loop(self): r"""Send client sign on to server response driver.""" super(RPCRequestDriver, self).before_loop() self.ocomm._send_serializer = True
[docs] def send_message(self, msg, **kwargs): r"""Start a response driver for a request message and send message with header. Args: msg (CommMessage): Message being sent. **kwargs: Keyword arguments are passed to parent class send_message. Returns: bool: Success or failure of send. """ if self.ocomm.is_closed: return False # Start response driver if msg.flag != CommBase.FLAG_EOF: # Remove client that signed off if ((msg.header.get('raw', False) and (msg.args == constants.YGG_CLIENT_EOF))): # pragma: intermittent self.remove_model('input', msg.header['__meta__']['model']) return True with self.lock: if (not self.is_comm_open) or self._block_response: # pragma: debug self.debug("Comm closed, not creating response driver.") return False key = msg.header['__meta__']['response_address'] if self.ocomm._commtype == 'fork': key = (msg.header['__meta__']['response_address'], self.ocomm.curr_comm_index % len(self.ocomm)) if key in self.response_drivers: response_driver = self.response_drivers[key] else: response_kwargs = self.response_kwargs.copy() response_kwargs.update( self.ocomm.get_response_comm_kwargs) drv_args = [msg.header['__meta__']['response_address'], msg.header['__meta__']['request_id']] drv_kwargs = dict( request_name=self.name, inputs=[response_kwargs], outputs=[{'commtype': msg.header["commtype"]}]) self.debug("Creating response comm: address = %s, request_id = %s", msg.header['__meta__']['response_address'], msg.header['__meta__']['request_id']) try: response_driver = RPCResponseDriver( *drv_args, **drv_kwargs) self.response_drivers[key] = response_driver response_driver.start() self.debug("Started response comm: address = %s, request_id = %s", msg.header['__meta__']['response_address'], msg.header['__meta__']['request_id']) except BaseException: # pragma: debug self.exception("Could not create/start response driver.") return False # Send response address in header kwargs.setdefault('header_kwargs', {}) kwargs['header_kwargs'].setdefault('__meta__', {}) kwargs['header_kwargs']['__meta__'].setdefault( 'response_address', response_driver.response_address) kwargs['header_kwargs']['__meta__'].setdefault( 'request_id', msg.header['__meta__']['request_id']) kwargs['header_kwargs']['__meta__'].setdefault( 'model', msg.header['__meta__'].get('model', '')) return super(RPCRequestDriver, self).send_message(msg, **kwargs)
[docs] def run_loop(self): r"""Run the driver. Continue looping over messages until there are not any left or the communication channel is closed. """ super(RPCRequestDriver, self).run_loop() if not self.was_break: self.prune_response_drivers()
[docs] def prune_response_drivers(self): r"""Promote errors from response drivers.""" with self.lock: for x in self.response_drivers.values(): self.errors += x.errors