Source code for cis_interface.tests

"""Testing things."""
import os
import shutil
import uuid
import difflib
import importlib
import contextlib
import warnings
import unittest
import numpy as np
import pandas as pd
import threading
import psutil
import copy
from cis_interface.config import cis_cfg, cfg_logging
from cis_interface.tools import get_default_comm, CisClass
from cis_interface import backwards, platform, units
from cis_interface.communication import cleanup_comms, get_comm_class


# Test data
data_dir = os.path.join(os.path.dirname(__file__), 'data')
data_list = [
    ('txt', 'ascii_file.txt'),
    ('table', 'ascii_table.txt')]
data = {k: os.path.join(data_dir, v) for k, v in data_list}

# Test scripts
script_dir = os.path.join(os.path.dirname(__file__), 'scripts')
script_list = [
    ('c', ['gcc_model.c', 'hellofunc.c']),
    ('cpp', ['gcc_model.cpp', 'hellofunc.c']),
    ('make', 'gcc_model'),
    ('cmake', 'gcc_model'),
    ('matlab', 'matlab_model.m'),
    ('matlab_error', 'matlab_error_model.m'),
    ('python', 'python_model.py'),
    ('error', 'error_model.py'),
    ('lpy', 'lpy_model.lpy')]
scripts = {}
for k, v in script_list:
    if isinstance(v, list):
        scripts[k] = [os.path.join(script_dir, iv) for iv in v]
    else:
        scripts[k] = os.path.join(script_dir, v)
# scripts = {k: os.path.join(script_dir, v) for k, v in script_list}
    
# Test yamls
yaml_dir = os.path.join(os.path.dirname(__file__), 'yamls')
yaml_list = [
    ('c', 'gcc_model.yml'),
    ('cpp', 'gpp_model.yml'),
    ('make', 'make_model.yml'),
    ('cmake', 'cmake_model.yml'),
    ('matlab', 'matlab_model.yml'),
    ('python', 'python_model.yml'),
    ('error', 'error_model.yml'),
    ('lpy', 'lpy_model.yml')]
yamls = {k: os.path.join(yaml_dir, v) for k, v in yaml_list}

# Makefile
if platform._is_win:  # pragma: windows
    makefile0 = os.path.join(script_dir, "Makefile_windows")
else:
    makefile0 = os.path.join(script_dir, "Makefile_linux")
shutil.copy(makefile0, os.path.join(script_dir, "Makefile"))


# Flag for enabling tests that take a long time
enable_long_tests = os.environ.get("CIS_ENABLE_LONG_TESTS", False)


if backwards.PY2:  # pragma: Python 2
    # Dummy TestCase instance, so we can initialize an instance
    # and access the assert instance methods
    class DummyTestCase(unittest.TestCase):  # pragma: no cover
        def __init__(self):
            super(DummyTestCase, self).__init__('_dummy')

        def _dummy(self):
            pass

    # A metaclass that makes __getattr__ static
    class AssertsAccessorType(type):  # pragma: no cover
        dummy = DummyTestCase()

        def __getattr__(cls, key):
            return getattr(AssertsAccessor.dummy, key)

    # The actual accessor, a static class, that redirect the asserts
    class AssertsAccessor(object):  # pragma: no cover
        __metaclass__ = AssertsAccessorType
        
    ut = AssertsAccessor
        
else:  # pragma: Python 3

    ut = unittest.TestCase()


def long_running(func):
    r"""Decorator for marking long tests that should be skipped if
    CIS_ENABLE_LONG_TESTS is set.

    Args:
        func (callable): Test function or method.

    """
    return unittest.skipIf(not enable_long_tests, "Long tests not enabled.")(func)


def assert_raises(exception, *args, **kwargs):
    r"""Assert that a call raises an exception.

    Args:
        exception (Exception): Exception class that should be raised.
        callable (function, class, optional): Callable that should raise the
            exception. If not provided, a context manager is returned.
        *args: Additional arguments are passed to the callable.
        **kwargs: Additional keyword arguments are passed to the callable.

    Raises:
        AssertionError: If the correct exception is not raised.

    """
    return ut.assertRaises(exception, *args, **kwargs)


