neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

base.py (9430B)


      1 # File: _base.py
      2 # Created Date: Tuesday February 8th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 The foundation of the model without the PyTorch Lightning attributes (losses, training
      7 steps)
      8 """
      9 
     10 import abc as _abc
     11 import math as _math
     12 import pkg_resources as _pkg_resources
     13 from typing import (
     14     Any as _Any,
     15     Dict as _Dict,
     16     Optional as _Optional,
     17     Tuple as _Tuple,
     18     Union as _Union,
     19 )
     20 
     21 import numpy as _np
     22 import torch as _torch
     23 import torch.nn as _nn
     24 
     25 from .._core import InitializableFromConfig as _InitializableFromConfig
     26 from ..data import wav_to_tensor as _wav_to_tensor
     27 from .exportable import Exportable as _Exportable
     28 
     29 
     30 class _Base(_nn.Module, _InitializableFromConfig, _Exportable):
     31     def __init__(self, sample_rate: _Optional[float] = None):
     32         super().__init__()
     33         self.register_buffer(
     34             "_has_sample_rate",
     35             _torch.tensor(sample_rate is not None, dtype=_torch.bool),
     36         )
     37         self.register_buffer(
     38             "_sample_rate", _torch.tensor(0.0 if sample_rate is None else sample_rate)
     39         )
     40 
     41     @property
     42     @_abc.abstractmethod
     43     def pad_start_default(self) -> bool:
     44         pass
     45 
     46     @property
     47     @_abc.abstractmethod
     48     def receptive_field(self) -> int:
     49         """
     50         Receptive field of the model
     51         """
     52         pass
     53 
     54     @_abc.abstractmethod
     55     def forward(self, *args, **kwargs) -> _torch.Tensor:
     56         pass
     57 
     58     @classmethod
     59     def _metadata_loudness_x(cls) -> _torch.Tensor:
     60         return _wav_to_tensor(
     61             _pkg_resources.resource_filename(
     62                 "nam", "models/_resources/loudness_input.wav"
     63             )
     64         )
     65 
     66     @property
     67     def device(self) -> _Optional[_torch.device]:
     68         """
     69         Helpful property, where the parameters of the model live.
     70         """
     71         # We can do this because the models are tiny and I don't expect a NAM to be on
     72         # multiple devices
     73         try:
     74             return next(self.parameters()).device
     75         except StopIteration:
     76             return None
     77 
     78     @property
     79     def sample_rate(self) -> _Optional[float]:
     80         return self._sample_rate.item() if self._has_sample_rate else None
     81 
     82     @sample_rate.setter
     83     def sample_rate(self, val: _Optional[float]):
     84         self._has_sample_rate = _torch.tensor(val is not None, dtype=_torch.bool)
     85         self._sample_rate = _torch.tensor(0.0 if val is None else val)
     86 
     87     def _get_export_dict(self):
     88         d = super()._get_export_dict()
     89         sample_rate_key = "sample_rate"
     90         if sample_rate_key in d:
     91             raise RuntimeError(
     92                 "Model wants to put 'sample_rate' into model export dict, but the key "
     93                 "is already taken!"
     94             )
     95         d[sample_rate_key] = self.sample_rate
     96         return d
     97 
     98     def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
     99         """
    100         How loud is this model when given a standardized input?
    101         In dB
    102 
    103         :param gain: Multiplies input signal
    104         """
    105         x = self._metadata_loudness_x().to(self.device)
    106         y = self._at_nominal_settings(gain * x)
    107         loudness = _torch.sqrt(_torch.mean(_torch.square(y)))
    108         if db:
    109             loudness = 20.0 * _torch.log10(loudness)
    110         return loudness.item()
    111 
    112     def _metadata_gain(self) -> float:
    113         """
    114         Between 0 and 1, how much gain / compression does the model seem to have?
    115         """
    116         x = _np.linspace(0.0, 1.0, 11)
    117         y = _np.array([self._metadata_loudness(gain=gain, db=False) for gain in x])
    118         #
    119         # O ^ o o o o o o
    120         # u | o       x   +-------------------------------------+
    121         # t | o     x     | x: Minimum gain (no compression)    |
    122         # p | o   x       | o: Max gain     (100% compression)  |
    123         # u | o x         +-------------------------------------+
    124         # t | o
    125         #   +------------->
    126         #       Input
    127         #
    128         max_gain = y[-1] * len(x)  # "Square"
    129         min_gain = 0.5 * max_gain  # "Triangle"
    130         gain_range = max_gain - min_gain
    131         this_gain = y.sum()
    132         normalized_gain = (this_gain - min_gain) / gain_range
    133         return _np.clip(normalized_gain, 0.0, 1.0)
    134 
    135     def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor:
    136         # parametric?...
    137         raise NotImplementedError()
    138 
    139     @_abc.abstractmethod
    140     def _forward(self, *args) -> _torch.Tensor:
    141         """
    142         The true forward method.
    143 
    144         :param x: (N,L1)
    145         :return: (N,L1-RF+1)
    146         """
    147         pass
    148 
    149     def _export_input_output_args(self) -> _Tuple[_Any]:
    150         """
    151         Create any other args necessesary (e.g. params to eval at)
    152         """
    153         return ()
    154 
    155     def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]:
    156         args = self._export_input_output_args()
    157         rate = self.sample_rate
    158         if rate is None:
    159             raise RuntimeError(
    160                 "Cannot export model's input and output without a sample rate."
    161             )
    162         x = _torch.cat(
    163             [
    164                 _torch.zeros((rate,)),
    165                 0.5
    166                 * _torch.sin(
    167                     2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1]
    168                 ),
    169                 _torch.zeros((rate,)),
    170             ]
    171         )
    172         # Use pad start to ensure same length as requested by ._export_input_output()
    173         return (
    174             x.detach().cpu().numpy(),
    175             self(*args, x, pad_start=True).detach().cpu().numpy(),
    176         )
    177 
    178 
    179 def _get_torch_version() -> str:
    180     return _torch.__version__
    181 
    182 
    183 class BaseNet(_Base):
    184     def __init__(self, *args, **kwargs):
    185         super().__init__(*args, **kwargs)
    186         self._mps_65536_fallback = False
    187 
    188     def forward(self, x: _torch.Tensor, pad_start: _Optional[bool] = None, **kwargs):
    189         pad_start = self.pad_start_default if pad_start is None else pad_start
    190         scalar = x.ndim == 1
    191         if scalar:
    192             x = x[None]
    193         if pad_start:
    194             x = _torch.cat(
    195                 (_torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x),
    196                 dim=1,
    197             )
    198         if x.shape[1] < self.receptive_field:
    199             raise ValueError(
    200                 f"Input has {x.shape[1]} samples, which is too few for this model with "
    201                 f"receptive field {self.receptive_field}!"
    202             )
    203         y = self._forward_mps_safe(x, **kwargs)
    204         if scalar:
    205             y = y[0]
    206         return y
    207 
    208     def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor:
    209         return self(x)
    210 
    211     def _forward_mps_safe(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor:
    212         """
    213         Wrap `._forward()` to protect against MPS-unsupported input lengths
    214         beyond 65,536 samples.
    215 
    216         Check this again when PyTorch 2.5.2 is released--hopefully it's fixed
    217         then.
    218         """
    219         if not self._mps_65536_fallback:
    220             try:
    221                 return self._forward(x, **kwargs)
    222             except NotImplementedError as e:
    223                 if "Output channels > 65536 not supported at the MPS device." in str(e):
    224                     msg = (
    225                         "Warning: NAM encountered a bug in PyTorch's MPS backend and "
    226                         "will switch to a fallback."
    227                     )
    228                     known_bad_versions = {"2.5.0", "2.5.1"}
    229                     torch_version = _get_torch_version()
    230                     if torch_version not in known_bad_versions:
    231                         msg += (
    232                             "\n"
    233                             f"Your version of PyTorch is {torch_version}, which "
    234                             "wasn't known to have this problem.\n"
    235                             "Please open an Issue at:\n"
    236                             "https://github.com/sdatkinson/neural-amp-modeler/issues/507"
    237                             "\n"
    238                             f"and report your PyTorch version ({torch_version}) "
    239                             "so that we can keep track of versions of PyTorch that "
    240                             "might be avoided."
    241                         )
    242                     print(msg)
    243                     self._mps_65536_fallback = True
    244                     return self._forward_mps_safe(x, **kwargs)
    245                 else:
    246                     raise e
    247         else:
    248             # Stitch together the output one piece at a time to avoid the MPS error
    249             stride = 65_536 - (self.receptive_field - 1)
    250             # We need to make sure that the last segment is big enough that we have the required history for the receptive field.
    251             out_list = []
    252             for i in range(0, x.shape[1], stride):
    253                 j = min(i + 65_536, x.shape[1])
    254                 xi = x[:, i:j]
    255                 out_list.append(self._forward(xi, **kwargs))
    256                 # Bit hacky, but correct.
    257                 if j == x.shape[1]:
    258                     break
    259             return _torch.cat(out_list, dim=1)
    260 
    261     @_abc.abstractmethod
    262     def _forward(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor:
    263         """
    264         The true forward method.
    265 
    266         :param x: (N,L1)
    267         :return: (N,L1-RF+1)
    268         """
    269         pass
    270 
    271     def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]:
    272         d = super()._get_non_user_metadata()
    273         d["loudness"] = self._metadata_loudness()
    274         d["gain"] = self._metadata_gain()
    275         return d