base.py (9430B)
1 # File: _base.py 2 # Created Date: Tuesday February 8th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 The foundation of the model without the PyTorch Lightning attributes (losses, training 7 steps) 8 """ 9 10 import abc as _abc 11 import math as _math 12 import pkg_resources as _pkg_resources 13 from typing import ( 14 Any as _Any, 15 Dict as _Dict, 16 Optional as _Optional, 17 Tuple as _Tuple, 18 Union as _Union, 19 ) 20 21 import numpy as _np 22 import torch as _torch 23 import torch.nn as _nn 24 25 from .._core import InitializableFromConfig as _InitializableFromConfig 26 from ..data import wav_to_tensor as _wav_to_tensor 27 from .exportable import Exportable as _Exportable 28 29 30 class _Base(_nn.Module, _InitializableFromConfig, _Exportable): 31 def __init__(self, sample_rate: _Optional[float] = None): 32 super().__init__() 33 self.register_buffer( 34 "_has_sample_rate", 35 _torch.tensor(sample_rate is not None, dtype=_torch.bool), 36 ) 37 self.register_buffer( 38 "_sample_rate", _torch.tensor(0.0 if sample_rate is None else sample_rate) 39 ) 40 41 @property 42 @_abc.abstractmethod 43 def pad_start_default(self) -> bool: 44 pass 45 46 @property 47 @_abc.abstractmethod 48 def receptive_field(self) -> int: 49 """ 50 Receptive field of the model 51 """ 52 pass 53 54 @_abc.abstractmethod 55 def forward(self, *args, **kwargs) -> _torch.Tensor: 56 pass 57 58 @classmethod 59 def _metadata_loudness_x(cls) -> _torch.Tensor: 60 return _wav_to_tensor( 61 _pkg_resources.resource_filename( 62 "nam", "models/_resources/loudness_input.wav" 63 ) 64 ) 65 66 @property 67 def device(self) -> _Optional[_torch.device]: 68 """ 69 Helpful property, where the parameters of the model live. 70 """ 71 # We can do this because the models are tiny and I don't expect a NAM to be on 72 # multiple devices 73 try: 74 return next(self.parameters()).device 75 except StopIteration: 76 return None 77 78 @property 79 def sample_rate(self) -> _Optional[float]: 80 return self._sample_rate.item() if self._has_sample_rate else None 81 82 @sample_rate.setter 83 def sample_rate(self, val: _Optional[float]): 84 self._has_sample_rate = _torch.tensor(val is not None, dtype=_torch.bool) 85 self._sample_rate = _torch.tensor(0.0 if val is None else val) 86 87 def _get_export_dict(self): 88 d = super()._get_export_dict() 89 sample_rate_key = "sample_rate" 90 if sample_rate_key in d: 91 raise RuntimeError( 92 "Model wants to put 'sample_rate' into model export dict, but the key " 93 "is already taken!" 94 ) 95 d[sample_rate_key] = self.sample_rate 96 return d 97 98 def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: 99 """ 100 How loud is this model when given a standardized input? 101 In dB 102 103 :param gain: Multiplies input signal 104 """ 105 x = self._metadata_loudness_x().to(self.device) 106 y = self._at_nominal_settings(gain * x) 107 loudness = _torch.sqrt(_torch.mean(_torch.square(y))) 108 if db: 109 loudness = 20.0 * _torch.log10(loudness) 110 return loudness.item() 111 112 def _metadata_gain(self) -> float: 113 """ 114 Between 0 and 1, how much gain / compression does the model seem to have? 115 """ 116 x = _np.linspace(0.0, 1.0, 11) 117 y = _np.array([self._metadata_loudness(gain=gain, db=False) for gain in x]) 118 # 119 # O ^ o o o o o o 120 # u | o x +-------------------------------------+ 121 # t | o x | x: Minimum gain (no compression) | 122 # p | o x | o: Max gain (100% compression) | 123 # u | o x +-------------------------------------+ 124 # t | o 125 # +-------------> 126 # Input 127 # 128 max_gain = y[-1] * len(x) # "Square" 129 min_gain = 0.5 * max_gain # "Triangle" 130 gain_range = max_gain - min_gain 131 this_gain = y.sum() 132 normalized_gain = (this_gain - min_gain) / gain_range 133 return _np.clip(normalized_gain, 0.0, 1.0) 134 135 def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor: 136 # parametric?... 137 raise NotImplementedError() 138 139 @_abc.abstractmethod 140 def _forward(self, *args) -> _torch.Tensor: 141 """ 142 The true forward method. 143 144 :param x: (N,L1) 145 :return: (N,L1-RF+1) 146 """ 147 pass 148 149 def _export_input_output_args(self) -> _Tuple[_Any]: 150 """ 151 Create any other args necessesary (e.g. params to eval at) 152 """ 153 return () 154 155 def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]: 156 args = self._export_input_output_args() 157 rate = self.sample_rate 158 if rate is None: 159 raise RuntimeError( 160 "Cannot export model's input and output without a sample rate." 161 ) 162 x = _torch.cat( 163 [ 164 _torch.zeros((rate,)), 165 0.5 166 * _torch.sin( 167 2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1] 168 ), 169 _torch.zeros((rate,)), 170 ] 171 ) 172 # Use pad start to ensure same length as requested by ._export_input_output() 173 return ( 174 x.detach().cpu().numpy(), 175 self(*args, x, pad_start=True).detach().cpu().numpy(), 176 ) 177 178 179 def _get_torch_version() -> str: 180 return _torch.__version__ 181 182 183 class BaseNet(_Base): 184 def __init__(self, *args, **kwargs): 185 super().__init__(*args, **kwargs) 186 self._mps_65536_fallback = False 187 188 def forward(self, x: _torch.Tensor, pad_start: _Optional[bool] = None, **kwargs): 189 pad_start = self.pad_start_default if pad_start is None else pad_start 190 scalar = x.ndim == 1 191 if scalar: 192 x = x[None] 193 if pad_start: 194 x = _torch.cat( 195 (_torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), 196 dim=1, 197 ) 198 if x.shape[1] < self.receptive_field: 199 raise ValueError( 200 f"Input has {x.shape[1]} samples, which is too few for this model with " 201 f"receptive field {self.receptive_field}!" 202 ) 203 y = self._forward_mps_safe(x, **kwargs) 204 if scalar: 205 y = y[0] 206 return y 207 208 def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor: 209 return self(x) 210 211 def _forward_mps_safe(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor: 212 """ 213 Wrap `._forward()` to protect against MPS-unsupported input lengths 214 beyond 65,536 samples. 215 216 Check this again when PyTorch 2.5.2 is released--hopefully it's fixed 217 then. 218 """ 219 if not self._mps_65536_fallback: 220 try: 221 return self._forward(x, **kwargs) 222 except NotImplementedError as e: 223 if "Output channels > 65536 not supported at the MPS device." in str(e): 224 msg = ( 225 "Warning: NAM encountered a bug in PyTorch's MPS backend and " 226 "will switch to a fallback." 227 ) 228 known_bad_versions = {"2.5.0", "2.5.1"} 229 torch_version = _get_torch_version() 230 if torch_version not in known_bad_versions: 231 msg += ( 232 "\n" 233 f"Your version of PyTorch is {torch_version}, which " 234 "wasn't known to have this problem.\n" 235 "Please open an Issue at:\n" 236 "https://github.com/sdatkinson/neural-amp-modeler/issues/507" 237 "\n" 238 f"and report your PyTorch version ({torch_version}) " 239 "so that we can keep track of versions of PyTorch that " 240 "might be avoided." 241 ) 242 print(msg) 243 self._mps_65536_fallback = True 244 return self._forward_mps_safe(x, **kwargs) 245 else: 246 raise e 247 else: 248 # Stitch together the output one piece at a time to avoid the MPS error 249 stride = 65_536 - (self.receptive_field - 1) 250 # We need to make sure that the last segment is big enough that we have the required history for the receptive field. 251 out_list = [] 252 for i in range(0, x.shape[1], stride): 253 j = min(i + 65_536, x.shape[1]) 254 xi = x[:, i:j] 255 out_list.append(self._forward(xi, **kwargs)) 256 # Bit hacky, but correct. 257 if j == x.shape[1]: 258 break 259 return _torch.cat(out_list, dim=1) 260 261 @_abc.abstractmethod 262 def _forward(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor: 263 """ 264 The true forward method. 265 266 :param x: (N,L1) 267 :return: (N,L1-RF+1) 268 """ 269 pass 270 271 def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]: 272 d = super()._get_non_user_metadata() 273 d["loudness"] = self._metadata_loudness() 274 d["gain"] = self._metadata_gain() 275 return d