@contextlib.contextmanager
def assert_warns(warning, *args, **kwargs):
    r"""Assert that a call (or context) raises an exception.

    Args:
        warning (Warning): Warning class that should be raised.
        callable (function, class, optional): Function that should raise
            the warning. If not provided, a context manager is returned.
        *args: Additional arguments are passed to the callable.
        **kwargs: Additional keyword arguments are passed to the callable.

    Raises:
        AssertionError: If the correct warning is not caught.

    """
    if backwards.PY2:  # pragma: Python 2
        if args and args[0] is None:  # pragma: debug
            warnings.warn("callable is None",
                          DeprecationWarning, 3)
            args = ()
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            try:
                if not args:
                    yield w
                else:  # pragma: debug
                    callable_obj = args[0]
                    args = args[1:]
                    callable_obj(*args, **kwargs)
            finally:
                assert(len(w) >= 1)
                for iw in w:
                    assert(issubclass(iw.category, warning))
    else:  # pragma: Python 3
        yield ut.assertWarns(warning, *args, **kwargs)


def assert_equal(x, y):
    r"""Assert that two messages are equivalent.

    Args:
        x (object): Python object to compare against y.
        y (object): Python object to compare against x.

    Raises:
        AssertionError: If the two messages are not equivalent.

    """
    if isinstance(y, (list, tuple)):
        assert(isinstance(x, (list, tuple)))
        ut.assertEqual(len(x), len(y))
        for ix, iy in zip(x, y):
            assert_equal(ix, iy)
    elif isinstance(y, dict):
        assert(issubclass(y.__class__, dict))
        # ut.assertEqual(type(x), type(y))
        ut.assertEqual(len(x), len(y))
        for k, iy in y.items():
            ix = x[k]
            assert_equal(ix, iy)
    elif isinstance(y, (np.ndarray, pd.DataFrame)):
        if units.has_units(y) and (not units.has_units(x)):  # pragma: debug
            y = units.get_data(y)
        elif (not units.has_units(y)) and units.has_units(x):
            x = units.get_data(x)
        np.testing.assert_array_equal(x, y)
    else:
        if units.has_units(y) and units.has_units(x):
            x = units.convert_to(x, units.get_units(y))
            assert_equal(units.get_data(x), units.get_data(y))
        else:
            if units.has_units(y) and (not units.has_units(x)):  # pragma: debug
                y = units.get_data(y)
            elif (not units.has_units(y)) and units.has_units(x):
                x = units.get_data(x)
            ut.assertEqual(x, y)


def assert_not_equal(x, y):
    r"""Assert that two objects are NOT equivalent.

    Args:
        x (object): Python object to compare against y.
        y (object): Python object to compare against x.

    Raises:
        AssertionError: If the two objects are equivalent.

    """
    ut.assertNotEqual(x, y)
    
        
