Source code for yggdrasil.drivers.PyTorchModelDriver

import os
from yggdrasil.drivers.DSLModelDriver import DSLModelDriver
from yggdrasil import rapidjson
# TODO: Allow model to be trained by input and return weights?


[docs]class PyTorchModelDriver(DSLModelDriver): r"""Class for handling PyTorch models.""" _schema_subtype_description = 'Model is a PyTorch model' _schema_required = ['weights'] _schema_properties = { 'weights': { 'type': 'string', 'description': ('Path to file where model weights ' 'are saved')}, 'input_transform': { 'type': 'function', 'description': ('Transformation that should be applied to ' 'input to get it into the format expected by ' 'the model (including transformation to ' 'pytorch tensors as necessary). This ' 'function should return a tuple of ' 'arguments for the model.')}, 'output_transform': { 'type': 'function', 'description': ('Transformation that should be applied to ' 'model output to get it into a format that ' 'can be serialized by yggdrasil (i.e. not ' 'a pytorch Tensor or model sepecific type).')}, } language = 'pytorch' language_ext = '.py' # '.pth' interface_dependencies = ['torch']
[docs] @classmethod def language_version(cls, **kwargs): r"""Determine the version of this language. Args: **kwargs: Keyword arguments are passed to cls.run_executable. Returns: str: Version of compiler/interpreter for this language. """ try: import torch return torch.__version__ except ImportError: # pragma: debug raise RuntimeError("roadrunner not installed.")
@property def model_wrapper_args(self): r"""tuple: Positional arguments for the model wrapper.""" return (self.model_file, self.weights, ) @property def model_wrapper_kwargs(self): r"""dict: Keyword arguments for the model wrapper.""" out = super(PyTorchModelDriver, self).model_wrapper_kwargs out.update( {'inputs': self.inputs, 'outputs': self.outputs, 'working_dir': self.working_dir, 'input_transform': self.input_transform, 'output_transform': self.output_transform}) return out
[docs] @classmethod def model_wrapper(cls, model_file, weights_file, inputs=[], outputs=[], env=None, working_dir=None, input_transform=None, output_transform=None): r"""Model wrapper.""" import torch from yggdrasil.languages.Python.YggInterface import ( YggInput, YggOutput) if env is not None: os.environ.update(env) if working_dir is not None: os.chdir(working_dir) # Create input/output comms input_map = {} output_map = {} input_vars = [] output_vars = [] for x in inputs: input_map[x['name']] = { 'vars': [v['name'] for v in x.get('vars', [])], 'comm': YggInput(x['name'], new_process=True)} if not input_map[x['name']]['vars']: input_map[x['name']]['vars'].append(x['name']) input_vars += input_map[x['name']]['vars'] for x in outputs: x_vars = [v['name'] for v in x.get('vars', [])] output_map[x['name']] = { 'as_array': x.get('as_array', False), 'vars': x_vars, 'comm': YggOutput(x['name'], new_process=True, field_names=x_vars)} if not output_map[x['name']]['vars']: output_map[x['name']]['vars'].append(x['name']) output_vars += output_map[x['name']]['vars'] # Create model model = rapidjson.normalize(model_file, {'type': 'class'})() model.load_state_dict(torch.load(weights_file)) model.eval() while True: flag = False values = {} for k, v in input_map.items(): flag, value = v['comm'].recv_dict(key_order=v['vars']) if not flag: print(f"No more input from {k}") break values.update(value) if not flag: break args = [values[k] for k in input_vars] if input_transform: args = input_transform(*args) output = model(*args) if output_transform: output = output_transform(output) if len(output_vars) == 1: output = [output] output = {k: v for k, v in zip(output_vars, output)} for k, v in output_map.items(): iout = {ik: output[ik] for ik in v['vars']} flag = v['comm'].send_dict(iout, key_order=v['vars']) if not flag: # pragma: debug raise RuntimeError(f"Error sending to {k}")
[docs] @classmethod def get_testing_options(cls, **kwargs): r"""Method to return a dictionary of testing options for this class. Returns: dict: Dictionary of variables to use for testing. Key/value pairs: kwargs (dict): Keyword arguments for driver instance. deps (list): Dependencies to install. """ out = super(PyTorchModelDriver, cls).get_testing_options() out['kwargs']['weights'] = 'pytorch_model_weights.pth' out['requires_partner'] = True return out