Source code for yggdrasil.runner

"""This module provides tools for running models using yggdrasil."""
import sys
import os
import time
import copy
import signal
import atexit
from pprint import pformat
from itertools import chain
import socket
from collections import OrderedDict
from yggdrasil.tools import YggClass
from yggdrasil.config import ygg_cfg, cfg_environment, temp_config
from yggdrasil import platform, yamlfile
from yggdrasil.drivers import create_driver
from yggdrasil.components import import_component
from yggdrasil.multitasking import MPI
from yggdrasil.drivers.DuplicatedModelDriver import DuplicatedModelDriver
from yggdrasil.drivers.ModelDriver import ModelDriver


COLOR_TRACE = '\033[30;43;22m'
COLOR_NORMAL = '\033[0m'


[docs]class IntegrationError(BaseException): r"""Error raised when there is an error in an integration.""" pass
[docs]class YggFunction(YggClass): r"""This class wraps function-like behavior around a model. Args: model_yaml (str, list): Full path to one or more YAML specification files containing information defining a partial integration. If service_address is set, this should be the name of a service registered with the service manager running at the provided address. service_address (str, optional): Address for service manager that is capable of running the specified integration. Defaults to None and is ignored. **kwargs: Additional keyword arguments are passed to the YggRunner constructor. Attributes: outputs (dict): Input channels providing access to model output. inputs (dict): Output channels providing access to model input. runner (YggRunner): Runner for model. """ def __init__(self, model_yaml, service_address=None, **kwargs): import uuid from yggdrasil.languages.Python.YggInterface import ( YggInput, YggOutput, YggRpcClient) super(YggFunction, self).__init__() # Create and start runner in another process self.dummy_name = 'func' + str(uuid.uuid4()).split('-')[0] kwargs['complete_partial'] = self.dummy_name if service_address: # Temporary YAML describing the service contents = (f'service:\n' f' name: {model_yaml}\n' f' address: {service_address}\n') model_yaml = os.path.join(os.getcwd(), self.dummy_name + '.yml') with open(model_yaml, 'w') as fd: fd.write(contents) self.runner = YggRunner(model_yaml, **kwargs) # Start the drivers self.runner.run() self.model_driver = self.runner.modeldrivers[self.dummy_name] for k in self.runner.modeldrivers.keys(): if k != self.dummy_name: self.__name__ = k break self.debug("run started") # Create input/output channels self.inputs = {} self.outputs = {} # import zmq; ctx = zmq.Context() self.old_environ = os.environ.copy() for drv in self.model_driver['input_drivers']: for env in drv['instance'].model_env.values(): os.environ.update(env) channel_name = drv['instance'].ocomm.name var_name = drv['name'].split('function_')[-1] self.outputs[var_name] = drv.copy() self.outputs[var_name]['comm'] = YggInput( channel_name, no_suffix=True) # context=ctx) if 'vars' in drv['inputs'][0]: self.outputs[var_name]['vars'] = drv['inputs'][0]['vars'] else: self.outputs[var_name]['vars'] = [var_name] for drv in self.model_driver['output_drivers']: for env in drv['instance'].model_env.values(): os.environ.update(env) channel_name = drv['instance'].icomm.name var_name = drv['name'].split('function_')[-1] self.inputs[var_name] = drv.copy() if drv['instance']._connection_type == 'rpc_request': self.inputs[var_name]['comm'] = YggRpcClient( channel_name, no_suffix=True) self.outputs[var_name] = drv.copy() self.outputs[var_name]['comm'] = self.inputs[var_name]['comm'] if drv['outputs'][0].get('server_replaces', False): srv = drv['outputs'][0]['server_replaces'] self.inputs[var_name]['vars'] = srv['input']['vars'] self.outputs[var_name]['vars'] = srv['output']['vars'] else: self.inputs[var_name]['comm'] = YggOutput( channel_name, no_suffix=True) # context=ctx) if 'vars' in drv['outputs'][0]: self.inputs[var_name]['vars'] = drv['outputs'][0]['vars'] else: self.inputs[var_name]['vars'] = [var_name] self.debug('inputs: %s, outputs: %s', list(self.inputs.keys()), list(self.outputs.keys())) self._stop_called = False atexit.register(self.stop) # Ensure that vars are strings for k, v in chain(self.inputs.items(), self.outputs.items()): v_vars = [] for iv in v['vars']: if isinstance(iv, dict): if not iv.get('is_length_var', False): v_vars.append(iv['name']) else: v_vars.append(iv) v['vars'] = v_vars # Get arguments self.arguments = [] for k, v in self.inputs.items(): self.arguments += v['vars'] self.returns = [] for k, v in self.outputs.items(): self.returns += v['vars'] self.debug("arguments: %s, returns: %s", self.arguments, self.returns) self.runner.pause() if service_address: os.remove(model_yaml) # def widget_function(self, *args, **kwargs): # # import matplotlib.pyplot as plt # # ncols = min(3, len(arguments)) # # nrows = int(ceil(float(len(arguments))/float(ncols))) # # plt.show() # out = self(*args, **kwargs) # return out # def widget(self, *args, **kwargs): # from ipywidgets import interact_manual # return interact_manual(self.widget_function, *args, **kwargs)
[docs] def __call__(self, *args, **kwargs): r"""Call the model as a function by sending variables. Args: *args: Any positional arguments are expected to be input variables in the correct order. **kwargs: Any keyword arguments are expected to be named input variables for the model. Raises: RuntimeError: If an input argument is missing. RuntimeError: If sending an input argument to a model fails. RuntimeError: If receiving an output value from a model fails. Returns: dict: Returned values for each return variable. """ self.runner.resume() # Check for arguments for a, arg in zip(self.arguments, args): assert a not in kwargs kwargs[a] = arg for a in self.arguments: if a not in kwargs: # pragma: debug raise RuntimeError("Required argument %s not provided." % a) # Send for k, v in self.inputs.items(): flag = v['comm'].send([kwargs[a] for a in v['vars']]) if not flag: # pragma: debug raise RuntimeError("Failed to send %s" % k) # Receive out = {} for k, v in self.outputs.items(): flag, data = v['comm'].recv(timeout=60.0) if not flag: # pragma: debug raise RuntimeError("Failed to receive variable %s" % v) ivars = v['vars'] if isinstance(data, (list, tuple)): assert len(data) == len(ivars) for a, d in zip(ivars, data): out[a] = d else: assert len(ivars) == 1 out[ivars[0]] = data self.runner.pause() return out
[docs] def stop(self): r"""Stop the model(s) from running.""" self.runner.resume() if self._stop_called: return self._stop_called = True for x in self.inputs.values(): x['comm'].send_eof() self.model_driver['instance'].set_break_flag() self.runner.waitModels(timeout=10) for x in self.inputs.values(): x['comm'].close() for x in self.outputs.values(): x['comm'].close() self.runner.terminate() self.runner.atexit() os.environ.clear() os.environ.update(self.old_environ)
[docs] def model_info(self): r"""Display information about the wrapped model(s).""" print("Models: %s\nInputs:\n%s\nOutputs:\n%s\n" % (', '.join([x['name'] for x in self.runner.modeldrivers.values() if x['name'] != self.dummy_name]), '\n'.join(['\t%s (vars=%s)' % (k, v['vars']) for k, v in self.inputs.items()]), '\n'.join(['\t%s (vars=%s)' % (k, v['vars']) for k, v in self.outputs.items()])))
[docs]class YggRunner(YggClass): r"""This class handles the orchestration of starting the model and IO drivers, monitoring their progress, and cleaning up on exit. Arguments: modelYmls (list): List of paths to yaml files specifying the models that should be run. namespace (str, optional): Name that should be used to uniquely identify any RMQ exchange. Defaults to the value in the config file. host (str, optional): Name of the host that the models will be launched from. Defaults to None. rank (int, optional): Rank of this set of models if run in parallel. Defaults to 0. ygg_debug_level (str, optional): Level for Ygg debug messages. Defaults to environment variable 'YGG_DEBUG'. rmq_debug_level (str, optional): Level for RabbitMQ debug messages. Defaults to environment variable 'RMQ_DEBUG'. ygg_debug_prefix (str, optional): Prefix for Ygg debug messages. Defaults to namespace. as_service (bool, optional): If True, the integration is running as a service. If True, complete_partial is set to True. Defaults to False. complete_partial (bool, optional): If True, unpaired input/output channels are allowed and reserved for use (e.g. for calling the model as a function). Defaults to False. partial_commtype (dict, optional): Communicator kwargs that should be be used for the connections to the unpaired channels when complete_partial is True. Defaults to None and will be ignored. yaml_param (dict, optional): Parameters that should be used in mustache formatting of YAML files. Defaults to None and is ignored. validate (bool, optional): If True, the validation scripts for each modle (if present), will be run after the integration finishes running. Defaults to False. with_debugger (str, optional): Tool (and any flags for the tool) that should be used to run models. disable_python_c_api (bool, optional): If True, the Python C API will be disabled. Defaults to False. with_asan (bool, optional): Compile and run all models with the address sanitizer. Defaults to False. Attributes: namespace (str): Name that should be used to uniquely identify any RMQ exchange. host (str): Name of the host that the models will be launched from. rank (int): Rank of this set of models if run in parallel. modeldrivers (dict): Model drivers associated with this run. connectiondrivers (dict): Connection drivers for this run. interrupt_time (float): Time of last interrupt signal. error_flag (bool): True if one or more models raises an error. ..todo:: namespace, host, and rank do not seem strictly necessary. """ def __init__(self, modelYmls, namespace=None, host=None, rank=0, ygg_debug_level=None, rmq_debug_level=None, ygg_debug_prefix=None, connection_task_method='thread', as_service=False, complete_partial=False, partial_commtype=None, production_run=False, mpi_tag_start=None, yaml_param=None, validate=False, with_debugger=None, disable_python_c_api=False, with_asan=False): kwargs_models = {'with_debugger': with_debugger, 'disable_python_c_api': disable_python_c_api, 'with_asan': with_asan} self.mpi_comm = None name = 'runner' if MPI is not None: comm = MPI.COMM_WORLD if comm.Get_size() > 1: self.mpi_comm = comm rank = comm.Get_rank() name += str(rank) super(YggRunner, self).__init__(name) if namespace is None: namespace = ygg_cfg.get('rmq', 'namespace', False) if not namespace: # pragma: debug raise Exception('rmq:namespace not set in config file') if as_service: complete_partial = True self.namespace = namespace self.host = host self.rank = rank self.connection_task_method = connection_task_method self.base_dup = {} self.modelcopies = {} self.modeldrivers = {} self.connectiondrivers = {} self.interrupt_time = 0 self._old_handlers = {} self.production_run = production_run self.error_flag = False self.complete_partial = complete_partial self.partial_commtype = partial_commtype self.validate = validate self.debug("Running in %s with path %s namespace %s rank %d", os.getcwd(), sys.path, namespace, rank) # Update environment based on config cfg_environment() # Parse yamls self.mpi_tag_start = mpi_tag_start if self.mpi_comm and (self.rank > 0): pass else: self.drivers = yamlfile.parse_yaml( modelYmls, complete_partial=complete_partial, partial_commtype=partial_commtype, yaml_param=yaml_param) self.connectiondrivers = self.drivers['connection'] self.modeldrivers = self.drivers['model'] for k, v in kwargs_models.items(): if not v: continue for x in self.modeldrivers.values(): x[k] = v for x in self.modeldrivers.values(): if x['driver'] == 'DummyModelDriver': x['runner'] = self if as_service: for io in x['output_drivers']: for comm in io['inputs']: comm['for_service'] = True for io in x['input_drivers']: for comm in io['outputs']: comm['for_service'] = True
[docs] def pprint(self, *args): r"""Print with color.""" s = ''.join(str(i) for i in args) print((COLOR_TRACE + '{}' + COLOR_NORMAL).format(s))
[docs] def atexit(self, *args, **kwargs): r"""At exit ensure that the runner has stopped and cleaned up.""" self.debug('') self.reset_signal_handler() self.closeChannels() self.cleanup()
[docs] def signal_handler(self, sig, frame): r"""Terminate all drivers on interrrupt.""" self.debug("Interrupt with signal %d", sig) now = time.perf_counter() elapsed = now - self.interrupt_time self.debug('Elapsed time since last interrupt: %d s', elapsed) self.interrupt_time = now self.pprint(' ') self.pprint(80 * '*') if elapsed < 5: self.pprint('* %76s *' % 'Interrupted twice within 5 seconds: shutting down') self.pprint(80 * '*') self.debug("Terminating models and closing all channels") self.terminate() self.pprint(80 * '*') return 1 else: self.pprint('* %76s *' % 'Interrupted: Displaying channel summary') self.pprint('* %76s *' % 'interrupt again (within 5 seconds) to exit') self.pprint(80 * '*') self.printStatus() self.pprint(80 * '*') self.debug('%d returns', sig)
def _swap_handler(self, signum, signal_handler): self._old_handlers[signum] = signal.getsignal(signum) signal.signal(signum, signal_handler) if not platform._is_win: signal.siginterrupt(signum, False)
[docs] def set_signal_handler(self, signal_handler=None): r"""Set the signal handler. Args: signal_handler (function, optional): Function that should handle received SIGINT and SIGTERM signals. Defaults to self.signal_handler. """ if signal_handler is None: signal_handler = self.signal_handler if signal_handler: self._swap_handler(signal.SIGINT, signal_handler) if not platform._is_win: self._swap_handler(signal.SIGTERM, signal_handler) else: # pragma: windows self._swap_handler(signal.SIGBREAK, signal_handler)
[docs] def reset_signal_handler(self): r"""Reset signal handlers to old ones.""" for k, v in self._old_handlers.items(): signal.signal(k, v)
[docs] def run(self, signal_handler=None, timer=None, t0=None): r"""Run all of the models and wait for them to exit. Args: signal_handler (function, optional): Function that should be used as a signal handler. Defaults to None and is set by set_signal_handler. timer (function, optional): Function that should be called to get intermediate timing statistics. Defaults to time.time if not provided. t0 (float, optional): Zero point for timing statistics. Is set using the provided timer if not provided. Returns: dict: Intermediate times from the run. """ with temp_config(production_run=self.production_run): if timer is None: timer = time.time if t0 is None: t0 = timer() times = OrderedDict() times['init'] = timer() self.loadDrivers() times['load drivers'] = timer() self.startDrivers() times['start drivers'] = timer() self.set_signal_handler(signal_handler) if not self.complete_partial: self.waitModels() times['run models'] = timer() self.atexit() times['at exit'] = timer() tprev = t0 for k, t in times.items(): self.info('%20s\t%f', k, t - tprev) tprev = t self.info(40 * '=') self.info('%20s\t%f', "Total", tprev - t0) if self.error_flag: raise IntegrationError("Error running the integration.") if self.validate: for v in self.modeldrivers.values(): v['instance'].run_validation() return times
@property def all_drivers(self): r"""iterator: For all drivers.""" return chain(self.connectiondrivers.values(), self.modeldrivers.values())
[docs] def io_drivers(self): r"""Return the input and output drivers for all models. Returns: iterator: Access to list of I/O drivers. """ return self.connectiondrivers.values()
[docs] def create_driver(self, yml, **kwargs): r"""Create a driver instance from the yaml information. Args: yml (yaml): Yaml object containing driver information. Returns: object: An instance of the specified driver. """ self.debug('Creating %s, a %s', yml['name'], yml['driver']) curpath = os.getcwd() if 'working_dir' in yml: os.chdir(yml['working_dir']) try: if (yml.get('copies', 1) > 1) and ('copy_index' not in yml): instance = DuplicatedModelDriver( yml, namespace=self.namespace, rank=self.rank, duplicates=yml.pop('duplicates', None), **kwargs) else: kwargs = dict(yml, **kwargs) instance = create_driver(yml=yml, namespace=self.namespace, rank=self.rank, **kwargs) yml['instance'] = instance finally: os.chdir(curpath) return instance
[docs] def get_models(self, name, rank=None): r"""Get the set of drivers referenced by a model name. Args: name (str, list): Name of model(s). rank (int, optional): If provided, only models that will run on MPI processes with this rank will be returned. Defaults to None and is ignored. Returns: list: Set of drivers for a model. """ if isinstance(name, list): models = [] for x in name: models += self.get_models(x, rank=rank) elif name in self.modelcopies: models = [self.modeldrivers[cpy] for cpy in self.modelcopies[name]] elif name in self.modeldrivers: models = [self.modeldrivers[name]] else: models = [self.modeldrivers[ DuplicatedModelDriver.get_base_name(name)]] assert models[0].get('copies', 0) > 1 if rank is not None: # pragma: debug # models = [x for x in models if (x['mpi_rank'] == rank)] raise NotImplementedError return models
[docs] def bridge_mpi_connections(self, yml): r"""Bridge connections over MPI processes.""" from yggdrasil.communication.MPIComm import MPIComm io_map = {'inputs': 'outputs', 'outputs': 'inputs'} models = {} for io in io_map.keys(): models[io[:-1]] = [x['name'] for x in self.get_models( [x.get('partner_model', None) for x in yml[io] if x.get('partner_model', None)])] for io, io_opp in io_map.items(): for x in yml[io]: model = x.get('partner_model', None) if not model: continue rank_map = {} for m in self.get_models(model): rank_map.setdefault(m['mpi_rank'], []) rank_map[m['mpi_rank']].append(m) if not any(rank > 0 for rank in rank_map.keys()): continue if 'models' not in yml: yml['models'] = models comms = [] for rank in rank_map.keys(): x_copy = dict(copy.deepcopy(x), partner_copies=len(rank_map[rank])) if rank == 0: icomm = x_copy else: icomm = dict( commtype='mpi', daemon=True, ranks=[rank], mpi_index=len(self._mpi_comms), mpi_direction=io_opp, mpi_stride=1, mpi_driver={ io_opp: [{'commtype': 'mpi', 'ranks': [0], 'daemon': True}], io: [x_copy], 'driver': yml['driver'], 'name': ( '%s_mpi%s_%s' % (yml['name'], rank, io)), 'models': { io_opp[:-1]: models[io_opp[:-1]], io[:-1]: [m['name'] for m in rank_map[rank]]}}) if yml['driver'].startswith('RPC'): icomm['mpi_stride'] += MPIComm._max_response self._mpi_comms.append(icomm) for m in rank_map[rank]: drv_key = 'mpi_%s_drivers' % io_opp[:-1] m.setdefault(drv_key, []) m[drv_key].append(icomm['mpi_driver']['name']) comms.append(icomm) if len(comms) == 1: x.update(comms[0]) self._mpi_comms[comms[0]['mpi_index']] = x else: # TODO: Move to connection level? x.clear() x['commtype'] = comms if yml['driver'].startswith('RPC'): x['pattern'] = 'cycle'
[docs] def create_connection_driver(self, yml): r"""Create a connection driver instance from the yaml information. Args: yml (yaml): Yaml object containing driver information. Returns: object: An instance of the specified driver. """ yml['task_method'] = self.connection_task_method drv = self.create_driver(yml) # Transfer connection addresses to model via env # TODO: Change to server that tracks connections for model, env in drv.model_env.items(): env_key = 'env' if (model not in self.modelcopies) and (model not in self.modeldrivers): env_key = 'env_%s' % model for x in self.get_models(model): x.setdefault(env_key, {}) x[env_key].update(env) return drv
[docs] def distribute_mpi(self): r"""Distribute models between MPI processes.""" size = self.mpi_comm.Get_size() if self.rank == 0: from yggdrasil.communication.MPIComm import MPIComm self.expand_duplicates() # Set the rank and index for each model for i, v in enumerate(self.modeldrivers.values()): v['mpi_rank'] = (i + 1) % size v['model_index'] = i v['mpi_tag_start'] = self.mpi_tag_start # Split the connections bridging MPI processes self.debug("Splitting connection drivers over MPI") self.all_connectiondrivers = self.connectiondrivers self._mpi_comms = [] for driver in self.connectiondrivers.values(): self.bridge_mpi_connections(driver) tag_start = len(ModelDriver._mpi_tags) * len(self.modeldrivers) * 5 if self.mpi_tag_start is not None: tag_start += self.mpi_tag_start tag_stride = sum([x.pop('mpi_stride') for x in self._mpi_comms]) connections = [[] for _ in range(size)] for x in self._mpi_comms: x['tag_start'] = tag_start + x.pop('mpi_index') * MPIComm._spacer_tags x['tag_stride'] = tag_stride io = x.pop('mpi_direction') drv = x.pop('mpi_driver') drv[io][0]['tag_start'] = x['tag_start'] drv[io][0]['tag_stride'] = x['tag_stride'] connections[x['ranks'][0]].append((drv['name'], drv)) max_len = len(max(connections, key=len)) for x in connections: while len(x) < max_len: x.append(None) # Sort models self.all_modeldrivers = self.modeldrivers models = [[] for _ in range(size)] for i, (k, v) in enumerate(self.modeldrivers.items()): x_cp = copy.deepcopy(v) for k2 in ['input_drivers', 'output_drivers', 'mpi_rank']: x_cp.pop(k2, None) for k2 in ['input_drivers', 'output_drivers']: x_cp[k2] = x_cp.get('mpi_%s' % k2, []) # Skew models away from root process so that # connection threading might not share process models[v['mpi_rank']].append((k, x_cp)) max_len = len(max(models, key=len)) for x in models: while len(x) < max_len: x.append(None) else: models = None connections = None self.modeldrivers = dict( [x for x in self.mpi_comm.scatter(models, root=0) if (x is not None)]) self.connectiondrivers = dict( [x for x in self.mpi_comm.scatter(connections, root=0) if (x is not None)]) self.modelcopies = self.mpi_comm.bcast(self.modelcopies, root=0) self.info("Models on MPI process %d: %s", self.rank, list(self.modeldrivers.keys())) # Add dummy drivers on root process to monitor remote ones # and re-group copies into duplicate model w/ duplicate models # before non-duplicate to allow them to start before starting # local models if self.rank == 0: for i, (k, v) in enumerate(self.all_modeldrivers.items()): if k not in self.modeldrivers: v['partner_driver'] = v['driver'] v['language'] = 'mpi' v['driver'] = 'MPIPartnerModel' self.modeldrivers[k] = v self.connectiondrivers = self.all_connectiondrivers else: for v in self.modeldrivers.values(): for k in ['input_drivers', 'output_drivers']: v[k] = [self.connectiondrivers[x] for x in v.get(k, [])] self.reduce_duplicates()
[docs] def expand_duplicates(self): r"""Expand model copies so they can be split across MPI processes.""" self.debug("Expanding duplicated models") remove_dup = [] add_dup = {} for k, v in self.modeldrivers.items(): if v.get('copies', 1) > 1: self.modelcopies[v['name']] = [] for x in DuplicatedModelDriver.get_yaml_copies(v): add_dup[x['name']] = x self.modelcopies[v['name']].append(x['name']) remove_dup.append(k) for k in remove_dup: self.base_dup[k] = self.modeldrivers.pop(k) self.modeldrivers.update(add_dup)
[docs] def reduce_duplicates(self): r"""Join model duplicates after they were split between processes.""" self.debug("Reducing duplicated models") for k in list(self.modelcopies.keys()): duplicates = [self.modeldrivers.pop(cpy) for cpy in self.modelcopies.pop(k) if cpy in self.modeldrivers] if duplicates: if k in self.base_dup: base = self.base_dup[k] else: base = dict(copy.deepcopy(duplicates[0]), name=k, input_drivers=duplicates[0].get( 'input_drivers', []), output_drivers=duplicates[0].get( 'output_drivers', [])) base.pop('copy_index', None) for x in duplicates: for k2 in ['input_drivers', 'output_drivers']: x[k2] = base.get(k2, []) base['duplicates'] = duplicates self.modeldrivers[k] = base
[docs] def loadDrivers(self): r"""Load all of the necessary drivers, doing the IO drivers first and adding IO driver environmental variables back tot he models.""" self.debug('') driver = dict(name='name') try: # Preparse model drivers first so that the input/output # channels are updated for wrapped functions self.debug("Preparsing model functions") for driver in self.modeldrivers.values(): driver_cls = import_component('model', driver['driver'], without_schema=True) driver_cls.preparse_function(driver) if self.mpi_comm: self.distribute_mpi() # Create I/O drivers self.debug("Loading connection drivers") for driver in self.connectiondrivers.values(): driver['task_method'] = self.connection_task_method self.create_connection_driver(driver) # Create model drivers self.debug("Loading model drivers") for driver in self.modeldrivers.values(): self.create_driver(driver) self.debug("Model %s:, env: %s", driver['name'], pformat(driver['instance'].env)) except BaseException as e: # pragma: debug self.error("%s could not be created: %s", driver['name'], e) self.terminate() raise
[docs] def start_server(self, name): r"""Start a server driver.""" if self.mpi_comm and (self.rank != 0): return # This is required if modelcopies are not joined before drivers # are started # if name in self.modelcopies: # assert name not in self.modeldrivers # for cpy in self.modelcopies[name]: # self.start_server(cpy) # return x = self.modeldrivers[name]['instance'] if not x.was_started: self.debug("Starting server '%s' before client", x.name) x.start()
[docs] def stop_server(self, name): r"""Stop a server driver.""" # This is required if modelcopies are not joined before drivers # are started # if name in self.modelcopies: # assert name not in self.modeldrivers # for cpy in self.modelcopies[name]: # self.stop_server(cpy) # return x = self.modeldrivers[name]['instance'] x.stop()
[docs] def startDrivers(self): r"""Start drivers, starting with the IO drivers.""" if not self.mpi_comm or (self.rank == 0): assert not self.modelcopies self.info('Starting I/O drivers and models on system ' + '{} in namespace {} with rank {}'.format( self.host, self.namespace, self.rank)) driver = dict(name='name') try: # Start connections for driver in self.io_drivers(): self.debug("Starting driver %s", driver['name']) d = driver['instance'] if not d.was_started: d.start() # Ensure connections in loop for driver in self.io_drivers(): self.debug("Checking driver %s", driver['name']) d = driver['instance'] d.wait_for_loop() assert d.was_loop assert not d.errors # Start models for driver in self.modeldrivers.values(): self.debug("Starting driver %s", driver['name']) d = driver['instance'] for n2 in driver.get('client_of', []): self.start_server(n2) if not d.was_started: d.start() except BaseException as e: # pragma: debug self.error("%s did not start: %s(%s)", driver['name'], type(e), e) self.terminate() raise if self.mpi_comm: self.mpi_comm.barrier() self.debug('ALL DRIVERS STARTED')
@property def is_alive(self): r"""bool: True if all of the models are still running, False otherwise.""" for drv in self.modeldrivers.values(): if (not drv['instance'].is_alive()) or drv['instance'].errors: return False return True
[docs] def waitModels(self, timeout=False): r"""Wait for all model drivers to finish. When a model finishes, join the thread and perform exits for associated IO drivers.""" self.debug('') running = [d for d in self.modeldrivers.values()] dead = [] Tout = self.start_timeout(t=timeout, key_suffix='.waitModels') while ((len(running) > 0) and (not self.error_flag) and (not Tout.is_out)): for drv in running: d = drv['instance'] if d.errors: # pragma: debug self.error('Error in model %s', drv['name']) self.error_flag = True break elif d.io_errors: # pragma: debug self.error('Error in input/output driver for model %s' % drv['name']) self.error_flag = True break d.join(1) if not d.is_alive(): if not d.errors: self.info("%s finished running.", drv['name']) # self.do_model_exits(drv) # self.debug("%s completed model exits.", drv['name']) self.do_client_exits(drv) self.debug("%s completed client exits.", drv['name']) running.remove(drv) self.info("%s finished exiting.", drv['name']) else: self.debug('%s still running', drv['name']) dead = [] for drv in self.all_drivers: d = drv['instance'] d.join(0.1) if not d.is_alive(): dead.append(drv['name']) self.stop_timeout(key_suffix='.waitModels') for d in self.modeldrivers.values(): if d['instance'].errors: self.error_flag = True if not self.error_flag: self.info('All models completed') else: self.error('One or more models generated errors.') self.printStatus() self.terminate() if self.mpi_comm: allcode = self.mpi_comm.allreduce(self.error_flag, op=MPI.SUM) if not self.error_flag: self.error_flag = allcode self.debug('Returning')
# def do_model_exits(self, model): # r"""Perform exits for IO drivers associated with a model. # Args: # model (dict): Dictionary of model parameters including any # associated IO drivers. # """ # for drv in model['input_drivers']: # # if model['name'] in drv['models']: # # drv['models'].remove(model['name']) # if not drv['instance'].is_alive(): # continue # # if (len(drv['models']) == 0): # self.debug('on_model_exit %s', drv['name']) # drv['instance'].on_model_exit('output', model['name']) # for drv in model['output_drivers']: # # if model['name'] in drv['models']: # # drv['models'].remove(model['name']) # if not drv['instance'].is_alive(): # continue # # if (len(drv['models']) == 0): # self.debug('on_model_exit %s', drv['name']) # drv['instance'].on_model_exit('input', model['name'])
[docs] def do_client_exits(self, model): r"""Perform exits for IO drivers associated with a client model. Args: model (dict): Dictionary of model parameters including any associated IO drivers. """ if self.mpi_comm and (self.rank != 0): return # TODO: Exit upstream models that no longer have any open # output, connections when a connection is closed. for srv_name in model.get('client_of', []): iod = self.connectiondrivers[srv_name] iod['instance'].remove_model('input', model['name']) if iod['instance'].nclients == 0: self.stop_server(srv_name)
[docs] def pause(self): r"""Pause all drivers.""" self.debug('') for driver in self.all_drivers: if 'instance' in driver: driver['instance'].pause()
[docs] def resume(self): r"""Resume all paused drivers.""" self.debug('') for driver in self.all_drivers: if 'instance' in driver: driver['instance'].resume()
[docs] def terminate(self): r"""Immediately stop all drivers, beginning with IO drivers.""" self.debug('') self.resume() for driver in self.all_drivers: if 'instance' in driver: self.debug('Stop %s', driver['name']) driver['instance'].terminate() # Terminate should ensure instance not alive assert not driver['instance'].is_alive() self.debug('Returning')
[docs] def cleanup(self): r"""Perform cleanup operations for all drivers.""" self.debug('') for driver in self.all_drivers: if 'instance' in driver: driver['instance'].cleanup()
# self.inputdrivers = {} # self.outputdrivers = {} # self.modeldrivers = {}
[docs] def printStatus(self, return_str=False): r"""Print the status of all drivers, starting with the IO drivers.""" self.debug('') out = [] for driver in self.all_drivers: if 'instance' in driver: out.append( driver['instance'].printStatus(return_str=return_str)) if return_str: return '\n'.join(out)
[docs] def closeChannels(self, force_stop=False): r"""Stop IO drivers and join the threads. Args: force_stop (bool, optional): If True, the terminate method is used to stop the drivers. Otherwise, the stop method is used. The stop method will try to exit gracefully while terminate will exit as quickly as possible. Defaults to False. """ self.debug('') drivers = [i for i in self.io_drivers()] for drv in drivers: if 'instance' in drv: driver = drv['instance'] if driver.is_alive(): # pragma: debug self.debug("Stopping %s", drv['name']) if force_stop or self.error_flag: driver.terminate() else: driver.stop() self.debug("Stop(%s) returns", drv['name']) self.debug('Channel Stops DONE') for drv in drivers: if 'instance' in drv: driver = drv['instance'] assert not driver.is_alive() self.debug('Returning')
[docs]def get_runner(models, **kwargs): r"""Get runner for a set of models, getting run information from the environment. Args: models (list): List of yaml files containing information on the models that should be run. **kwargs: Additonal keyword arguments are passed to YggRunner. Returns: YggRunner: Runner for the provided models. Raises: Exception: If config option 'namespace' in 'rmq' section not set. """ # Get environment variables rank = os.environ.get('PARALLEL_SEQ', '0') host = socket.gethostname() os.environ['YGG_RANK'] = rank os.environ['YGG_HOST'] = host rank = int(rank) kwargs.update(rank=rank, host=host) # Run yggRunner = YggRunner(models, **kwargs) return yggRunner
[docs]def run(*args, **kwargs): run_kwargs = kwargs.pop('run_kwargs', {}) yggRunner = get_runner(*args, **kwargs) yggRunner.run(**run_kwargs) yggRunner.debug("runner returns, exiting")