[docs]class CisTestBase(unittest.TestCase): r"""Wrapper for unittest.TestCase that allows use of setup and teardown methods along with description prefix. Args: description_prefix (str, optional): String to prepend docstring test message with. Default to empty string. skip_unittest (bool, optional): If True, the unittest parent class will not be initialized. Defaults to False. Attributes: uuid (str): Random unique identifier. attr_list (list): List of attributes that should be checked for after initialization. timeout (float): Maximum time in seconds for timeouts. sleeptime (float): Time in seconds that should be waited for sleeps. """ attr_list = list() def __init__(self, *args, **kwargs): self._description_prefix = kwargs.pop('description_prefix', str(self.__class__).split("'")[1]) self.uuid = str(uuid.uuid4()) self.timeout = 10.0 self.sleeptime = 0.01 self.attr_list = copy.deepcopy(self.__class__.attr_list) self._teardown_complete = False self._new_default_comm = None self._old_default_comm = None self._old_loglevel = None self._old_encoding = None self.debug_flag = False self._first_test = True skip_unittest = kwargs.pop('skip_unittest', False) if not skip_unittest: super(CisTestBase, self).__init__(*args, **kwargs)
[docs] def assert_equal(self, x, y): r"""Assert that two values are equal.""" return assert_equal(x, y)
[docs] def assert_less_equal(self, x, y): r"""Assert that one value is less than or equal to another.""" return self.assertLessEqual(x, y)
[docs] def assert_greater(self, x, y): r"""Assert that one value is greater than another.""" return self.assertGreater(x, y)
[docs] def assert_raises(self, *args, **kwargs): r"""Assert that a function raises an error.""" return self.assertRaises(*args, **kwargs)
@property def comm_count(self): r"""int: The number of comms.""" out = 0 for k in self.cleanup_comm_classes: cls = get_comm_class(k) out += cls.comm_count() return out @property def fd_count(self): r"""int: The number of open file descriptors.""" proc = psutil.Process() if platform._is_win: # pragma: windows out = proc.num_handles() else: out = proc.num_fds() # print(proc.num_fds(), proc.num_threads(), len(proc.connections("all")), # len(proc.open_files())) return out @property def thread_count(self): r"""int: The number of active threads.""" return threading.active_count()
[docs] def set_utf8_encoding(self): r"""Set the encoding to utf-8 if it is not already.""" old_lang = os.environ.get('LANG', '') if 'UTF-8' not in old_lang: # pragma: debug self._old_encoding = old_lang os.environ['LANG'] = 'en_US.UTF-8'
[docs] def reset_encoding(self): r"""Reset the encoding to the original value before the test.""" if self._old_encoding is not None: # pragma: debug os.environ['LANG'] = self._old_encoding self._old_encoding = None
[docs] def debug_log(self): # pragma: debug r"""Turn on debugging.""" self._old_loglevel = cis_cfg.get('debug', 'cis') cis_cfg.set('debug', 'cis', 'DEBUG') cfg_logging()
[docs] def reset_log(self): # pragma: debug r"""Resetting logging to prior value.""" if self._old_loglevel is not None: cis_cfg.set('debug', 'cis', self._old_loglevel) cfg_logging() self._old_loglevel = None
[docs] def set_default_comm(self, default_comm=None): r"""Set the default comm.""" self._old_default_comm = os.environ.get('CIS_DEFAULT_COMM', None) if default_comm is None: default_comm = self._new_default_comm if default_comm is not None: os.environ['CIS_DEFAULT_COMM'] = default_comm
[docs] def reset_default_comm(self): r"""Reset the default comm to the original value.""" if self._old_default_comm is None: if 'CIS_DEFAULT_COMM' in os.environ: del os.environ['CIS_DEFAULT_COMM'] else: # pragma: debug os.environ['CIS_DEFAULT_COMM'] = self._old_default_comm
[docs] def setUp(self, *args, **kwargs): self.setup(*args, **kwargs)
[docs] def tearDown(self, *args, **kwargs): self.teardown(*args, **kwargs)
[docs] def setup(self, nprev_comm=None, nprev_thread=None, nprev_fd=None): r"""Record the number of open comms, threads, and file descriptors. Args: nprev_comm (int, optional): Number of previous comm channels. If not provided, it is determined to be the present number of default comms. nprev_thread (int, optional): Number of previous threads. If not provided, it is determined to be the present number of threads. nprev_fd (int, optional): Number of previous open file descriptors. If not provided, it is determined to be the present number of open file descriptors. """ self.set_default_comm() self.set_utf8_encoding() if self.debug_flag: # pragma: debug self.debug_log() if nprev_comm is None: nprev_comm = self.comm_count if nprev_thread is None: nprev_thread = self.thread_count if nprev_fd is None: nprev_fd = self.fd_count self.nprev_comm = nprev_comm self.nprev_thread = nprev_thread self.nprev_fd = nprev_fd
[docs] def teardown(self, ncurr_comm=None, ncurr_thread=None, ncurr_fd=None): r"""Check the number of open comms, threads, and file descriptors. Args: ncurr_comm (int, optional): Number of current comms. If not provided, it is determined to be the present number of comms. ncurr_thread (int, optional): Number of current threads. If not provided, it is determined to be the present number of threads. ncurr_fd (int, optional): Number of current open file descriptors. If not provided, it is determined to be the present number of open file descriptors. """ self._teardown_complete = True x = CisClass('dummy', timeout=self.timeout, sleeptime=self.sleeptime) # Give comms time to close if ncurr_comm is None: Tout = x.start_timeout() while ((not Tout.is_out) and (self.comm_count > self.nprev_comm)): # pragma: debug x.sleep() x.stop_timeout() ncurr_comm = self.comm_count self.assert_less_equal(ncurr_comm, self.nprev_comm) # Give threads time to close if ncurr_thread is None: Tout = x.start_timeout() while ((not Tout.is_out) and (self.thread_count > self.nprev_thread)): # pragma: debug x.sleep() x.stop_timeout() ncurr_thread = self.thread_count self.assert_less_equal(ncurr_thread, self.nprev_thread) # Give files time to close self.cleanup_comms() if ncurr_fd is None: if not self._first_test: Tout = x.start_timeout() while ((not Tout.is_out) and (self.fd_count > self.nprev_fd)): # pragma: debug x.sleep() x.stop_timeout() ncurr_fd = self.fd_count fds_created = ncurr_fd - self.nprev_fd # print("FDS CREATED: %d" % fds_created) if not self._first_test: self.assert_equal(fds_created, 0) # Reset the log, encoding, and default comm self.reset_log() self.reset_encoding() self.reset_default_comm() self._first_test = False
@property def cleanup_comm_classes(self): r"""list: Comm classes that should be cleaned up following the test.""" return [get_default_comm()]
[docs] def cleanup_comms(self): r"""Cleanup all comms.""" for k in self.cleanup_comm_classes: cleanup_comms(k)
@property def description_prefix(self): r"""String prefix to prepend docstr test message with.""" return self._description_prefix
[docs] def shortDescription(self): r"""Prefix first line of doc string.""" out = super(CisTestBase, self).shortDescription() if self.description_prefix: out = '%s: %s' % (self.description_prefix, out) return out
[docs] def check_file_exists(self, fname): r"""Check that a file exists. Args: fname (str): Full path to the file that should be checked. """ Tout = self.start_timeout(2) while (not Tout.is_out) and (not os.path.isfile(fname)): # pragma: debug self.sleep() self.stop_timeout() if not os.path.isfile(fname): # pragma: debug raise AssertionError("File '%s' dosn't exist." % fname)
[docs] def check_file_size(self, fname, fsize): r"""Check that file is the correct size. Args: fname (str): Full path to the file that should be checked. fsize (int): Size that the file should be in bytes. """ Tout = self.start_timeout(2) if (os.stat(fname).st_size != fsize): # pragma: debug print('file sizes not equal', os.stat(fname).st_size, fsize) while ((not Tout.is_out) and (os.stat(fname).st_size != fsize)): # pragma: debug self.sleep() self.stop_timeout() if os.stat(fname).st_size != fsize: # pragma: debug raise AssertionError("File size (%d), dosn't match expected size (%d)." % (os.stat(fname).st_size, fsize))
[docs] def check_file_contents(self, fname, result): r"""Check that the contents of a file are correct. Args: fname (str): Full path to the file that should be checked. result (str): Contents of the file. """ with open(fname, 'r') as fd: ocont = fd.read() if ocont != result: # pragma: debug odiff = '\n'.join(list(difflib.Differ().compare(ocont, result))) raise AssertionError(('File contents do not match expected result.' 'Diff:\n%s') % odiff)
[docs] def check_file(self, fname, result): r"""Check that a file exists, is the correct size, and has the correct contents. Args: fname (str): Full path to the file that should be checked. result (str): Contents of the file. """ self.check_file_exists(fname) self.check_file_size(fname, len(result)) self.check_file_contents(fname, result)
[docs]class CisTestClass(CisTestBase): r"""Test class for a CisClass.""" testing_option_kws = {} _mod = None _cls = None def __init__(self, *args, **kwargs): self._inst_args = list() self._inst_kwargs = dict() super(CisTestClass, self).__init__(*args, **kwargs)
[docs] def setup(self, *args, **kwargs): r"""Create an instance of the class.""" super(CisTestClass, self).setup(*args, **kwargs) self._instance = self.create_instance()
[docs] def teardown(self, *args, **kwargs): r"""Remove the instance.""" self.clear_instance() super(CisTestClass, self).teardown(*args, **kwargs)
@property def description_prefix(self): r"""String prefix to prepend docstr test message with.""" if self.cls is None: return super(CisTestClass, self).description_prefix else: return self.cls @property def cls(self): r"""str: Class to be tested.""" return self._cls @property def mod(self): r"""str: Absolute name of module containing class to be tested.""" return self._mod @property def inst_args(self): r"""list: Arguments for creating a class instance.""" return self._inst_args @property def inst_kwargs(self): r"""dict: Keyword arguments for creating a class instance.""" out = self._inst_kwargs return out @property def import_cls(self): r"""Import the tested class from its module""" if self.mod is None: raise Exception("No module registered.") if self.cls is None: raise Exception("No class registered.") mod = importlib.import_module(self.mod) cls = getattr(mod, self.cls) return cls
[docs] def get_options(self): r"""Get testing options.""" if self.mod is None: # pragma: debug return {} return self.import_cls.get_testing_options(**self.testing_option_kws)
@property def testing_options(self): r"""dict: Testing options.""" if getattr(self, '_testing_options', None) is None: self._testing_options = self.get_options() return self._testing_options @property def instance(self): r"""object: Instance of the test driver.""" if self._teardown_complete: raise RuntimeError("Instance referenced after teardown.") if not hasattr(self, '_instance'): # pragma: debug self._instance = self.create_instance() return self._instance
[docs] def create_instance(self): r"""Create a new instance of the class.""" inst = self.import_cls(*self.inst_args, **self.inst_kwargs) # print("created instance") return inst
[docs] def remove_instance(self, inst): r"""Remove an instance of the class.""" # print("removed instance") pass
[docs] def clear_instance(self): r"""Clear the instance.""" if hasattr(self, '_instance'): inst = self._instance self._instance = None self.remove_instance(inst) delattr(self, '_instance')
[docs]class IOInfo(object): r"""Simple class for useful IO attributes.""" def __init__(self): self.field_names = ['name', 'count', 'size'] self.field_units = ['n/a', 'umol', 'cm'] self.nfields = len(self.field_names) self.comment = b'# ' self.delimiter = b'\t' self.newline = b'\n' self.field_names = [backwards.as_bytes(x) for x in self.field_names] self.field_units = [backwards.as_bytes(x) for x in self.field_units]
[docs]class CisTestClassInfo(CisTestClass, IOInfo): r"""Test class for a CisClass with IOInfo available.""" def __init__(self, *args, **kwargs): super(CisTestClassInfo, self).__init__(*args, **kwargs) IOInfo.__init__(self)
class MagicTestError(Exception): r"""Special exception for testing.""" pass
[docs]def ErrorClass(base_class, *args, **kwargs): r"""Wrapper to return errored version of a class. Args: base_class (class): Base class to use. *args: Additional arguments are passed to the class constructor. **kwargs: Additional keyword arguments are passed to the class constructor. """ class ErrorClass(base_class): r"""Dummy class that will raise an error for any requested method. Args: error_on_init (bool, optional): If True, an error will be raised in place of the base class's __init__ method. Defaults to False. *args: Additional arguments are passed to the parent class. **kwargs: Additional keyword arguments are passed to the parent class. Attributes: error_location (str): Name of the method/attribute that will raise an error. """ def __init__(self, *args, **kwargs): error_on_init = kwargs.pop('error_on_init', False) if error_on_init: self.error_method() self._replaced_methods = dict() super(ErrorClass, self).__init__(*args, **kwargs) def empty_method(self, *args, **kwargs): r"""Method that won't do anything.""" pass def error_method(self, *args, **kwargs): r"""Method that will raise a MagicTestError.""" raise MagicTestError("This is a test error.") def getattr(self, attr): r"""Get the underlying object for an attribute name.""" for obj in [self] + self.__class__.mro(): if attr in obj.__dict__: return obj.__dict__[attr] raise AttributeError # pragma: debug def setattr(self, attr, value): r"""Set the attribute at the class level.""" setattr(self.__class__, attr, value) def replace_method(self, method_name, replacement): r"""Temporarily replace method with another.""" self._replaced_methods[method_name] = self.getattr(method_name) self.setattr(method_name, replacement) def restore_method(self, method_name): r"""Restore the original method.""" self.setattr(method_name, self._replaced_methods.pop(method_name)) def restore_all(self): r"""Restored all replaced methods.""" meth_list = list(self._replaced_methods.keys()) for k in meth_list: self.restore_method(k) def empty_replace(self, method_name, **kwargs): r"""Replace a method with an empty method.""" self.replace_method(method_name, self.empty_method, **kwargs) def error_replace(self, method_name, **kwargs): r"""Replace a method with an errored method.""" self.replace_method(method_name, self.error_method, **kwargs) return ErrorClass(*args, **kwargs)
__all__ = ['data', 'scripts', 'yamls', 'IOInfo', 'ErrorClass', 'CisTestBase', 'CisTestClass', 'CisTestBaseInfo', 'CisTestClassInfo']