neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

recurrent.py (10600B)


      1 # File: recurrent.py
      2 # Created Date: Saturday July 2nd 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 Recurrent models (LSTM)
      7 
      8 TODO batch_first=False (I get it...)
      9 """
     10 
     11 import abc as _abc
     12 import json as _json
     13 from pathlib import Path as _Path
     14 from tempfile import TemporaryDirectory as _TemporaryDirectory
     15 from typing import Optional as Optional, Tuple as _Tuple
     16 
     17 import numpy as _np
     18 import torch as _torch
     19 import torch.nn as _nn
     20 
     21 from .base import BaseNet as _BaseNet
     22 
     23 
     24 class _L(_nn.LSTM):
     25     """
     26     Tweaks to PyTorch LSTM module
     27     * Up the remembering
     28     """
     29 
     30     def reset_parameters(self) -> None:
     31         super().reset_parameters()
     32         # https://danijar.com/tips-for-training-recurrent-neural-networks/
     33         # forget += 1
     34         # ifgo
     35         value = 2.0
     36         idx_input = slice(0, self.hidden_size)
     37         idx_forget = slice(self.hidden_size, 2 * self.hidden_size)
     38         for layer in range(self.num_layers):
     39             for input in ("i", "h"):
     40                 # Balance out the scale of the cell w/ a -=1
     41                 getattr(self, f"bias_{input}h_l{layer}").data[idx_input] -= value
     42                 getattr(self, f"bias_{input}h_l{layer}").data[idx_forget] += value
     43 
     44 
     45 # State:
     46 # L: Number of LSTM layers
     47 # DH: Hidden state dimension
     48 # [0]: hidden (L,DH)
     49 # [1]: cell (L,DH)
     50 _LSTMHiddenType = _torch.Tensor
     51 _LSTMCellType = _torch.Tensor
     52 _LSTMHiddenCellType = _Tuple[_LSTMHiddenType, _LSTMCellType]
     53 
     54 
     55 # TODO get this somewhere more core-ish
     56 class _ExportsWeights(_abc.ABC):
     57     @_abc.abstractmethod
     58     def export_weights(self) -> _np.ndarray:
     59         """
     60         :return: a 1D array of weights
     61         """
     62         pass
     63 
     64 
     65 class _Linear(_nn.Linear, _ExportsWeights):
     66     def export_weights(self):
     67         return _np.concatenate(
     68             [
     69                 self.weight.data.detach().cpu().numpy().flatten(),
     70                 self.bias.data.detach().cpu().numpy().flatten(),
     71             ]
     72         )
     73 
     74 
     75 class LSTM(_BaseNet):
     76     """
     77     ABC for recurrent architectures
     78     """
     79 
     80     def __init__(
     81         self,
     82         hidden_size,
     83         train_burn_in: Optional[int] = None,
     84         train_truncate: Optional[int] = None,
     85         input_size: int = 1,
     86         sample_rate: Optional[float] = None,
     87         **lstm_kwargs,
     88     ):
     89         """
     90         :param hidden_size: for LSTM
     91         :param train_burn_in: Detach calculations from first (this many) samples when
     92             training to burn in the hidden state.
     93         :param train_truncate: detach the hidden & cell states every this many steps
     94             during training so that backpropagation through time is faster + to simulate
     95             better starting states for h(t0)&c(t0) (instead of zeros)
     96             TODO recognition head to start the hidden state in a good place?
     97         :param input_size: Usually 1 (mono input). A catnet extending this might change
     98             it and provide the parametric inputs as additional input dimensions.
     99         """
    100         super().__init__(sample_rate=sample_rate)
    101         if "batch_first" in lstm_kwargs:
    102             raise ValueError("batch_first cannot be set.")
    103         self._input_size = input_size
    104         self._core = _L(self._input_size, hidden_size, batch_first=True, **lstm_kwargs)
    105         self._head = self._init_head(hidden_size)
    106         self._train_burn_in = train_burn_in
    107         self._train_truncate = train_truncate
    108         self._initial_cell = _nn.Parameter(
    109             _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
    110         )
    111         self._initial_hidden = _nn.Parameter(
    112             _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
    113         )
    114         self._get_initial_state_burn_in = 48_000
    115 
    116     @property
    117     def input_device(self) -> _torch.device:
    118         """
    119         What device does the input need to be on?
    120         """
    121         return self._core.bias_ih_l0.device
    122 
    123     @property
    124     def receptive_field(self) -> int:
    125         return 1
    126 
    127     @property
    128     def pad_start_default(self) -> bool:
    129         # I should simplify this...
    130         return True
    131 
    132     def export_cpp_header(self, filename: _Path):
    133         with _TemporaryDirectory() as tmpdir:
    134             tmpdir = _Path(tmpdir)
    135             LSTM.export(self, _Path(tmpdir))  # Hacky...need to work w/ CatLSTM
    136             with open(_Path(tmpdir, "model.nam"), "r") as fp:
    137                 _c = _json.load(fp)
    138             version = _c["version"]
    139             config = _c["config"]
    140             s_parametric = self._export_cpp_header_parametric(config.get("parametric"))
    141             with open(filename, "w") as f:
    142                 f.writelines(
    143                     (
    144                         "#pragma once\n",
    145                         "// Automatically-generated model file\n",
    146                         "#include <vector>\n",
    147                         '#include "json.hpp"\n',
    148                         '#include "lstm.h"\n',
    149                         f'#define PYTHON_MODEL_VERSION "{version}"\n',
    150                         f'const int NUM_LAYERS = {config["num_layers"]};\n',
    151                         f'const int INPUT_SIZE = {config["input_size"]};\n',
    152                         f'const int HIDDEN_SIZE = {config["hidden_size"]};\n',
    153                     )
    154                     + s_parametric
    155                     + (
    156                         "std::vector<float> PARAMS{"
    157                         + ", ".join([f"{w:.16f}f" for w in _c["weights"]])
    158                         + "};\n",
    159                     )
    160                 )
    161 
    162     def _apply_head(self, features: _torch.Tensor) -> _torch.Tensor:
    163         """
    164         :param features: (B,S,DH)
    165         :return: (B,S)
    166         """
    167         return self._head(features)[:, :, 0]
    168 
    169     def _forward(
    170         self, x: _torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None
    171     ) -> _torch.Tensor:
    172         """
    173         :param x: (B,L) or (B,L,D)
    174         :return: (B,L)
    175         """
    176 
    177         def process_in_blocks(x, hidden_state=None):
    178             # See: https://github.com/sdatkinson/neural-amp-modeler/issues/450
    179             BLOCK_SIZE = 65_535
    180             outputs = []
    181             for i in range(0, x.shape[1], BLOCK_SIZE):
    182                 out, hidden_state = self._core(
    183                     x[:, i : i + BLOCK_SIZE, :], hidden_state
    184                 )
    185                 outputs.append(out)
    186             return _torch.cat(outputs, dim=1), hidden_state  # assert batch_first
    187 
    188         last_hidden_state = (
    189             self._initial_state(len(x)) if initial_state is None else initial_state
    190         )
    191         if x.ndim == 2:
    192             x = x[:, :, None]
    193         if not self.training or self._train_truncate is None:
    194             output_features = process_in_blocks(x, last_hidden_state)[0]
    195         else:
    196             output_features_list = []
    197             if self._train_burn_in is not None:
    198                 last_output_features, last_hidden_state = process_in_blocks(
    199                     x[:, : self._train_burn_in, :], last_hidden_state
    200                 )
    201                 output_features_list.append(last_output_features.detach())
    202             burn_in_offset = 0 if self._train_burn_in is None else self._train_burn_in
    203             for i in range(burn_in_offset, x.shape[1], self._train_truncate):
    204                 if i > burn_in_offset:
    205                     # Don't detach the burn-in state so that we can learn it.
    206                     last_hidden_state = tuple(z.detach() for z in last_hidden_state)
    207                 last_output_features, last_hidden_state = process_in_blocks(
    208                     x[:, i : i + self._train_truncate, :], last_hidden_state
    209                 )
    210                 output_features_list.append(last_output_features)
    211             output_features = _torch.cat(output_features_list, dim=1)
    212         return self._apply_head(output_features)
    213 
    214     def _export_cell_weights(
    215         self, i: int, hidden_state: _torch.Tensor, cell_state: _torch.Tensor
    216     ) -> _np.ndarray:
    217         """
    218         * weight matrix (xh -> ifco)
    219         * bias vector
    220         * Initial hidden state
    221         * Initial cell state
    222         """
    223 
    224         tensors = [
    225             _torch.cat(
    226                 [
    227                     getattr(self._core, f"weight_ih_l{i}").data,
    228                     getattr(self._core, f"weight_hh_l{i}").data,
    229                 ],
    230                 dim=1,
    231             ),
    232             getattr(self._core, f"bias_ih_l{i}").data
    233             + getattr(self._core, f"bias_hh_l{i}").data,
    234             hidden_state,
    235             cell_state,
    236         ]
    237         return _np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors])
    238 
    239     def _export_config(self):
    240         return {
    241             "input_size": self._core.input_size,
    242             "hidden_size": self._core.hidden_size,
    243             "num_layers": self._core.num_layers,
    244         }
    245 
    246     def _export_cpp_header_parametric(self, config):
    247         # TODO refactor to merge w/ WaveNet implementation
    248         if config is not None:
    249             raise ValueError("Got non-None parametric config")
    250         return ("nlohmann::json PARAMETRIC {};\n",)
    251 
    252     def _export_weights(self):
    253         """
    254         * Loop over cells:
    255             * weight matrix (xh -> ifco)
    256             * bias vector
    257             * Initial hidden state
    258             * Initial cell state
    259         * Head weights
    260         * Head bias
    261         """
    262         return _np.concatenate(
    263             [
    264                 self._export_cell_weights(i, h, c)
    265                 for i, (h, c) in enumerate(zip(*self._get_initial_state()))
    266             ]
    267             + [self._head.export_weights()]
    268         )
    269 
    270     def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType:
    271         """
    272         Convenience function to find a good hidden state to start the plugin at
    273 
    274         DX=input size
    275         L=num layers
    276         S=sequence length
    277         :param inputs: (1,S,DX)
    278 
    279         :return: (L,DH), (L,DH)
    280         """
    281         inputs = (
    282             _torch.zeros((1, self._get_initial_state_burn_in, 1))
    283             if inputs is None
    284             else inputs
    285         ).to(self.input_device)
    286         _, (h, c) = self._core(inputs)
    287         return h, c
    288 
    289     def _init_head(self, hidden_size: int) -> _ExportsWeights:
    290         return _Linear(hidden_size, 1)
    291 
    292     def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType:
    293         """
    294         Literally what the forward pass starts with.
    295         Default is zeroes; this should be better since it can be learned.
    296         """
    297         return (
    298             (self._initial_hidden, self._initial_cell)
    299             if n is None
    300             else (
    301                 _torch.tile(self._initial_hidden[:, None], (1, n, 1)),
    302                 _torch.tile(self._initial_cell[:, None], (1, n, 1)),
    303             )
    304         )