import os
import uuid
from cis_interface import backwards
from cis_interface.tests import CisTestClassInfo, assert_equal
from cis_interface.communication import new_comm, get_comm, CommBase
[docs]def test_registry():
r"""Test registry of comm."""
comm_class = 'CommBase'
key = 'key1'
value = None
assert(not CommBase.is_registered(comm_class, key))
assert(not CommBase.unregister_comm(comm_class, key))
assert_equal(CommBase.get_comm_registry(None), {})
assert_equal(CommBase.get_comm_registry(comm_class), {})
CommBase.register_comm(comm_class, key, value)
assert(key in CommBase.get_comm_registry(comm_class))
assert(CommBase.is_registered(comm_class, key))
assert(not CommBase.unregister_comm(comm_class, key, dont_close=True))
CommBase.register_comm(comm_class, key, value)
assert(not CommBase.unregister_comm(comm_class, key))
[docs]class TestCommBase(CisTestClassInfo):
r"""Tests for CommBase communication class.
Attributes:
send_inst_kwargs (dict): Keyword arguments for send half of the comm
pair.
"""
comm = 'CommBase'
attr_list = ['name', 'address', 'direction',
'serializer', 'recv_timeout',
'close_on_eof_recv', 'opp_address', 'opp_comms',
'maxMsgSize']
@property
def cleanup_comm_classes(self):
r"""list: Comm classes that should be cleaned up following the test."""
return set([self.comm, self.send_inst_kwargs['comm']])
@property
def name(self):
r"""str: Name of the test connection."""
return 'Test%s_%s' % (self.cls, self.uuid)
@property
def cls(self):
r"""str: Communication class."""
return self.comm
@property
def mod(self):
r"""str: Absolute module import."""
return 'cis_interface.communication.%s' % self.cls
@property
def is_installed(self):
r"""bool: Is the communication class installed."""
return self.import_cls.is_installed(language='python')
@property
def send_inst_kwargs(self):
r"""dict: Keyword arguments for send instance."""
out = {'comm': self.comm, 'reverse_names': True, 'direction': 'send'}
out.update(self.testing_options['kwargs'])
return out
@property
def inst_args(self):
r"""list: Arguments for tested class."""
return [self.name]
@property
def inst_kwargs(self):
r"""dict: Keyword arguments for tested class."""
return self.send_instance.opp_comm_kwargs()
@property
def recv_instance(self):
r"""Alias for instance."""
return self.instance
@property
def maxMsgSize(self):
r"""int: Maximum message size."""
return self.instance.maxMsgSize
@property
def test_msg_array(self):
r"""str: Test message that should be used for any send/recv tests."""
return self.testing_options.get('msg_array', None)
@property
def test_msg(self):
r"""str: Test message that should be used for any send/recv tests."""
return self.testing_options['msg']
@property
def msg_long(self):
r"""str: Small test message for sending."""
out = self.test_msg
if isinstance(out, backwards.bytes_type):
out += (self.maxMsgSize * b'0')
return out
[docs] def setup(self, *args, **kwargs):
r"""Initialize comm object pair."""
assert(self.is_installed)
sleep_after_connect = kwargs.pop('sleep_after_connect', False)
send_inst_kwargs = self.send_inst_kwargs
kwargs.setdefault('nprev_comm', self.comm_count)
kwargs.setdefault('nprev_fd', self.fd_count)
self.send_instance = new_comm(self.name, **send_inst_kwargs)
super(TestCommBase, self).setup(*args, **kwargs)
if sleep_after_connect:
self.send_instance.sleep()
# CommBase is dummy class that never opens
if self.comm in ['CommBase', 'AsyncComm']:
assert(not self.send_instance.is_open)
assert(not self.recv_instance.is_open)
else:
assert(self.send_instance.is_open)
assert(self.recv_instance.is_open)
[docs] def teardown(self, *args, **kwargs):
r"""Destroy comm object pair."""
self.remove_instance(self.send_instance)
super(TestCommBase, self).teardown(*args, **kwargs)
[docs] def create_instance(self):
r"""Create a new instance of the class."""
inst = get_comm(*self.inst_args, **self.inst_kwargs)
assert(isinstance(inst, self.import_cls))
return inst
[docs] def remove_instance(self, inst):
r"""Remove an instance."""
inst.close()
assert(inst.is_closed)
super(TestCommBase, self).remove_instance(inst)
[docs] def get_fresh_error_instance(self, recv=False):
r"""Get comm instance with ErrorClass parent class."""
send_kwargs = self.send_inst_kwargs
err_kwargs = dict(base_comm=send_kwargs['comm'], new_comm_class='ErrorComm')
err_name = self.name + '_' + self.uuid
if not recv:
send_kwargs.update(**err_kwargs)
send_inst = new_comm(err_name, **send_kwargs)
recv_kwargs = send_inst.opp_comm_kwargs()
recv_kwargs['comm'] = send_kwargs['comm']
if recv:
recv_kwargs.update(**err_kwargs)
recv_inst = new_comm(err_name, **recv_kwargs)
return send_inst, recv_inst
[docs] def test_empty_obj_recv(self):
r"""Test identification of empty message."""
msg = self.instance.empty_obj_recv
assert(self.instance.is_empty_recv(msg))
assert(not self.instance.is_empty_recv(self.instance.eof_msg))
if self.recv_instance.recv_converter is None:
self.recv_instance.recv_converter = lambda x: x
msg = self.instance.empty_obj_recv
assert(self.instance.is_empty_recv(msg))
assert(not self.instance.is_empty_recv(self.instance.eof_msg))
[docs] def test_error_name(self):
r"""Test error on missing address."""
self.assert_raises(RuntimeError, self.import_cls, 'test%s' % uuid.uuid4())
[docs] def test_error_send(self):
r"""Test error on send."""
send_inst, recv_inst = self.get_fresh_error_instance()
send_inst._first_send_done = True
send_inst.error_replace('send_multipart')
flag = send_inst.send(self.test_msg)
assert(not flag)
send_inst.restore_all()
send_inst.close()
recv_inst.close()
[docs] def test_error_recv(self):
r"""Test error on recv."""
self.fd_count
send_inst, recv_inst = self.get_fresh_error_instance(recv=True)
self.fd_count
recv_inst.error_replace('recv_multipart')
flag, msg_recv = recv_inst.recv()
self.fd_count
assert(not flag)
recv_inst.restore_all()
send_inst.close()
recv_inst.close()
self.fd_count
[docs] def test_send_recv_after_close(self):
r"""Test that opening twice dosn't cause errors and that send/recv after
close returns false."""
self.send_instance.open()
self.recv_instance.open()
if self.comm in ['RMQComm', 'RMQAsyncComm']:
self.send_instance.bind()
self.recv_instance.bind()
self.send_instance.close()
self.recv_instance.close()
assert(self.send_instance.is_closed)
assert(self.recv_instance.is_closed)
flag = self.send_instance.send(self.test_msg)
assert(not flag)
flag, msg_recv = self.recv_instance.recv()
assert(not flag)
[docs] def test_attributes(self):
r"""Assert that the instance has all of the required attributes."""
for a in self.attr_list:
if not hasattr(self.send_instance, a): # pragma: debug
raise AttributeError("Send comm does not have attribute %s" % a)
if not hasattr(self.recv_instance, a): # pragma: debug
raise AttributeError("Recv comm does not have attribute %s" % a)
getattr(self.send_instance, a)
getattr(self.recv_instance, a)
self.instance.debug('maxMsgSize: %d, %d, %d', self.maxMsgSize,
self.send_instance.maxMsgSize,
self.recv_instance.maxMsgSize)
self.instance.opp_comm_kwargs()
if self.import_cls.is_file:
assert(self.import_cls.is_installed(language='invalid'))
else:
assert(not self.import_cls.is_installed(language='invalid'))
[docs] def test_invalid_direction(self):
r"""Check that error raised for invalid direction."""
kwargs = self.send_inst_kwargs
kwargs['direction'] = 'invalid'
self.assert_raises(ValueError, new_comm, self.name + "_" + self.uuid,
**kwargs)
[docs] def test_work_comm(self):
r"""Test creating/removing a work comm."""
wc_send = self.instance.create_work_comm()
self.assert_raises(KeyError, self.instance.add_work_comm, wc_send)
# Create recv instance in way that tests new_comm
header_recv = dict(id=self.uuid + '1', address=wc_send.address)
recv_kwargs = self.instance.get_work_comm_kwargs
recv_kwargs['work_comm_name'] = 'test_worker_%s' % header_recv['id']
recv_kwargs['new_comm_class'] = wc_send.comm_class
os.environ[recv_kwargs['work_comm_name']] = wc_send.opp_address
wc_recv = self.instance.create_work_comm(**recv_kwargs)
# wc_recv = self.instance.get_work_comm(header_recv)
if self.comm in ['CommBase', 'AsyncComm']:
flag = wc_send.send(self.test_msg)
assert(not flag)
flag, msg_recv = wc_recv.recv()
assert(not flag)
else:
flag = wc_send.send(self.test_msg)
assert(flag)
flag, msg_recv = wc_recv.recv(self.timeout)
assert(flag)
self.assert_equal(msg_recv, self.test_msg)
# Assert errors on second attempt
# self.assert_raises(RuntimeError, wc_send.send, self.test_msg)
self.assert_raises(RuntimeError, wc_recv.recv)
self.instance.remove_work_comm(wc_send.uuid)
self.instance.remove_work_comm(wc_recv.uuid)
self.instance.remove_work_comm(wc_recv.uuid)
# Create work comm that should be cleaned up on teardown
self.instance.create_work_comm()
[docs] def map_sent2recv(self, obj):
r"""Convert a sent object into a received one."""
return obj
[docs] def assert_msg_equal(self, x, y):
r"""Assert that two messages are equivalent."""
if not (isinstance(y, type(self.send_instance.eof_msg))
and (y == self.send_instance.eof_msg)):
y = self.map_sent2recv(y)
self.assert_equal(x, y)
[docs] def do_send_recv(self, send_meth='send', recv_meth='recv',
msg_send=None, msg_recv=None,
n_msg_send_meth='n_msg_send', n_msg_recv_meth='n_msg_recv',
reverse_comms=False, send_kwargs=None, recv_kwargs=None,
n_send=1, n_recv=1,
close_on_send_eof=None, close_on_recv_eof=None):
r"""Generic send/recv of a message."""
tkey = 'do_send_recv'
is_eof = ('eof' in send_meth)
if msg_send is None:
if is_eof:
msg_send = self.send_instance.eof_msg
else:
msg_send = self.test_msg
if msg_recv is None:
msg_recv = msg_send
if send_kwargs is None:
send_kwargs = dict()
if recv_kwargs is None:
recv_kwargs = dict()
if is_eof:
send_args = tuple()
else:
send_args = (msg_send,)
self.assert_equal(getattr(self.send_instance, n_msg_send_meth), 0)
self.assert_equal(getattr(self.recv_instance, n_msg_recv_meth), 0)
if reverse_comms:
send_instance = self.recv_instance
recv_instance = self.send_instance
else:
send_instance = self.send_instance
recv_instance = self.recv_instance
if close_on_recv_eof is None:
close_on_recv_eof = recv_instance.close_on_eof_recv
if close_on_send_eof is None:
close_on_send_eof = send_instance.close_on_eof_send
recv_instance.close_on_eof_recv = close_on_recv_eof
send_instance.close_on_eof_send = close_on_send_eof
if self.comm == 'ForkComm':
for x in recv_instance.comm_list:
x.close_on_eof_recv = close_on_recv_eof
for x in send_instance.comm_list:
x.close_on_eof_send = close_on_send_eof
fsend_meth = getattr(send_instance, send_meth)
frecv_meth = getattr(recv_instance, recv_meth)
if self.comm in ['CommBase', 'AsyncComm']:
flag = fsend_meth(*send_args, **send_kwargs)
assert(not flag)
flag, msg_recv0 = frecv_meth(**recv_kwargs)
assert(not flag)
if self.comm == 'CommBase':
self.assert_raises(NotImplementedError, self.recv_instance._send,
self.test_msg)
self.assert_raises(NotImplementedError, self.recv_instance._recv)
else:
for i in range(n_send):
flag = fsend_meth(*send_args, **send_kwargs)
assert(flag)
# Wait for messages to be received
for i in range(n_recv):
if not is_eof:
T = recv_instance.start_timeout(self.timeout, key_suffix=tkey)
while ((not T.is_out) and (not recv_instance.is_closed)
and (getattr(recv_instance,
n_msg_recv_meth) == 0)): # pragma: debug
recv_instance.sleep()
recv_instance.stop_timeout(key_suffix=tkey)
assert(getattr(recv_instance, n_msg_recv_meth) >= 1)
# IPC nolimit sends multiple messages
# self.assert_equal(recv_instance.n_msg_recv, 1)
flag, msg_recv0 = frecv_meth(timeout=self.timeout, **recv_kwargs)
if is_eof and close_on_recv_eof:
assert(not flag)
assert(recv_instance.is_closed)
else:
assert(flag)
self.assert_msg_equal(msg_recv0, msg_recv)
# Wait for send to close
if is_eof and close_on_send_eof:
T = send_instance.start_timeout(self.timeout, key_suffix=tkey)
while (not T.is_out) and (not send_instance.is_closed): # pragma: debug
send_instance.sleep()
send_instance.stop_timeout(key_suffix=tkey)
assert(send_instance.is_closed)
# Make sure no messages outgoing
T = send_instance.start_timeout(self.timeout, key_suffix=tkey)
while ((not T.is_out) and (getattr(send_instance,
n_msg_send_meth) != 0)): # pragma: debug
send_instance.sleep()
send_instance.stop_timeout(key_suffix=tkey)
# Print status of comms
send_instance.printStatus()
recv_instance.printStatus()
# Confirm recept of messages
if not (is_eof or reverse_comms):
send_instance.wait_for_confirm(timeout=self.timeout)
recv_instance.wait_for_confirm(timeout=self.timeout)
assert(send_instance.is_confirmed)
assert(recv_instance.is_confirmed)
send_instance.confirm(noblock=True)
recv_instance.confirm(noblock=True)
self.assert_equal(getattr(send_instance, n_msg_send_meth), 0)
self.assert_equal(getattr(recv_instance, n_msg_recv_meth), 0)
[docs] def test_drain_messages(self):
r"""Test waiting for messages to drain."""
self.send_instance.drain_messages(timeout=self.timeout)
self.assert_equal(self.send_instance.n_msg_send_drain, 0)
if not self.recv_instance.is_file:
self.recv_instance.drain_messages(timeout=self.timeout)
self.assert_equal(self.recv_instance.n_msg_recv_drain, 0)
self.assert_raises(ValueError, self.send_instance.drain_messages,
variable='n_msg_invalid')
self.assert_raises(ValueError, self.recv_instance.drain_messages,
variable='n_msg_invalid')
[docs] def test_recv_nomsg(self):
r"""Test recieve when there is no waiting message."""
flag, msg_recv = self.recv_instance.recv(timeout=self.sleeptime)
if self.comm in ['CommBase', 'AsyncComm']:
assert(not flag)
else:
assert(flag)
assert(not msg_recv)
[docs] def test_send_recv(self):
r"""Test send/recv of a small message."""
self.do_send_recv()
[docs] def test_send_recv_nolimit(self):
r"""Test send/recv of a large message."""
assert(len(self.msg_long) > self.maxMsgSize)
self.do_send_recv('send_nolimit', 'recv_nolimit', self.msg_long)
[docs] def test_send_recv_array(self):
r"""Test send/recv of a array message."""
msg_send = getattr(self, 'test_msg_array', None)
self.do_send_recv('send_array', 'recv_array', msg_send=msg_send)
[docs] def test_eof(self):
r"""Test send/recv of EOF message."""
self.do_send_recv(send_meth='send_eof')
[docs] def test_eof_no_close(self):
r"""Test send/recv of EOF message with no close."""
self.do_send_recv(send_meth='send_eof', close_on_recv_eof=False)
[docs] def test_purge(self, nrecv=1):
r"""Test purging messages from the comm."""
self.assert_equal(self.send_instance.n_msg, 0)
self.assert_equal(self.recv_instance.n_msg, 0)
# Purge recv while open
if self.comm not in ['CommBase', 'AsyncComm']:
flag = self.send_instance.send(self.test_msg)
assert(flag)
T = self.recv_instance.start_timeout()
while ((not T.is_out) and (self.recv_instance.n_msg
!= nrecv)): # pragma: debug
self.recv_instance.sleep()
self.recv_instance.stop_timeout()
self.assert_greater(self.recv_instance.n_msg, 0)
self.recv_instance.purge()
# Uni-directional comms can't know about messages sent
# self.assert_equal(self.send_instance.n_msg, 0)
self.assert_equal(self.recv_instance.n_msg, 0)
# Purge recv while closed
self.recv_instance.close()
self.recv_instance.purge()
[docs] def test_send_recv_dict(self):
r"""Test send/recv message as dict."""
msg_send = self.testing_options['dict']
self.do_send_recv(send_meth='send_dict', recv_meth='recv_dict',
msg_send=msg_send)
field_order = self.testing_options.get('field_names', None)
if field_order is not None:
self.do_send_recv(send_meth='send_dict', recv_meth='recv_dict',
msg_send=msg_send,
send_kwargs={'field_order': field_order},
recv_kwargs={'field_order': field_order})