Source code for cis_interface.runner

"""This module provides tools for running models using cis_interface."""
import sys
import logging
# import atexit
import os
import time
import signal
from pprint import pformat
from itertools import chain
import socket
from cis_interface.tools import CisClass
from cis_interface.config import cis_cfg, cfg_environment
from cis_interface import platform, backwards, yamlfile
from cis_interface.drivers import create_driver


COLOR_TRACE = '\033[30;43;22m'
COLOR_NORMAL = '\033[0m'


# def setup_cis_logging(prog, level=None):
#     r"""Set the log lovel based on environment variable 'CIS_DEBUG'. If the
#     variable is not set, the log level is set to 'NOTSET'.

#     Args:
#         prog (str): Name to prepend log messages with.
#         level (str, optional): String specifying the logging level. Defaults
#             to None and the environment variable 'CIS_DEBUG' is used.

#     """
#     if level is None:
#         level = cis_cfg.get('debug', 'cis', 'NOTSET')
#     logLevel = eval('logging.' + level)
#     logging.basicConfig(level=logLevel, stream=sys.stdout, format=COLOR_TRACE +
#                         prog + ': %(message)s' + COLOR_NORMAL)


[docs]class CisRunner(CisClass): r"""This class handles the orchestration of starting the model and IO drivers, monitoring their progress, and cleaning up on exit. Arguments: modelYmls (list): List of paths to yaml files specifying the models that should be run. namespace (str): Name that should be used to uniquely identify any RMQ exchange. host (str, optional): Name of the host that the models will be launched from. Defaults to None. rank (int, optional): Rank of this set of models if run in parallel. Defaults to 0. cis_debug_level (str, optional): Level for CiS debug messages. Defaults to environment variable 'CIS_DEBUG'. rmq_debug_level (str, optional): Level for RabbitMQ debug messages. Defaults to environment variable 'RMQ_DEBUG'. cis_debug_prefix (str, optional): Prefix for CiS debug messages. Defaults to namespace. Attributes: namespace (str): Name that should be used to uniquely identify any RMQ exchange. host (str): Name of the host that the models will be launched from. rank (int): Rank of this set of models if run in parallel. modeldrivers (dict): Model drivers associated with this run. inputdrivers (dict): Input drivers associated with this run. outputdrivers (dict): Output drivers associated with this run. serverdrivers (dict): The addresses associated with different server drivers. interrupt_time (float): Time of last interrupt signal. error_flag (bool): True if one or more models raises an error. ..todo:: namespace, host, and rank do not seem strictly necessary. """ def __init__(self, modelYmls, namespace, host=None, rank=0, cis_debug_level=None, rmq_debug_level=None, cis_debug_prefix=None): super(CisRunner, self).__init__('runner') self.namespace = namespace self.host = host self.rank = rank self.modeldrivers = {} self.inputdrivers = {} self.outputdrivers = {} self.serverdrivers = {} self.interrupt_time = 0 self._inputchannels = {} self._outputchannels = {} self._old_handlers = {} self.error_flag = False # Setup logging # if cis_debug_prefix is None: # cis_debug_prefix = namespace # setup_cis_logging(cis_debug_prefix, level=cis_debug_level) # Update environment based on config cfg_environment() # Parse yamls drivers = yamlfile.parse_yaml(modelYmls) self.inputdrivers = drivers['input'] self.outputdrivers = drivers['output'] self.modeldrivers = drivers['model'] for x in self.outputdrivers.values(): self._outputchannels[x['args']] = x for x in self.inputdrivers.values(): self._inputchannels[x['args']] = x # print(pformat(self.inputdrivers), pformat(self.outputdrivers), # pformat(self.modeldrivers)) # atexit.register(self.cleanup)
[docs] def pprint(self, *args): r"""Print with color.""" s = ''.join(str(i) for i in args) print((COLOR_TRACE + '{}' + COLOR_NORMAL).format(s))
[docs] def signal_handler(self, sig, frame): r"""Terminate all drivers on interrrupt.""" self.debug("Interrupt with signal %d", sig) now = backwards.clock_time() elapsed = now - self.interrupt_time self.debug('Elapsed time since last interrupt: %d s', elapsed) self.interrupt_time = now self.pprint(' ') self.pprint(80 * '*') if elapsed < 5: self.pprint('* %76s *' % 'Interrupted twice within 5 seconds: shutting down') self.pprint(80 * '*') # signal.siginterrupt(signal.SIGTERM, True) # signal.siginterrupt(signal.SIGINT, True) self.debug("Terminating models and closing all channels") self.terminate() self.pprint(80 * '*') # self.sleep(5) return 1 else: self.pprint('* %76s *' % 'Interrupted: Displaying channel summary') self.pprint('* %76s *' % 'interrupt again (within 5 seconds) to exit') self.pprint(80 * '*') self.printStatus() self.pprint(80 * '*') self.debug('%d returns', sig)
def _swap_handler(self, signum, signal_handler): self._old_handlers[signum] = signal.getsignal(signum) signal.signal(signum, signal_handler) if not platform._is_win: signal.siginterrupt(signum, False)
[docs] def set_signal_handler(self, signal_handler=None): r"""Set the signal handler. Args: signal_handler (function, optional): Function that should handle received SIGINT and SIGTERM signals. Defaults to self.signal_handler. """ if signal_handler is None: signal_handler = self.signal_handler self._swap_handler(signal.SIGINT, signal_handler) if not platform._is_win: self._swap_handler(signal.SIGTERM, signal_handler) else: # pragma: windows self._swap_handler(signal.SIGBREAK, signal_handler)
[docs] def reset_signal_handler(self): r"""Reset signal handlers to old ones.""" for k, v in self._old_handlers.items(): signal.signal(k, v)
[docs] def run(self, signal_handler=None, timer=None, t0=None): r"""Run all of the models and wait for them to exit. Args: signal_handler (function, optional): Function that should be used as a signal handler. Defaults to None and is set by set_signal_handler. timer (function, optional): Function that should be called to get intermediate timing statistics. Defaults to time.time if not provided. t0 (float, optional): Zero point for timing statistics. Is set using the provided timer if not provided. Returns: dict: Intermediate times from the run. """ if timer is None: timer = time.time if t0 is None: t0 = timer() times = {} times['init'] = timer() self.loadDrivers() times['load drivers'] = timer() self.startDrivers() times['start drivers'] = timer() self.set_signal_handler(signal_handler) self.waitModels() times['run models'] = timer() self.reset_signal_handler() self.closeChannels() times['close channels'] = timer() self.cleanup() times['clean up'] = timer() tprev = t0 key_order = ['init', 'load drivers', 'start drivers', 'run models', 'close channels', 'clean up'] for k in key_order: self.info('%20s\t%f', k, times[k] - tprev) tprev = times[k] self.info(40 * '=') self.info('%20s\t%f', "Total", tprev - t0) return times
@property def all_drivers(self): r"""iterator: For all drivers.""" return chain(self.inputdrivers.values(), self.outputdrivers.values(), self.modeldrivers.values())
[docs] def io_drivers(self, model=None): r"""Return the input and output drivers for one or all models. Args: model (str, optional): Name of a model that I/O drivers should be returned for. Defaults to None and all I/O drivers are returned. Returns: iterator: Access to list of I/O drivers. """ if model is None: out = chain(self.inputdrivers.values(), self.outputdrivers.values()) else: driver = self.modeldrivers[model] out = chain(driver.get('input_drivers', dict()), driver.get('output_drivers', dict())) return out
[docs] def createDriver(self, yml): r"""Create a driver instance from the yaml information. Args: yml (yaml): Yaml object containing driver information. Returns: object: An instance of the specified driver. """ self.debug('Creating %s, a %s', yml['name'], yml['driver']) curpath = os.getcwd() if 'ClientDriver' in yml['driver']: yml.setdefault('comm_address', self.serverdrivers[yml['args']]) if 'working_dir' in yml: os.chdir(yml['working_dir']) instance = create_driver(yml=yml, namespace=self.namespace, rank=self.rank, **yml) yml['instance'] = instance os.chdir(curpath) if 'ServerDriver' in yml['driver']: self.serverdrivers[yml['args']] = instance.comm_address return instance
[docs] def createModelDriver(self, yml): r"""Create a model driver instance from the yaml information. Args: yml (yaml): Yaml object containing driver information. Returns: object: An instance of the specified driver. """ yml['env'] = {} for iod in self.io_drivers(yml['name']): yml['env'].update(iod['instance'].env) iod['models'].append(yml['name']) drv = self.createDriver(yml) if 'client_of' in yml: for srv in yml['client_of']: self.modeldrivers[srv]['clients'].append(yml['name']) self.debug("Model %s:, env: %s", yml['name'], pformat(yml['instance'].env)) return drv
[docs] def createInputDriver(self, yml): r"""Create an input driver instance from the yaml information. Args: yml (yaml): Yaml object containing driver information. Returns: object: An instance of the specified driver. """ yml['models'] = [] if yml['args'] not in self._outputchannels: for x in yml['icomm_kws']['comm']: if 'filetype' not in x: raise ValueError( ("Input driver %s could not locate a " + "corresponding file or output channel %s") % ( x["name"], yml["args"])) drv = self.createDriver(yml) return drv
[docs] def createOutputDriver(self, yml): r"""Create an output driver instance from the yaml information. Args: yml (yaml): Yaml object containing driver information. Returns: object: An instance of the specified driver. """ yml['models'] = [] if yml['args'] in self._inputchannels: yml.setdefault('comm_env', {}) yml['comm_env'] = self._inputchannels[yml['args']]['instance'].comm_env if yml['args'] not in self._inputchannels: for x in yml['ocomm_kws']['comm']: if 'filetype' not in x: raise ValueError( ("Output driver %s could not locate a " + "corresponding file or input channel %s") % ( x["name"], yml["args"])) drv = self.createDriver(yml) return drv
[docs] def loadDrivers(self): r"""Load all of the necessary drivers, doing the IO drivers first and adding IO driver environmental variables back tot he models.""" self.debug('') driver = dict(name='name') try: # Create input drivers self.debug("Loading input drivers") for driver in self.inputdrivers.values(): self.createInputDriver(driver) # Create output drivers self.debug("Loading output drivers") for driver in self.outputdrivers.values(): self.createOutputDriver(driver) # Create model drivers self.debug("Loading model drivers") for driver in self.modeldrivers.values(): self.createModelDriver(driver) except BaseException: # pragma: debug self.error("%s could not be created.", driver['name']) self.terminate() raise
[docs] def startDrivers(self): r"""Start drivers, starting with the IO drivers.""" self.info('Starting I/O drivers and models on system ' + '{} in namespace {} with rank {}'.format( self.host, self.namespace, self.rank)) driver = dict(name='name') try: # Start connections for driver in self.io_drivers(): self.debug("Starting driver %s", driver['name']) d = driver['instance'] if not d.was_started: d.start() # Ensure connections in loop for driver in self.io_drivers(): self.debug("Checking driver %s", driver['name']) d = driver['instance'] d.wait_for_loop() assert(d.was_loop) assert(not d.errors) # Start models # self.sleep(1) # on windows comms can take a while start for driver in self.modeldrivers.values(): self.debug("Starting driver %s", driver['name']) d = driver['instance'] for n2 in driver.get('client_of', []): d2 = self.modeldrivers[n2]['instance'] if not d2.was_started: self.debug("Starting server '%s' before client", d2.name) d2.start() if not d.was_started: d.start() except BaseException: # pragma: debug self.error("%s did not start", driver['name']) self.terminate() raise self.debug('ALL DRIVERS STARTED')
[docs] def waitModels(self): r"""Wait for all model drivers to finish. When a model finishes, join the thread and perform exits for associated IO drivers.""" self.debug('') running = [d for d in self.modeldrivers.values()] dead = [] while (len(running) > 0) and (not self.error_flag): for drv in running: d = drv['instance'] if d.errors: # pragma: debug self.error('Error in model %s', drv['name']) self.error_flag = True break d.join(1) if not d.is_alive(): if not d.errors: self.info("%s finished running.", drv['name']) self.do_model_exits(drv) self.debug("%s completed model exits.", drv['name']) self.do_client_exits(drv) self.debug("%s completed client exits.", drv['name']) running.remove(drv) self.info("%s finished exiting.", drv['name']) else: self.info('%s still running', drv['name']) dead = [] for drv in self.all_drivers: d = drv['instance'] d.join(0.1) if not d.is_alive(): dead.append(drv['name']) for d in self.modeldrivers.values(): if d['instance'].errors: self.error_flag = True if not self.error_flag: self.info('All models completed') else: self.error('One or more models generated errors.') self.terminate() self.debug('Returning')
[docs] def do_model_exits(self, model): r"""Perform exits for IO drivers associated with a model. Args: model (dict): Dictionary of model parameters including any associated IO drivers. """ for drv in self.io_drivers(model['name']): drv['models'].remove(model['name']) if not drv['instance'].is_alive(): continue if (len(drv['models']) == 0): self.debug('on_model_exit %s', drv['name']) drv['instance'].on_model_exit()
[docs] def do_client_exits(self, model): r"""Perform exits for IO drivers associated with a client model. Args: model (dict): Dictionary of model parameters including any associated IO drivers. """ for srv_name in model.get('client_of', []): # Remove this client from list for server srv = self.modeldrivers[srv_name] srv['clients'].remove(model['name']) # Stop server if there are not any more clients if len(srv['clients']) == 0: iod = self.inputdrivers[srv_name] iod['instance'].on_client_exit() srv['instance'].stop()
[docs] def terminate(self): r"""Immediately stop all drivers, beginning with IO drivers.""" self.debug('') # self.closeChannels(force_stop=True) # self.debug('Stop models') for driver in self.all_drivers: if 'instance' in driver: self.debug('Stop %s', driver['name']) driver['instance'].terminate() # Terminate should ensure instance not alive assert(not driver['instance'].is_alive()) # if driver['instance'].is_alive(): # driver['instance'].join() self.debug('Returning')
[docs] def cleanup(self): r"""Perform cleanup operations for all drivers.""" self.debug('') for driver in self.all_drivers: if 'instance' in driver: driver['instance'].cleanup()
[docs] def printStatus(self): r"""Print the status of all drivers, starting with the IO drivers.""" self.debug('') for driver in self.all_drivers: if 'instance' in driver: driver['instance'].printStatus()
[docs] def closeChannels(self, force_stop=False): r"""Stop IO drivers and join the threads. Args: force_stop (bool, optional): If True, the terminate method is used to stop the drivers. Otherwise, the stop method is used. The stop method will try to exit gracefully while terminate will exit as quickly as possible. Defaults to False. """ self.debug('') drivers = [i for i in self.io_drivers()] for drv in drivers: if 'instance' in drv: driver = drv['instance'] if driver.is_alive(): # pragma: debug self.debug("Stopping %s", drv['name']) if force_stop or self.error_flag: driver.terminate() else: driver.stop() self.debug("Stop(%s) returns", drv['name']) self.debug('Channel Stops DONE') for drv in drivers: if 'instance' in drv: driver = drv['instance'] assert(not driver.is_alive()) # self.debug("Join %s", drv['name']) # if driver.is_alive(): # driver.join() # self.debug("Join %s done", drv['name']) self.debug('Returning')
[docs]def get_runner(models, **kwargs): r"""Get runner for a set of models, getting run information from the environment. Args: models (list): List of yaml files containing information on the models that should be run. **kwargs: Additonal keyword arguments are passed to CisRunner. Returns: CisRunner: Runner for the provided models. Raises: Exception: If config option 'namespace' in 'rmq' section not set. """ # Get environment variables logger = logging.getLogger(__name__) namespace = kwargs.pop('namespace', cis_cfg.get('rmq', 'namespace', False)) if not namespace: # pragma: debug raise Exception('rmq:namespace not set in config file') rank = os.environ.get('PARALLEL_SEQ', '0') host = socket.gethostname() os.environ['CIS_RANK'] = rank os.environ['CIS_HOST'] = host rank = int(rank) kwargs.update(rank=rank, host=host) # Run logger.debug("Running in %s with path %s namespace %s rank %d", os.getcwd(), sys.path, namespace, rank) cisRunner = CisRunner(models, namespace, **kwargs) return cisRunner