import copy
from yggdrasil.drivers import create_driver
from yggdrasil.drivers.Driver import Driver
[docs]class DuplicatedModelDriver(Driver):
r"""Base class for Model drivers and for running executable based models.
Args:
name (str): Unique name used to identify the model. This will
be used to report errors associated with the model.
*args: Additional arguments are passed to the models in the set.
**kwargs: Additional keyword arguments are passed to the models in
the set.
Attributes:
Raises:
RuntimeError: If both with_strace and with_valgrind are True.
"""
name_format = "%s_copy%d"
def __init__(self, yml, duplicates=None, **kwargs):
kwargs.update(yml)
self.copies = []
if duplicates is not None:
for x in duplicates:
ienv = copy.deepcopy(yml.get('env', {}))
ienv.update(yml.pop('env_%s' % x['name'], {}))
ienv.update(x.pop('env', {}))
x['env'] = ienv
ikws = copy.deepcopy(kwargs)
ikws.update(x)
self.copies.append(create_driver(yml=x, **ikws))
else:
for iyml in self.get_yaml_copies(yml):
ikws = copy.deepcopy(kwargs)
ikws.update(iyml)
self.copies.append(create_driver(yml=iyml, **ikws))
super(DuplicatedModelDriver, self).__init__(**kwargs)
[docs] @classmethod
def get_base_name(cls, name):
r"""Get the name of the base model.
Args:
name (str): Model name.
Returns:
str: Base model name.
"""
assert '_copy' in name
return name.split('_copy')[0]
[docs] @classmethod
def get_yaml_copies(cls, yml):
r"""Get a list of yamls for creating duplicate models for the model
described by the provided yaml.
Args:
yml (dict): Input parameters for creating a model driver.
Returns:
list: Copies of input parameters for creating duplicate models.
"""
env_copy_specific = {}
for i in range(yml['copies']):
iname = cls.name_format % (yml['name'], i)
env_copy_specific[iname] = yml.pop('env_%s' % iname, {})
copies = []
for i in range(yml['copies']):
iyml = copy.deepcopy(yml)
iyml['name'] = cls.name_format % (yml['name'], i)
iyml['copy_index'] = i
iyml['input_drivers'] = yml['input_drivers']
iyml['output_drivers'] = yml['output_drivers']
# Update environment to reflect addition of suffix
iyml['env'] = yml.get('env', {}).copy()
iyml['env'].update(env_copy_specific.get(iyml['name'], {}))
copies.append(iyml)
return copies
[docs] def cleanup(self, *args, **kwargs):
r"""Actions to perform to clean up the thread after it has stopped."""
for x in self.copies:
x.cleanup(*args, **kwargs)
super(DuplicatedModelDriver, self).cleanup(*args, **kwargs)
[docs] def start(self, *args, **kwargs):
r"""Start thread/process and print info."""
# self.delay_start(*args, **kwargs)
input_drivers = self.yml.get('input_drivers', [])
output_drivers = self.yml.get('output_drivers', [])
for x in self.copies:
x.env.update(x.get_io_env(input_drivers=input_drivers,
output_drivers=output_drivers))
x.start(*args, **kwargs)
super(DuplicatedModelDriver, self).start(*args, **kwargs)
# def delay_start(self, *args, **kwargs):
# r"""This method should not be called in production and is only
# used for local testing to simulation a delayed start for some
# copies."""
# self.copies[0].start(*args, **kwargs)
# def start_remainder():
# for x in self.copies[1:]:
# x.start(*args, **kwargs)
# self.sched_task(0.4, start_remainder)
# super(DuplicatedModelDriver, self).start(*args, **kwargs)
[docs] def stop(self, *args, **kwargs):
r"""Stop the driver."""
for x in self.copies:
x.stop(*args, **kwargs)
super(DuplicatedModelDriver, self).stop(*args, **kwargs)
[docs] def graceful_stop(self, *args, **kwargs):
r"""Gracefully stop the driver."""
for x in self.copies:
x.graceful_stop(*args, **kwargs)
super(DuplicatedModelDriver, self).graceful_stop(*args, **kwargs)
[docs] def terminate(self, *args, **kwargs):
r"""Set the terminate event and wait for the thread/process to stop."""
for x in self.copies:
x.terminate(*args, **kwargs)
super(DuplicatedModelDriver, self).terminate(*args, **kwargs)
[docs] def run_loop(self):
r"""Loop to check if model is still running and forward output."""
# TODO: Stop if there is an error on one?
if any([x.is_alive() for x in self.copies]):
self.sleep()
return
else:
self.set_break_flag()
[docs] def after_loop(self, *args, **kwargs):
r"""Actions to perform after run_loop has finished."""
for x in self.copies:
x.terminate()
super(DuplicatedModelDriver, self).after_loop(*args, **kwargs)
[docs] def printStatus(self, *args, **kwargs):
r"""Print the class status."""
out_copies = []
for x in self.copies:
out_copies.append(x.printStatus(*args, **kwargs))
out = super(DuplicatedModelDriver, self).printStatus(*args, **kwargs)
if kwargs.get('return_str', False):
out = '\n'.join(out_copies + [out])
return out
@property
def io_errors(self):
r"""list: Errors produced by input/output drivers to this model."""
errors = []
for x in self.copies:
errors += x.io_errors
return errors
@property
def errors(self):
r"""list: Errors returned by model copies."""
out = []
for x in self.copies:
out += x.errors
return out
@errors.setter
def errors(self, val):
pass