Source code for cis_interface.drivers.tests.test_ConnectionDriver

import uuid
import unittest
from cis_interface import tools, backwards
from cis_interface.tests import MagicTestError, assert_raises
from cis_interface.schema import get_schema
from cis_interface.drivers import import_driver
from cis_interface.drivers.tests import test_Driver as parent
from cis_interface.drivers.ConnectionDriver import ConnectionDriver
from cis_interface.communication import (
    new_comm, ZMQComm, IPCComm, RMQComm, get_comm_class)


_default_comm = tools.get_default_comm()
_zmq_installed = ZMQComm.ZMQComm.is_installed(language='python')
_ipc_installed = IPCComm.IPCComm.is_installed(language='python')
_rmq_server_running = RMQComm._rmq_server_running

            
[docs]class TestConnectionParam(parent.TestParam): r"""Test parameters for the ConnectionDriver class.""" comm_name = _default_comm icomm_name = _default_comm ocomm_name = _default_comm testing_option_kws = {} driver = 'ConnectionDriver' def __init__(self, *args, **kwargs): super(TestConnectionParam, self).__init__(*args, **kwargs) self.attr_list += ['icomm_kws', 'ocomm_kws', 'icomm', 'ocomm', 'nrecv', 'nproc', 'nsent', 'state', 'translator'] # self.timeout = 1.0 self._extra_instances = [] @property def description_prefix(self): r"""String prefix to prepend docstr test message with.""" out = super(TestConnectionParam, self).description_prefix return '%s(%s, %s)' % (out, self.icomm_name, self.ocomm_name) @property def maxMsgSize(self): r"""int: Maximum message size.""" return min(self.instance.icomm.maxMsgSize, self.instance.ocomm.maxMsgSize) @property def is_input(self): r"""bool: True if the connection is for input.""" return (self.icomm_name != self.comm_name) @property def is_output(self): r"""bool: True if the connection is for output.""" return (self.ocomm_name != self.comm_name)
[docs] def assert_msg_equal(self, x, y): r"""Assert that two messages are equivalent.""" self.assert_equal(x, y)
[docs] def get_options(self): r"""Get testing options.""" if self.is_output: out = self.ocomm_import_cls.get_testing_options( **self.testing_option_kws) else: out = self.icomm_import_cls.get_testing_options( **self.testing_option_kws) return out
@property def cleanup_comm_classes(self): r"""list: Comm classes that should be cleaned up following the test.""" comms = set([self.comm_name, self.icomm_name, self.ocomm_name]) return comms @property def icomm_kws(self): r"""dict: Keyword arguments for connection input comm.""" out = {'name': self.icomm_name, 'comm': self.icomm_name} if self.is_input: out.update(self.testing_options['kwargs']) return out @property def ocomm_kws(self): r"""dict: Keyword arguments for connection output comm.""" out = {'name': self.ocomm_name, 'comm': self.ocomm_name} if self.is_output: out.update(self.testing_options['kwargs']) return out @property def icomm_import_cls(self): r"""class: Class used for connection input comm.""" return get_comm_class(self.icomm_name) @property def ocomm_import_cls(self): r"""class: Class used for connection output comm.""" return get_comm_class(self.ocomm_name) @property def send_comm_kwargs(self): r"""dict: Keyword arguments for send comm.""" return self.instance.icomm.opp_comm_kwargs() @property def recv_comm_kwargs(self): r"""dict: Keyword arguments for recv comm.""" return self.instance.ocomm.opp_comm_kwargs() @property def inst_kwargs(self): r"""dict: Keyword arguments for tested class.""" out = super(TestConnectionParam, self).inst_kwargs out['icomm_kws'] = self.icomm_kws out['ocomm_kws'] = self.ocomm_kws return out @property def test_msg(self): r"""str: Test message that should be used for any send/recv tests.""" return self.testing_options['msg'] @property def msg_long(self): r"""str: Small test message for sending.""" msg_short = self.test_msg if isinstance(msg_short, backwards.bytes_type): out = msg_short + (self.maxMsgSize * b'0') else: # pragma: debug out = msg_short # return self.testing_options['msg_long'] return out
[docs] def setup(self, *args, **kwargs): r"""Initialize comm object pair.""" super(TestConnectionParam, self).setup(*args, **kwargs) send_kws = self.send_comm_kwargs recv_kws = self.recv_comm_kwargs if self.skip_start: send_kws['dont_open'] = True recv_kws['dont_open'] = True self.send_comm = new_comm(self.name, **send_kws) self.recv_comm = new_comm(self.name, **recv_kws)
[docs] def teardown(self, *args, **kwargs): r"""Destroy comm object pair.""" self.send_comm.close() self.recv_comm.close() assert(self.send_comm.is_closed) assert(self.recv_comm.is_closed) super(TestConnectionParam, self).teardown(*args, **kwargs) for inst in self._extra_instances: inst.terminate()
[docs]class TestConnectionDriverNoStart(TestConnectionParam, parent.TestDriverNoStart): r"""Test class for the ConnectionDriver class without start."""
[docs] def test_send_recv(self): r"""Test sending/receiving with queues closed.""" self.instance.close_comm() self.send_comm.close() self.recv_comm.close() assert(self.instance.is_comm_closed) assert(self.send_comm.is_closed) assert(self.recv_comm.is_closed) flag = self.instance.send_message() assert(not flag) flag = self.instance.recv_message() assert(not flag) # Short flag = self.send_comm.send(self.test_msg) assert(not flag) flag, ret = self.recv_comm.recv() assert(not flag) self.assert_equal(ret, None) # Long flag = self.send_comm.send_nolimit(self.test_msg) assert(not flag) flag, ret = self.recv_comm.recv_nolimit() assert(not flag) self.assert_equal(ret, None)
[docs] def get_fresh_name(self): r"""Get a fresh name for a new instance that won't overlap with the base.""" return 'Test%s_%s' % (self.cls, str(uuid.uuid4()))
[docs] def get_fresh_error_instance(self, comm, error_on_init=False): r"""Get a driver instance with ErrorComm class for one or both comms.""" args = [self.get_fresh_name()] if self.args is not None: args.append(self.args) # args = self.inst_args kwargs = self.inst_kwargs if 'comm_address' in kwargs: del kwargs['comm_address'] if comm in ['ocomm', 'both']: kwargs['ocomm_kws'].update( base_comm=self.ocomm_name, new_comm_class='ErrorComm', error_on_init=error_on_init) if comm in ['icomm', 'both']: kwargs['icomm_kws'].update( base_comm=self.icomm_name, new_comm_class='ErrorComm', error_on_init=error_on_init) driver_class = import_driver(self.driver) if error_on_init: self.assert_raises(MagicTestError, driver_class, *args, **kwargs) else: inst = driver_class(*args, **kwargs) inst.icomm._first_send_done = True self._extra_instances.append(inst) return inst
[docs] def test_error_init_ocomm(self): r"""Test forwarding of error from init of ocomm.""" self.get_fresh_error_instance('ocomm', error_on_init=True)
[docs] def test_error_open_icomm(self): r"""Test fowarding of error from open of icomm.""" inst = self.get_fresh_error_instance('icomm') inst.icomm.error_replace('open') self.assert_raises(MagicTestError, inst.open_comm) assert(inst.icomm.is_closed) inst.icomm.restore_all()
[docs] def test_error_close_icomm(self): r"""Test forwarding of error from close of icomm.""" inst = self.get_fresh_error_instance('icomm') inst.open_comm() inst.icomm.error_replace('close') self.assert_raises(MagicTestError, inst.close_comm) assert(inst.ocomm.is_closed) inst.icomm.restore_all() inst.icomm.close() assert(inst.icomm.is_closed)
[docs] def test_error_close_ocomm(self): r"""Test forwarding of error from close of ocomm.""" inst = self.get_fresh_error_instance('ocomm') inst.open_comm() inst.ocomm.error_replace('close') self.assert_raises(MagicTestError, inst.close_comm) assert(inst.icomm.is_closed) inst.ocomm.restore_all() inst.ocomm.close() assert(inst.ocomm.is_closed)
[docs] def test_error_open_fails(self): r"""Test error raised when comms fail to open.""" inst = self.get_fresh_error_instance('both') old_timeout = inst.timeout inst.icomm.empty_replace('open') inst.ocomm.empty_replace('open') inst.timeout = inst.sleeptime / 2.0 self.assert_raises(Exception, inst.start) inst.timeout = old_timeout inst.icomm.restore_all() inst.ocomm.restore_all() inst.close_comm() assert(inst.is_comm_closed)
[docs]class TestConnectionDriver(TestConnectionParam, parent.TestDriver): r"""Test class for the ConnectionDriver class."""
[docs] def setup(self, *args, **kwargs): r"""Initialize comm object pair.""" super(TestConnectionDriver, self).setup(*args, **kwargs) # CommBase is dummy class that never opens if (self.send_comm.comm_class != 'CommBase'): assert(self.send_comm.is_open) if (self.recv_comm.comm_class != 'CommBase'): assert(self.recv_comm.is_open) self.nmsg_recv = 1
[docs] def test_early_close(self): r"""Test early deletion of message queue.""" self.instance.close_comm() self.instance.open_comm() assert(self.instance.is_comm_closed)
[docs] def test_send_recv(self): r"""Test sending/receiving small message.""" flag = self.send_comm.send(self.test_msg) if self.comm_name != 'CommBase': assert(flag) # self.instance.sleep() # if self.comm_name != 'CommBase': # self.assert_equal(self.recv_comm.n_msg, 1) for i in range(self.nmsg_recv): flag, msg_recv = self.recv_comm.recv(self.timeout) if self.comm_name != 'CommBase': assert(flag) self.assert_msg_equal(msg_recv, self.test_msg) if self.comm_name != 'CommBase': self.assert_equal(self.instance.n_msg, 0)
[docs] def test_send_recv_nolimit(self): r"""Test sending/receiving large message.""" assert(len(self.msg_long) > self.maxMsgSize) flag = self.send_comm.send_nolimit(self.msg_long) if self.comm_name != 'CommBase': assert(flag) for i in range(self.nmsg_recv): flag, msg_recv = self.recv_comm.recv_nolimit(self.timeout) if self.comm_name != 'CommBase': assert(flag) self.assert_msg_equal(msg_recv, self.msg_long)
[docs] def assert_before_stop(self, check_open=True): r"""Assertions to make before stopping the driver instance.""" super(TestConnectionDriver, self).assert_before_stop() if self.comm_name != 'CommBase' and check_open: assert(self.instance.is_comm_open)
[docs] def run_before_terminate(self): r"""Commands to run while the instance is running, before terminate.""" super(TestConnectionDriver, self).run_before_terminate()
# TODO: This fails with ZMQ # self.send_comm.send(self.test_msg)
[docs] def assert_after_terminate(self): r"""Assertions to make after terminating the driver instance.""" super(TestConnectionDriver, self).assert_after_terminate() assert(self.instance.is_comm_closed)
[docs]class TestConnectionDriverFork(TestConnectionDriver): r"""Test class for the ConnectionDriver class between fork comms."""
[docs] def setup(self, *args, **kwargs): r"""Initialize comm object pair.""" self.ncomm_input = 2 self.ncomm_output = 1 super(TestConnectionDriverFork, self).setup(*args, **kwargs) self.nmsg_recv = self.ncomm_input * self.ncomm_output
@property def inst_kwargs(self): r"""dict: Keyword arguments for tested class.""" out = super(TestConnectionDriverFork, self).inst_kwargs out['icomm_kws']['comm'] = [None for i in range(self.ncomm_input)] return out
[docs]def direct_translate(msg): r"""Test translator that just returns passed message.""" return msg
invalid_translate = True
[docs]class TestConnectionDriverTranslate(TestConnectionDriver): r"""Test class for the ConnectionDriver class with translator.""" @property def inst_kwargs(self): r"""dict: Keyword arguments for tested class.""" out = super(TestConnectionDriverTranslate, self).inst_kwargs out['translator'] = direct_translate out['onexit'] = 'printStatus' return out
[docs]def test_ConnectionDriverOnexit_errors(): r"""Test that errors are raised for invalid onexit.""" assert_raises(ValueError, ConnectionDriver, 'test', onexit='invalid')
[docs]def test_ConnectionDriverTranslate_errors(): r"""Test that errors are raised for invalid translators.""" assert(not hasattr(invalid_translate, '__call__')) assert_raises(ValueError, ConnectionDriver, 'test', translator=invalid_translate)
# Dynamically create tests based on registered file classes s = get_schema() comm_types = list(s['comm'].schema_subtypes.keys()) for k in comm_types: if k == _default_comm: continue # Output ocls = type('Test%sOutputDriver' % k, (TestConnectionDriver, ), {'ocomm_name': k, 'driver': 'OutputDriver', 'args': 'test'}) # Input icls = type('Test%sInputDriver' % k, (TestConnectionDriver, ), {'icomm_name': k, 'driver': 'InputDriver', 'args': 'test'}) # Flags flag_func = None if k in ['RMQComm', 'RMQAsyncComm']: flag_func = unittest.skipIf(not _rmq_server_running, "RMQ Server not running") elif k in ['ZMQComm']: flag_func = unittest.skipIf(not _zmq_installed, "ZMQ library not installed") elif k in ['IPCComm']: flag_func = unittest.skipIf(not _ipc_installed, "IPC library not installed") if flag_func is not None: ocls = flag_func(ocls) icls = flag_func(icls) # Add class to globals globals()[ocls.__name__] = ocls globals()[icls.__name__] = icls del ocls, icls