import os
import sys
import signal
import uuid
import json
import traceback
import yaml
import glob
import pprint
import functools
import threading
import logging
from yggdrasil import runner
from yggdrasil import platform
from yggdrasil.multitasking import (
wait_on_function, ValueEvent, MemoryTracker)
from yggdrasil.tools import YggClass, kill
from yggdrasil.config import ygg_cfg
_service_host_env = 'YGGDRASIL_SERVICE_HOST_URL'
_service_repo_dir = 'YGGDRASIL_SERVICE_REPO_DIR'
_default_service_type = ygg_cfg.get('services', 'default_type', 'flask')
_default_commtype = ygg_cfg.get('services', 'default_comm', None)
_default_address = ygg_cfg.get('services', 'address', None)
_client_id = ygg_cfg.get('services', 'client_id', None)
if platform._is_win: # pragma: windows
_shutdown_signal = signal.SIGBREAK
else:
_shutdown_signal = signal.SIGTERM
[docs]class ClientError(BaseException):
r"""Error raised by errors when calling the server from the client."""
pass
[docs]class ServerError(BaseException):
r"""Error raised when there was an error on the server."""
pass
[docs]class ServiceBase(YggClass):
r"""Base class for sending/responding to service requests.
Args:
name (str): Name that should be used to initialize an address for the
service.
for_request (bool, optional): If True, a client-side connection is
initialized. If False a server-side connection is initialized.
Defaults to False.
address (str, optional): The address that the service can be accessed
from. Defaults to ('services', 'address') configuration option, if
set, and None if not.
*args: Additional arguments are used to initialize the client/server
connection.
**kwargs: Additional keyword arguments are used to initialize the
client/server connection.
"""
default_address = None
default_port = None
def __init__(self, name, *args, **kwargs):
self.for_request = kwargs.pop('for_request', False)
self.address = kwargs.pop('address', None)
self.port = kwargs.pop('port', None)
if self.address is None:
self.address = _default_address
if self.address is None:
self.address = self.default_address
if self.port is None:
self.port = self.default_port
if isinstance(self.address, str) and ('{port}' in self.address):
self.address = self.address.format(port=self.port)
self._args = args
self._kwargs = kwargs
self._is_running = False
super(ServiceBase, self).__init__(name, *args, **kwargs)
if self.for_request:
self.setup_client(*args, **kwargs)
else:
self.setup_server(*args, **kwargs)
[docs] @classmethod
def is_installed(cls):
r"""bool: True if the class is fully installed, False otherwise."""
return False # pragma: no cover
@property
def is_running(self):
r"""bool: True if the server is running."""
if self.for_request:
return True
else:
return self._is_running
[docs] def wait_for_server(self, timeout=15.0):
r"""Wait for a service to start running.
Args:
timeout (float, optional): Time (in seconds) that should be waited
for the server to start. Defaults to 15.
Raises:
RuntimeError: If the time limit is reached and the server still
hasn't started.
"""
wait_on_function(lambda: self.is_running, timeout=timeout,
on_timeout="Server never started")
[docs] def setup_server(self, *args, **kwargs):
r"""Set up the machinery for receiving requests."""
raise NotImplementedError # pragma: no cover
[docs] def setup_client(self, *args, **kwargs):
r"""Set up the machinery for sending requests."""
raise NotImplementedError # pragma: no cover
[docs] def set_log_level(self, log_level):
r"""Set the logging level.
Args:
log_level (int): Logging level.
"""
import logging
logging.basicConfig(level=log_level)
[docs] def start_server(self, remote_url=None, with_coverage=False,
log_level=None, model_repository=None,
track_memory=False):
r"""Start the server.
Args:
remote_url (str optional): Address for the URL that remote
integrations will use to connect to this server. Defaults to
None and is set based on the YGGDRASIL_SERVICE_HOST_URL
environment variable if it is set and is the local address
otherwise.
with_coverage (bool, optional): If True, the server is started
with coverage. Defaults to False.
log_level (int, optional): Level of log messages that should be
printed. Defaults to None and is ignored.
model_repository (str, optional): URL of directory in a Git
repository containing YAMLs that should be added to the model
registry. Defaults to None and is ignored.
track_memory (boolean, optional): If True, the memory used
by the server will be reported at shutdown. Defaults
to False.
"""
if remote_url is None:
remote_url = os.environ.get(_service_host_env, None)
if remote_url is None:
remote_url = self.address
if model_repository is not None:
repo_dir = self.registry.add_from_repository(model_repository)
os.environ.setdefault(_service_repo_dir, repo_dir)
os.environ.setdefault(_service_host_env, remote_url)
if log_level is not None:
self.set_log_level(log_level)
if with_coverage: # pragma: testing
def handle_shutdown(sig, frame):
self.cleanup_server(track_memory=track_memory)
sys.exit()
signal.signal(_shutdown_signal, handle_shutdown)
# try:
# from pytest_cov.embed import cleanup_on_signal
# cleanup_on_signal(_shutdown_signal)
# except ImportError: # pragma: debug
# pass
self._is_running = True
if track_memory:
track_memory = MemoryTracker(os.getpid())
track_memory.start()
try:
self.run_server()
finally:
self.cleanup_server(track_memory=track_memory)
[docs] def cleanup_server(self, track_memory=False):
r"""Cleanup server process after it finishes.
Args:
track_memory (MemoryTracker, optional): If provided,
print information about memory usage.
"""
if track_memory and self._is_running:
track_memory.terminate()
print(f"Max memory usage: {track_memory.max_memory} MB")
self._is_running = False
[docs] def run_server(self):
r"""Begin listening for requests."""
raise NotImplementedError # pragma: no cover
[docs] def respond(self, request, **kwargs):
r"""Create a response to the request."""
raise NotImplementedError # pragma: no cover
[docs] def shutdown(self, *args, **kwargs):
r"""Shutdown the process from the server."""
raise NotImplementedError # pragma: no cover
[docs] def process_request(self, request, **kwargs):
r"""Process a request and return a response.
Args:
request (str): Serialized request that should be processed.
**kwargs: Additional keyword arguments are passed to the respond
method.
Returns:
str: Serialized response.
"""
request = self.deserialize_request(request)
response = self.respond(request, **kwargs)
return self.serialize_response(response)
[docs] def process_response(self, response):
r"""Process a response.
Args:
response (str): Serialized response that should be processed.
Returns:
object: The deserialized, processed response.
"""
return self.deserialize(response)
[docs] def deserialize(self, msg):
r"""Deserialize a message.
Args:
msg (str): Message to deserialize.
Returns:
object: Deserialized message.
"""
return json.loads(msg)
[docs] def serialize(self, msg):
r"""Serialize a message.
Args:
msg (object): Message to serialize.
Returns:
str: Serialized message.
"""
return json.dumps(msg)
[docs] def deserialize_request(self, request):
r"""Deserialize a request message.
Args:
request (str): Serialized request.
Returns:
object: Deserialized request.
"""
return self.deserialize(request)
[docs] def serialize_response(self, response):
r"""Serialize a response message.
Args:
request (object): Request to serialize.
Returns:
str: Serialized request.
"""
return self.serialize(response)
[docs] def call(self, *args, **kwargs):
r"""Send a request."""
raise NotImplementedError # pragma: no cover
[docs] def send_request(self, request, **kwargs):
r"""Send a request.
Args:
request (object): Request to send.
**kwargs: Additional keyword arguments are passed to the call
method.
Returns:
object: Response.
"""
request_str = self.serialize(request)
assert self.for_request
# if not self.for_request:
# x = self.__class__(self.name, *self._args,
# **self._kwargs, for_request=True,
# address=self.opp_address)
# else:
# x = self
return self.process_response(self.call(request_str, **kwargs))
[docs]class FlaskService(ServiceBase):
r"""Flask based service."""
service_type = 'flask'
default_commtype = 'rest'
default_address = 'http://localhost:{port}'
default_port = int(os.environ.get("PORT", 5000))
[docs] @classmethod
def is_installed(cls):
r"""bool: True if the class is fully installed, False otherwise."""
try:
import flask # noqa: F401
return True
except ImportError: # pragma: debug
return False
def __init__(self, *args, **kwargs):
super(FlaskService, self).__init__(*args, **kwargs)
# if str(self.port) not in self.address:
# parts = self.address.split(':')
# if parts[-1].strip('/').isdigit():
# self.port = int(parts[-1].strip('/'))
if not self.address.endswith('/'):
self.address += '/'
[docs] def setup_server(self, *args, **kwargs):
r"""Set up the machinery for receiving requests.
Args:
*args: Arguments are ignored.
**kwargs: Keyword arguments are ignored.
"""
from flask import Flask
from flask import request
from flask import jsonify
self.queue = {}
self.request = request
self.jsonify = jsonify
self.app = Flask(__name__)
@self.app.route('/' + self.name, methods=['POST'])
def _target(*req_args):
return self.process_request(self.request.json, args=req_args)
[docs] def setup_client(self, *args, **kwargs):
r"""Set up the machinery for sending requests."""
pass
[docs] def set_log_level(self, log_level):
r"""Set the logging level.
Args:
log_level (int): Logging level.
"""
super(FlaskService, self).set_log_level(log_level)
from flask.logging import default_handler
werkzeug_logger = logging.getLogger('werkzeug')
default_handler.setLevel(level=log_level)
self.app.logger.setLevel(level=log_level)
werkzeug_logger.setLevel(level=log_level)
[docs] def run_server(self):
r"""Begin listening for requests."""
self.app.run(host='0.0.0.0', port=self.port)
[docs] def shutdown(self):
r"""Shutdown the process from the server."""
# if not self.for_request:
# func = self.request.environ.get('werkzeug.server.shutdown')
# # Explicitly cleaning up the pytest coverage plugin is required
# # to ensure that the server methods are properly covered during
# # cleanup.
# try:
# from pytest_cov.embed import cleanup
# cleanup()
# except ImportError: # pragma: debug
# pass
# if func is None: # pragma: debug
# raise RuntimeError('Not running with the Werkzeug Server')
# func() # pragma: no cover
pass
[docs] def deserialize(self, msg):
r"""Deserialize a message.
Args:
msg (str): Message to deserialize.
Returns:
object: Deserialized message.
"""
return msg # should already be deserialized
[docs] def serialize(self, msg):
r"""Serialize a message.
Args:
msg (object): Message to serialize.
Returns:
str: Serialized message.
"""
return msg # should already be serialized
[docs] def serialize_response(self, response):
r"""Serialize a response message.
Args:
request (object): Request to serialize.
Returns:
str: Serialized request.
"""
return self.jsonify(response)
[docs] def call(self, request, **kwargs):
r"""Send a request.
Args:
request (object): JSON serializable request.
**kwargs: Keyword arguments are ignored.
Returns:
object: Response.
"""
import requests
try:
r = requests.post(self.address + self.name, json=request)
r.raise_for_status()
except BaseException as e:
raise ClientError(e)
return r.json()
[docs]class RMQService(ServiceBase):
r"""RabbitMQ based service."""
service_type = 'rmq'
default_commtype = 'rmq'
default_port = 5672
[docs] @classmethod
def is_installed(cls):
r"""bool: True if the class is fully installed, False otherwise."""
from yggdrasil.communication.RMQComm import check_rmq_server
return check_rmq_server()
def _init_rmq(self, *args, **kwargs):
from yggdrasil.communication.RMQComm import pika, get_rmq_parameters
self.pika = pika
if not self.address:
kwargs['port'] = self.port
self.address, _, _ = get_rmq_parameters(*args, **kwargs)
self.queue = self.name
# Unclear why using a non-default exchange prevents the server
# from starting
self.exchange = ''
parameters = pika.URLParameters(self.address)
self.connection = pika.BlockingConnection(parameters)
self.channel = self.connection.channel()
[docs] def setup_server(self, *args, **kwargs):
r"""Set up the machinery for receiving requests.
Args:
*args: Arguments are used to initialize the RabbitMQ connections
via _init_rmq.
**kwargs: Keyword arguments are used to initialize the RabbitMQ
connections via _init_rmq.
"""
self._init_rmq(*args, **kwargs)
if self.exchange: # pragma: debug
# self.channel.exchange_declare(exchange=self.exchange,
# auto_delete=True)
raise NotImplementedError("There is a bug when using the "
"non-default exchange.")
self.channel.queue_declare(queue=self.queue)
self.channel.basic_qos(prefetch_count=1)
self.consumer_tag = self.channel.basic_consume(
queue=self.queue,
on_message_callback=self._on_request)
cb = functools.partial(self.shutdown, in_callback=True)
self.channel.add_on_cancel_callback(cb)
[docs] def setup_client(self, *args, **kwargs):
r"""Set up the machinery for sending requests.
Args:
*args: Arguments are used to initialize the RabbitMQ connections
via _init_rmq.
**kwargs: Keyword arguments are used to initialize the RabbitMQ
connections via _init_rmq.
"""
self._init_rmq(*args, **kwargs)
result = self.channel.queue_declare(queue='', exclusive=True)
self.callback_queue = result.method.queue
self.consumer_tag = self.channel.basic_consume(
queue=self.callback_queue,
on_message_callback=self._on_response,
auto_ack=True)
[docs] def run_server(self):
r"""Listen for requests."""
try:
self.channel.start_consuming()
except self.pika.exceptions.ChannelWrongStateError: # pragma: debug
pass
[docs] def shutdown(self, in_callback=False):
r"""Shutdown the process from the server."""
if not self.channel: # pragma: debug
return
if self.for_request:
queue = self.callback_queue
in_callback = False
else:
queue = self.queue
if not in_callback:
self.channel.basic_cancel(consumer_tag=self.consumer_tag)
if not self.for_request:
return
self.channel.queue_delete(queue=queue)
self.channel.close()
self.channel = None
self.connection.close()
self.connection = None
def _on_request(self, ch, method, props, body):
response = self.process_request(body)
ch.basic_publish(exchange=self.exchange,
routing_key=props.reply_to,
properties=self.pika.BasicProperties(
correlation_id=props.correlation_id),
body=response)
ch.basic_ack(delivery_tag=method.delivery_tag)
def _on_response(self, ch, method, props, body):
if self.corr_id == props.correlation_id:
self.response.set(body)
@property
def is_running(self):
r"""bool: True if the server is running."""
return (super(RMQService, self).is_running and bool(self.channel))
[docs] def call(self, request, timeout=10.0, **kwargs):
r"""Send a request.
Args:
request (str): Serialized request.
timeout (float, optional): Time (in seconds) that should be waited
for a response to be returned. Defaults to 10.
**kwargs: Keyword arguments are ignored.
Returns:
str: Serialized response.
"""
self.response = ValueEvent()
self.corr_id = str(uuid.uuid4())
try:
self.channel.basic_publish(exchange=self.exchange,
routing_key=self.queue,
properties=self.pika.BasicProperties(
reply_to=self.callback_queue,
correlation_id=self.corr_id),
body=request)
except (self.pika.exceptions.UnroutableError,
self.pika.exceptions.StreamLostError) as e: # pragma: debug
raise ClientError(e)
def process_events():
self.connection.process_data_events()
return self.response.is_set()
def client_error():
raise ClientError("No response received")
if not self.response.is_set():
wait_on_function(
process_events, timeout=timeout, polling_interval=0.5,
on_timeout=client_error)
return self.response.get()
[docs]def create_service_manager_class(service_type=None):
r"""Create an integration manager service with the specified base.
Args:
service_type (ServiceBase, str, optional): Base class that should be
used. Defaults to ('services', 'default_type') configuration
options, if set, and 'flask' if not.
Returns:
type: Subclass of ServiceBase to handle starting/stopping integrations
running as services.
"""
if service_type is None:
service_type = _default_service_type
if isinstance(service_type, str):
cls_map = {'flask': FlaskService, 'rmq': RMQService}
service_type = cls_map[service_type]
class IntegrationServiceManager(service_type):
r"""Manager to track running integrations.
Args:
name (str): Name that should be used to initialize an address for
the service. Defaults to 'ygg_integrations'.
commtype (str, optional): Communicator type that should be used
for the connections to services. Defaults to ('services',
'default_comm') configuration option, if set, and None if not.
is_app (bool, optional): If True, the service manager will be run
as an app and will not be expected to be shut down by clients.
Defaults to False.
**kwargs: Additional keyword arguments are passed to the __init__
method for the service_type class.
"""
def __init__(self, name=None, commtype=None, is_app=False, **kwargs):
if name is None:
name = 'ygg_integrations'
self.integrations = {}
self.stopped_integrations = {}
self.registry = IntegrationServiceRegistry()
if commtype is None:
commtype = _default_commtype
self.commtype = commtype
self.is_app = is_app
super(IntegrationServiceManager, self).__init__(name, **kwargs)
if self.commtype is None:
self.commtype = self.default_commtype
@property
def client_id(self):
r"""str: The ID that should be associated with a client on the
current machine. Defaults to the configuration entry
('services', 'client_id') if it is set and is generated otherwise.
"""
global _client_id
if _client_id is None:
_client_id = str(uuid.uuid4())
return _client_id
def send_request(self, name=None, yamls=None, action='start', **kwargs):
r"""Send a request.
Args:
name (str, tuple, optional): A hashable object that will be
used to reference the integration. If not provided,
the yamls keyword will be used.
yamls (list, str, optional): One or more YAML files defining
a network of models to run as a service. Defaults to None.
action (str, optional): Action that is being requested.
Defaults to 'start'.
**kwargs: Additional keyword arguments are included in the
request.
"""
if isinstance(yamls, str):
yamls = [yamls]
if name is None:
name = yamls
request = dict(kwargs, name=name, yamls=yamls, action=action)
request.setdefault('client_id', self.client_id)
wait_for_complete = ((action in ['start', 'stop', 'shutdown'])
and (service_type != RMQService))
out = super(IntegrationServiceManager, self).send_request(request)
if wait_for_complete and (out['status'] != 'complete'):
def is_complete():
out.update(
super(IntegrationServiceManager, self).send_request(
request))
return (out['status'] == 'complete')
wait_on_function(
is_complete, timeout=30, polling_interval=0.5,
on_timeout=f"Request did not complete: {request}")
return out
def setup_server(self, *args, **kwargs):
r"""Set up the machinery for receiving requests."""
super(IntegrationServiceManager, self).setup_server(*args, **kwargs)
if service_type == FlaskService:
@self.app.route('/')
def landing_page():
from flask import render_template
import yaml
kwargs = {
'address': self.address,
'available': {
k: yaml.dump(v).splitlines()
for k, v in self.registry.registry.items()},
'running': {
k: {k2: v2.printStatus(return_str=True).splitlines()
for k2, v2 in v.items()}
for k, v in self.integrations.items()}}
out = render_template(
'service_manager_index.html', **kwargs)
return out
from yggdrasil.communication import RESTComm
RESTComm.add_comm_server_to_app(self.app)
def stop_server(self):
r"""Stop the server from the client-side."""
assert self.for_request
try:
response = self.send_request(action='shutdown')
except ClientError: # pragma: debug
return
if response.get('pid', None):
kill(response['pid'], _shutdown_signal)
self.shutdown()
def start_integration(self, client_id, x, yamls, **kwargs):
r"""Start an integration if it is not already running.
Args:
client_id (str): ID associated with the client requesting the
integration be started.
x (str, tuple): Hashable object that should be used to
identify the integration being started in the registry of
running integrations.
yamls (list): Set of YAML specification files defining the
integration that should be run as as service.
**kwargs: Additional keyword arguments are passed to get_runner.
Returns:
bool: True if the integration started, False otherwise.
"""
integrations = self.integrations[client_id]
stopped_integrations = self.stopped_integrations[client_id]
if (x in integrations) and (not integrations[x].is_alive):
if not self.stop_integration(client_id, x):
return False
if x in stopped_integrations:
if not self.stop_integration(client_id, x):
return False
stopped_integrations.pop(x)
if x not in integrations:
partial_commtype = {'commtype': self.commtype}
if self.commtype == 'rest':
partial_commtype['client_id'] = client_id
integrations[x] = runner.get_runner(
yamls, complete_partial=True, as_service=True,
partial_commtype=partial_commtype, **kwargs)
integrations[x].run(signal_handler=False)
return True
def _stop_integration(self, client_id, x):
r"""Finish stopping an integration in a thread."""
m = self.integrations[client_id].pop(x)
if m.is_alive:
m.terminate()
m.atexit()
def stop_integration(self, client_id, x):
r"""Stop a running integration.
Args:
client_id (str): ID associated with the client requesting the
integration be stopped.
x (str, tuple): Hashable object associated with the
integration service that should be stopped. If None,
all of the running integrations are stopped.
Returns:
bool: True if the integration has stopped.
Raises:
KeyError: If there is not a running integration associated
with the specified hashable object.
"""
integrations = self.integrations[client_id]
stopped_integrations = self.stopped_integrations[client_id]
if x is None:
return all(self.stop_integration(client_id, k)
for k in (list(integrations.keys())
+ list(stopped_integrations.keys())))
if x in stopped_integrations:
pass
elif x not in integrations:
raise KeyError(f"Integration defined by {x} not running")
elif service_type == RMQService:
self._stop_integration(client_id, x)
return True
else:
mthread = threading.Thread(target=self._stop_integration,
args=(client_id, x,), daemon=True)
mthread.start()
stopped_integrations[x] = mthread
return not stopped_integrations[x].is_alive()
def integration_info(self, client_id, x):
r"""Get information about an integration and how to connect to it.
Args:
client_id (str): ID associated with the client requesting the
integration info.
x (str, tuple): Hashable object associated with the
integration to get information on.
Returns:
dict: Information about the integration.
Raises:
KeyError: If there is not a running integration associated
with the specified hashable object.
"""
integrations = self.integrations[client_id]
if x not in integrations: # pragma: debug
raise KeyError(f"Integration defined by {x} not running")
m = integrations[x].modeldrivers['dummy_model']
out = m['instance'].service_partner
name = 'dummy'
if isinstance(x, str) and (not os.path.isfile(x)):
name = x
out.update(name=name,
args=name,
language='dummy')
return out
@property
def is_running(self):
r"""bool: True if the server is running."""
if not super(IntegrationServiceManager, self).is_running:
return False
if self.for_request:
try:
response = self.send_request(action='ping')
return response['status'] == 'running'
except ClientError:
return False
else: # pragma: debug
# This would only occur if a server calls is_running while it
# is running (perhaps in a future callback?)
return True
def respond(self, request, **kwargs):
r"""Create a response to the request.
Args:
request (dict): Request to respond to.
**kwargs: Additional keyword arguments are ignored.
Returns:
dict: Response to the request.
"""
name = None
action = None
yamls = None
client_id = None
try:
name = request.pop('name')
action = request.pop('action')
yamls = request.pop('yamls')
client_id = request.pop('client_id')
if client_id is not None:
self.integrations.setdefault(client_id, {})
self.stopped_integrations.setdefault(client_id, {})
if isinstance(name, list):
name = tuple(name)
if action == 'start':
if not yamls:
reg = self.registry.registry.get(name, None)
if isinstance(name, tuple):
yamls = list(name)
elif isinstance(name, str) and os.path.isfile(name):
yamls = [name]
elif reg is not None:
yamls = reg['yamls']
for k, v in reg.items():
if k not in ['name', 'yamls']:
request.setdefault(k, v)
else: # pragma: debug
raise RuntimeError("No YAML files specified.")
if self.start_integration(client_id, name, yamls,
**request):
response = dict(
self.integration_info(client_id, name),
status='complete')
else:
response = {'status': 'starting'}
elif action == 'stop':
if self.stop_integration(client_id, name):
response = {'status': 'complete'}
else:
response = {'status': 'stopping'}
elif action == 'shutdown':
if self.stop_integration(client_id, None):
if self.is_app: # pragma: no cover
response = {'status': 'complete'}
else:
self.shutdown()
response = {'status': 'complete',
'pid': os.getpid()}
else:
response = {'status': 'shutting down'}
elif action == 'status':
response = {'status': 'done'}
if name is None:
fmt = ('Address: %s\n'
'Available Services:\n%s\n'
'Running Services:\n%s')
registry_str = '\t' + '\n\t'.join(
pprint.pformat(self.registry.registry).splitlines())
if client_id is None:
clients = list(self.integrations.keys())
else:
clients = [client_id]
running_str = ''
for cli in clients:
running_str += '\tClient %s\n' % cli
for k, v in self.integrations[cli].items():
running_str += '\t\t%s:\n\t\t\t%s' % (
k, '\n\t\t\t'.join(v.printStatus(
return_str=True).splitlines()))
args = (self.address, registry_str, running_str)
response['status'] = fmt % args
else:
assert client_id is not None
response['status'] = (
self.integrations[client_id][name].printStatus(
return_str=True))
elif action == 'ping':
response = {'status': 'running'}
else:
raise RuntimeError(f"Unsupported action: '{action}'")
except BaseException as e:
tb = traceback.format_exc()
response = {'error': str(e), 'traceback': tb}
if action == 'start': # pragma: intermittent
self.respond({'name': name, 'action': 'stop', 'yamls': None,
'client_id': client_id})
return response
def process_response(self, response):
r"""Process a response.
Args:
response (str): Serialized response that should be processed.
Returns:
object: The deserialized, processed response.
Raises:
ServerError: If the response indicates there was an error on
the server during the creation of the response.
"""
response = super(
IntegrationServiceManager, self).process_response(response)
if 'error' in response:
raise ServerError('%s\n%s' % (response['traceback'],
response['error']))
return response
def printStatus(self, level='info', return_str=False):
r"""Print the status of the service manager including available
and running services."""
status = self.send_request(action='status')
if return_str:
msg, _ = self.logger.process(status['status'], {})
return msg
getattr(self.logger, level)(status['status'])
return IntegrationServiceManager
[docs]def IntegrationServiceManager(service_type=None, **kwargs):
r"""Start a management service to track running integrations.
Args:
service_type (ServiceBase, str, optional): Base class that should be
used. Defaults to ('services', 'default_type') configuration
options, if set, and 'flask' if not. If there is an address
provided, the service type will be determined by parsing the
address.
**kwargs: Additional keyword arguments are used to intialized the
manager class instance.
"""
if service_type is None:
if kwargs.get('address', None):
if kwargs['address'].startswith('amqp://'):
service_type = 'rmq'
else:
service_type = 'flask'
cls = create_service_manager_class(service_type=service_type)
return cls(**kwargs)
[docs]class IntegrationServiceRegistry(object):
r"""Class for managing integration services.
Args:
filename (str, optional): File where the registry will be/is stored.
Defaults to '~/.yggdrasil_services.yml'.
"""
def __init__(self, filename=os.path.join('~', '.yggdrasil_services.yml')):
self.filename = os.path.expanduser(filename)
@property
def registry(self):
r"""dict: Existing registry of integrations."""
return self.load()
[docs] def load(self):
r"""Load the dictionary of existing integrations that have been
registered.
Returns:
dict: Existing registry of integrations.
"""
out = {}
if os.path.isfile(self.filename):
with open(self.filename, 'r') as fd:
out = yaml.safe_load(fd)
return out
[docs] def save(self, registry):
r"""Save the registry to self.filename.
Args:
registry (dict): Dictionary of integrations to save.
"""
with open(self.filename, 'w') as fd:
yaml.dump(registry, fd)
[docs] def load_collection(self, name):
r"""Read a collection of integration registry entries from an YAML.
Args:
name (str): Full path to a YAML file containing one or more
registry entries mapping between integration name and YAML
specification files.
Returns:
dict: Loaded registry entries.
"""
with open(name, 'r') as fd:
out = yaml.safe_load(fd.read())
assert isinstance(out, dict)
base_dir = os.path.dirname(name)
for k in out.keys():
v = out.get(k, [])
out[k] = []
for x in v:
if not os.path.isabs(x):
x = os.path.join(base_dir, x)
out[k].append(x)
return out
[docs] def remove(self, name):
r"""Remove an integration service from the registry.
Args:
name (str): Name associated with the integration service that
should be removed from the registry.
Raises:
KeyError: If there is not an integration service associated with
the specified name.
"""
registry = self.load()
if os.path.isfile(name):
names = list(self.load_collection(name).keys())
else:
names = [name]
for k in names:
if k not in registry:
keys = list(self.registry.keys())
raise KeyError(f"There is not an integration service "
f"registered under the name '{k}'. Existing "
f"services are {keys}")
registry.pop(k)
self.save(registry)
[docs] def add_from_repository(self, model_repository, directory=None):
r"""Add integration services to the registry from a repository of
model YAMLs.
Args:
model_repository (str): URL of directory in a Git repository
containing YAMLs that should be added to the model registry.
directory (str, optional): Directory where services from the
model_repository should be cloned. Defaults to
'~/.yggdrasil_service'.
Returns:
str: The directory where the repositories were cloned.
"""
from yggdrasil.yamlfile import clone_github_repo, prep_yaml
if directory is None:
directory = os.path.expanduser(
os.path.join('~', '.yggdrasil_services'))
yaml_dir = clone_github_repo(model_repository,
local_directory=directory)
yaml_files = (glob.glob(os.path.join(yaml_dir, '*.yaml'))
+ glob.glob(os.path.join(yaml_dir, '*.yml')))
for x in yaml_files:
# Calling prep_yaml allows the model repositories to be cloned
# in advance to circumvent th hold place on git cloning on the
# service manager (these models are assumed to be vetted so
# they do not pose a security risk).
prep_yaml(x, directory_for_clones=directory)
self.add(os.path.splitext(os.path.basename(x))[0], x)
return directory
[docs] def add(self, name, yamls=None, **kwargs):
r"""Add an integration service to the registry.
Args:
name (str): Name that will be used to access the integration
service when starting or stopping it.
yamls (str, list): Set of one or more YAML specification files
defining the integration.
**kwargs: Additional keyword arguments are added to the new entry.
Raises:
ValueError: If there is already an integration with the specified
name.
"""
registry = self.load()
if os.path.isfile(name):
assert not yamls
collection = {k: dict(kwargs, name=k, yamls=v)
for k, v in self.load_collection(name).items()}
else:
assert yamls
collection = {name: dict(kwargs, name=name, yamls=yamls)}
for k, v in collection.items():
if (k in registry) and (registry[k] != v):
old = pprint.pformat(registry[k])
new = pprint.pformat(v)
raise ValueError(f"There is an registry integration "
f"associated with the name '{k}'. Remove "
f"the registry entry before adding a new "
f"one.\n"
f" Registry:\n{old}\n New:\n{new}")
registry[k] = v
self.save(registry)
[docs]def validate_model_submission(fname):
r"""Validate a YAML file according to the standards for submission to
the yggdrasil model repository.
Args:
fname (str): YAML file to validate or directory in which to check
each of the YAML files.
"""
from yggdrasil import yamlfile, runner
if isinstance(fname, list):
for x in fname:
validate_model_submission(x)
return
elif os.path.isdir(fname):
files = sorted(glob.glob(os.path.join(fname, '*.yml'))
+ glob.glob(os.path.join(fname, '*.yaml')))
for x in files:
validate_model_submission(x)
return
# 1-2. YAML syntax and schema
yml = yamlfile.parse_yaml(fname, model_submission=True)
# 3a. LICENSE
repo_dir = yml['models'][0]['working_dir']
patterns = ['LICENSE', 'LICENSE.*']
for x in patterns:
if ((glob.glob(os.path.join(repo_dir, x.upper()))
or glob.glob(os.path.join(repo_dir, x.lower())))):
break
else:
raise RuntimeError("Model repository does not contain a LICENSE file.")
# 4. Run & validate
runner.run(fname, validate=True)