recurrent.py (10600B)
1 # File: recurrent.py 2 # Created Date: Saturday July 2nd 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 Recurrent models (LSTM) 7 8 TODO batch_first=False (I get it...) 9 """ 10 11 import abc as _abc 12 import json as _json 13 from pathlib import Path as _Path 14 from tempfile import TemporaryDirectory as _TemporaryDirectory 15 from typing import Optional as Optional, Tuple as _Tuple 16 17 import numpy as _np 18 import torch as _torch 19 import torch.nn as _nn 20 21 from .base import BaseNet as _BaseNet 22 23 24 class _L(_nn.LSTM): 25 """ 26 Tweaks to PyTorch LSTM module 27 * Up the remembering 28 """ 29 30 def reset_parameters(self) -> None: 31 super().reset_parameters() 32 # https://danijar.com/tips-for-training-recurrent-neural-networks/ 33 # forget += 1 34 # ifgo 35 value = 2.0 36 idx_input = slice(0, self.hidden_size) 37 idx_forget = slice(self.hidden_size, 2 * self.hidden_size) 38 for layer in range(self.num_layers): 39 for input in ("i", "h"): 40 # Balance out the scale of the cell w/ a -=1 41 getattr(self, f"bias_{input}h_l{layer}").data[idx_input] -= value 42 getattr(self, f"bias_{input}h_l{layer}").data[idx_forget] += value 43 44 45 # State: 46 # L: Number of LSTM layers 47 # DH: Hidden state dimension 48 # [0]: hidden (L,DH) 49 # [1]: cell (L,DH) 50 _LSTMHiddenType = _torch.Tensor 51 _LSTMCellType = _torch.Tensor 52 _LSTMHiddenCellType = _Tuple[_LSTMHiddenType, _LSTMCellType] 53 54 55 # TODO get this somewhere more core-ish 56 class _ExportsWeights(_abc.ABC): 57 @_abc.abstractmethod 58 def export_weights(self) -> _np.ndarray: 59 """ 60 :return: a 1D array of weights 61 """ 62 pass 63 64 65 class _Linear(_nn.Linear, _ExportsWeights): 66 def export_weights(self): 67 return _np.concatenate( 68 [ 69 self.weight.data.detach().cpu().numpy().flatten(), 70 self.bias.data.detach().cpu().numpy().flatten(), 71 ] 72 ) 73 74 75 class LSTM(_BaseNet): 76 """ 77 ABC for recurrent architectures 78 """ 79 80 def __init__( 81 self, 82 hidden_size, 83 train_burn_in: Optional[int] = None, 84 train_truncate: Optional[int] = None, 85 input_size: int = 1, 86 sample_rate: Optional[float] = None, 87 **lstm_kwargs, 88 ): 89 """ 90 :param hidden_size: for LSTM 91 :param train_burn_in: Detach calculations from first (this many) samples when 92 training to burn in the hidden state. 93 :param train_truncate: detach the hidden & cell states every this many steps 94 during training so that backpropagation through time is faster + to simulate 95 better starting states for h(t0)&c(t0) (instead of zeros) 96 TODO recognition head to start the hidden state in a good place? 97 :param input_size: Usually 1 (mono input). A catnet extending this might change 98 it and provide the parametric inputs as additional input dimensions. 99 """ 100 super().__init__(sample_rate=sample_rate) 101 if "batch_first" in lstm_kwargs: 102 raise ValueError("batch_first cannot be set.") 103 self._input_size = input_size 104 self._core = _L(self._input_size, hidden_size, batch_first=True, **lstm_kwargs) 105 self._head = self._init_head(hidden_size) 106 self._train_burn_in = train_burn_in 107 self._train_truncate = train_truncate 108 self._initial_cell = _nn.Parameter( 109 _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) 110 ) 111 self._initial_hidden = _nn.Parameter( 112 _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) 113 ) 114 self._get_initial_state_burn_in = 48_000 115 116 @property 117 def input_device(self) -> _torch.device: 118 """ 119 What device does the input need to be on? 120 """ 121 return self._core.bias_ih_l0.device 122 123 @property 124 def receptive_field(self) -> int: 125 return 1 126 127 @property 128 def pad_start_default(self) -> bool: 129 # I should simplify this... 130 return True 131 132 def export_cpp_header(self, filename: _Path): 133 with _TemporaryDirectory() as tmpdir: 134 tmpdir = _Path(tmpdir) 135 LSTM.export(self, _Path(tmpdir)) # Hacky...need to work w/ CatLSTM 136 with open(_Path(tmpdir, "model.nam"), "r") as fp: 137 _c = _json.load(fp) 138 version = _c["version"] 139 config = _c["config"] 140 s_parametric = self._export_cpp_header_parametric(config.get("parametric")) 141 with open(filename, "w") as f: 142 f.writelines( 143 ( 144 "#pragma once\n", 145 "// Automatically-generated model file\n", 146 "#include <vector>\n", 147 '#include "json.hpp"\n', 148 '#include "lstm.h"\n', 149 f'#define PYTHON_MODEL_VERSION "{version}"\n', 150 f'const int NUM_LAYERS = {config["num_layers"]};\n', 151 f'const int INPUT_SIZE = {config["input_size"]};\n', 152 f'const int HIDDEN_SIZE = {config["hidden_size"]};\n', 153 ) 154 + s_parametric 155 + ( 156 "std::vector<float> PARAMS{" 157 + ", ".join([f"{w:.16f}f" for w in _c["weights"]]) 158 + "};\n", 159 ) 160 ) 161 162 def _apply_head(self, features: _torch.Tensor) -> _torch.Tensor: 163 """ 164 :param features: (B,S,DH) 165 :return: (B,S) 166 """ 167 return self._head(features)[:, :, 0] 168 169 def _forward( 170 self, x: _torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None 171 ) -> _torch.Tensor: 172 """ 173 :param x: (B,L) or (B,L,D) 174 :return: (B,L) 175 """ 176 177 def process_in_blocks(x, hidden_state=None): 178 # See: https://github.com/sdatkinson/neural-amp-modeler/issues/450 179 BLOCK_SIZE = 65_535 180 outputs = [] 181 for i in range(0, x.shape[1], BLOCK_SIZE): 182 out, hidden_state = self._core( 183 x[:, i : i + BLOCK_SIZE, :], hidden_state 184 ) 185 outputs.append(out) 186 return _torch.cat(outputs, dim=1), hidden_state # assert batch_first 187 188 last_hidden_state = ( 189 self._initial_state(len(x)) if initial_state is None else initial_state 190 ) 191 if x.ndim == 2: 192 x = x[:, :, None] 193 if not self.training or self._train_truncate is None: 194 output_features = process_in_blocks(x, last_hidden_state)[0] 195 else: 196 output_features_list = [] 197 if self._train_burn_in is not None: 198 last_output_features, last_hidden_state = process_in_blocks( 199 x[:, : self._train_burn_in, :], last_hidden_state 200 ) 201 output_features_list.append(last_output_features.detach()) 202 burn_in_offset = 0 if self._train_burn_in is None else self._train_burn_in 203 for i in range(burn_in_offset, x.shape[1], self._train_truncate): 204 if i > burn_in_offset: 205 # Don't detach the burn-in state so that we can learn it. 206 last_hidden_state = tuple(z.detach() for z in last_hidden_state) 207 last_output_features, last_hidden_state = process_in_blocks( 208 x[:, i : i + self._train_truncate, :], last_hidden_state 209 ) 210 output_features_list.append(last_output_features) 211 output_features = _torch.cat(output_features_list, dim=1) 212 return self._apply_head(output_features) 213 214 def _export_cell_weights( 215 self, i: int, hidden_state: _torch.Tensor, cell_state: _torch.Tensor 216 ) -> _np.ndarray: 217 """ 218 * weight matrix (xh -> ifco) 219 * bias vector 220 * Initial hidden state 221 * Initial cell state 222 """ 223 224 tensors = [ 225 _torch.cat( 226 [ 227 getattr(self._core, f"weight_ih_l{i}").data, 228 getattr(self._core, f"weight_hh_l{i}").data, 229 ], 230 dim=1, 231 ), 232 getattr(self._core, f"bias_ih_l{i}").data 233 + getattr(self._core, f"bias_hh_l{i}").data, 234 hidden_state, 235 cell_state, 236 ] 237 return _np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors]) 238 239 def _export_config(self): 240 return { 241 "input_size": self._core.input_size, 242 "hidden_size": self._core.hidden_size, 243 "num_layers": self._core.num_layers, 244 } 245 246 def _export_cpp_header_parametric(self, config): 247 # TODO refactor to merge w/ WaveNet implementation 248 if config is not None: 249 raise ValueError("Got non-None parametric config") 250 return ("nlohmann::json PARAMETRIC {};\n",) 251 252 def _export_weights(self): 253 """ 254 * Loop over cells: 255 * weight matrix (xh -> ifco) 256 * bias vector 257 * Initial hidden state 258 * Initial cell state 259 * Head weights 260 * Head bias 261 """ 262 return _np.concatenate( 263 [ 264 self._export_cell_weights(i, h, c) 265 for i, (h, c) in enumerate(zip(*self._get_initial_state())) 266 ] 267 + [self._head.export_weights()] 268 ) 269 270 def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType: 271 """ 272 Convenience function to find a good hidden state to start the plugin at 273 274 DX=input size 275 L=num layers 276 S=sequence length 277 :param inputs: (1,S,DX) 278 279 :return: (L,DH), (L,DH) 280 """ 281 inputs = ( 282 _torch.zeros((1, self._get_initial_state_burn_in, 1)) 283 if inputs is None 284 else inputs 285 ).to(self.input_device) 286 _, (h, c) = self._core(inputs) 287 return h, c 288 289 def _init_head(self, hidden_size: int) -> _ExportsWeights: 290 return _Linear(hidden_size, 1) 291 292 def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType: 293 """ 294 Literally what the forward pass starts with. 295 Default is zeroes; this should be better since it can be learned. 296 """ 297 return ( 298 (self._initial_hidden, self._initial_cell) 299 if n is None 300 else ( 301 _torch.tile(self._initial_hidden[:, None], (1, n, 1)), 302 _torch.tile(self._initial_cell[:, None], (1, n, 1)), 303 ) 304 )