neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 8a3df26731d8207e04ddffbe7d889324a3f3b98d
parent 384b06fac6da3d0cfe898e2999dcc0e830a5efe3
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 29 Jan 2023 00:07:27 -0800

Fix ESR loss (#73)

* Pencil in requirements

* Python package workflow

* install package, trigger on dev branch

* Quick ConvNet tests

* Colab Notebook test

* Fix Issue 15

* Update colab.ipynb

* Bump to version 0.2.1

* Should be good

* Update README.md

* Update README.md

* Parametric model (#26)

* Parametric data and Hypernet
* Validte data print path if fail
* Affine instead of layer norm, batchnorm momentum
* Tweak batchnorm momentum
* Fix test so data doesn't clip

* Exporting Hypernet models (#28)

* Export for HyperConvNet, tests

* Fix bool in CPP header

* DC loss (#32)

Closes #31

Implement DC loss of Eq. (19) of https://www.mdpi.com/2076-3417/10/3/766/htm, but without denominator term.

* ESR Loss (#33)

Closes #10

Energy-signal ratio loss metric of Eq. (10) of https://www.mdpi.com/2076-3417/10/3/766/htm

* tqdm loading data

* LSTM (#35)

Closes #24 

* Implement an LSTM model.
* Config with some decent starting parameters as well as some hints.
* Some improvements to the training script
* Some refactoring to model exporting

Squash of:

* Better callbacks

* Rearrange training input JSONs

* np_to_wav, expose REQUIRED_RATE

* LSTM model

* Version bump to 0.3.0

* Conditional LSTM (#38)

Closes #36 

* A conditional LSTM where the input signal is concatenated with the parametric inputs.
* Example configurations with helpful tips commented.
* Quality of life improvements in the trainer script including functionality for `ConcatDataset`-type validation datasets
* Slicing of a single pair of WAV files into datasets at different parametric settings.

Squash of:

* Better callbacks

* Rearrange training input JSONs

* np_to_wav, expose REQUIRED_RATE

* LSTM model

* Version bump to 0.3.0

* Plot ConcatDatasets

* Flatten datasets inside a ConcatDataset

* CatLSTM

* Tests

* Config for CatLSTM

* Better error message on invalid nx

* Fix export for parametric models (#40)

* Fix export for parametric models
* Version bump to 0.3.1

* Improve docstring

* WaveNet and other improvements (#46)

Smashing together a few things...

* Implements the WaveNet architecture (#43) (and the parametric version, concatenating the knobs as additional inputs i.e. "CatWaveNet")
* Speed up `ConcatDataset` access (#45)
* Deprecate use of `"nx"` in dataset config (#44)
* Increment version to 0.4.0

Commit notes:

* WaveNet

* Fix invalid broadcasting, rechannel needed

* CatWaveNet

* Faster lookup of data in ConcatDataset

* Gated, exporting weights work

* Fix bugs

* Fix export for parametric models

* Version bump to 0.3.1

* Exporting

* Move bias from the input mixer to the dilated conv, which is always used.

* Fix redundant conv in WaveNet head

* Fix bugs; works with plugin

* Automatically add nx to data config from model, check fewer than once per epoch

* Refactor for multiple layer sets

* Zero out through-connection for init (learn direct paths)

* Layer send to head before 1x1, remove zeroing init

* Fix bug: reintroduce out_length in head term

* Fix receptive field w/ non-2 kernel sizes

* Fix up WaveNet export

* cpp headers for WaveNet and CatWaveNet

* Improve docstring

* LSTM tweaks, etc (#47)

* Tweak default parameters of LSTM configs
* Pre-emphasis filtered loss (#42)
* LSTM export C++ header
* Improve printing on the training figures

* Update config files

* Update README.md

* Update README.md

* Update Colab notebook

* Update Colab notebook

* Input gain (#50)

* Input gain for data sets

Adds parameter `input_gain` for data sets, default at 0 (unity)
New test, passes.

* Version bump

* Fix Issue 52

* Losses module

* Black, use esr()

* Black

* Test

* Rename test

* Don't bump version yet

* Remove redundant tests
Diffstat:
Mnam/__init__.py | 1+
Mnam/data.py | 44+++++++++++++++++++++++++++++++++++++++++---
Mnam/models/__init__.py | 2++
Mnam/models/_exportable.py | 2+-
Mnam/models/base.py | 40++++++++++++++++++++++++----------------
Anam/models/losses.py | 33+++++++++++++++++++++++++++++++++
Mtests/test_nam/test_data.py | 25+++++++++++++++++++++++++
Atests/test_nam/test_models/test_losses.py | 41+++++++++++++++++++++++++++++++++++++++++
8 files changed, 168 insertions(+), 20 deletions(-)

diff --git a/nam/__init__.py b/nam/__init__.py @@ -7,3 +7,4 @@ from ._version import __version__ # Must be before models or else circular from . import _core # noqa F401 from . import data # noqa F401 from . import models # noqa F401 +from . import train # noqa F401 diff --git a/nam/data.py b/nam/data.py @@ -107,11 +107,16 @@ def np_to_wav( class AbstractDataset(_Dataset, abc.ABC): @abc.abstractmethod def __getitem__( - self, idx + self, idx: int ) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ]: + """ + :return: + Case 1: Input (N1,), Output (N2,) + Case 2: Parameters (D,), Input (N1,), Output (N2,) + """ pass @@ -134,11 +139,31 @@ class Dataset(AbstractDataset, InitializableFromConfig): y_scale: float = 1.0, x_path: Optional[Union[str, Path]] = None, y_path: Optional[Union[str, Path]] = None, + input_gain: float=0.0 ): """ - :param start: In samples - :param stop: In samples + :param x: The input signal. A 1D array. + :param y: The associated output from the model. A 1D array. + :param nx: The number of samples required as input for the model. For example, + for a ConvNet, this would be the receptive field. + :param ny: How many samples to provide as the output array for a single "datum". + It's usually more computationally-efficient to provide a larger `ny` than 1 + so that the forward pass can process more audio all at once. However, this + shouldn't be too large or else you won't be able to provide a large batch + size (where each input-output pair could be something substantially + different and improve batch diversity). + :param start: In samples; clip x and y up to this point. + :param stop: In samples; clip x and y past this point. + :param y_scale: Multiplies the output signal by a factor (e.g. if the data are + too quiet). :param delay: In samples. Positive means we get rid of the start of x, end of y. + :param input_gain: In dB. If the input signal wasn't fed to the amp at unity + gain, you can indicate the gain here. The data set will multipy the raw + audio file by the specified gain so that the true input signal amplitude + experienced by the signal chain will be provided as input to the model. If + you are using a reamping setup, you can estimate this by reamping a + completely dry signal (i.e. connecting the interface output directly back + into the input with which the guitar was originally recorded.) """ x, y = [z[start:stop] for z in (x, y)] if delay is not None: @@ -148,6 +173,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): elif delay < 0: x = x[-delay:] y = y[:delay] + x_scale = 10.0 ** (input_gain / 20.0) + x = x * x_scale y = y * y_scale self._x_path = x_path self._y_path = y_path @@ -158,6 +185,11 @@ class Dataset(AbstractDataset, InitializableFromConfig): self._ny = ny if ny is not None else len(x) - nx + 1 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :return: + Input (NX+NY-1,) + Output (NY,) + """ if idx >= len(self): raise IndexError(f"Attempted to access datum {idx}, but len is {len(self)}") i = idx * self._ny @@ -271,6 +303,12 @@ class ParametricDataset(Dataset): return config, x, y, slices def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :return: + Parameter values (D,) + Input (NX+NY-1,) + Output (NY,) + """ # FIXME don't override signature x, y = super().__getitem__(idx) return self.vals, x, y diff --git a/nam/models/__init__.py b/nam/models/__init__.py @@ -4,6 +4,8 @@ from . import _base # noqa F401 from . import _exportable # noqa F401 +from . import losses # noqa F401 +from . import wavenet # noqa F401 from .base import Model # noqa F401 from .linear import Linear # noqa F401 from .conv_net import ConvNet # noqa F401 diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -61,7 +61,7 @@ class Exportable(abc.ABC): @abc.abstractmethod def _export_config(self): """ - Creates the JSON of the model's archtecture hyperparameters (number of layers, + Creates the JSON of the model's archtecture hyperparameters (number of layers, number of units, etc) :return: a JSON serializable object diff --git a/nam/models/base.py b/nam/models/base.py @@ -21,6 +21,7 @@ import torch.nn as nn from .._core import InitializableFromConfig from .conv_net import ConvNet from .linear import Linear +from .losses import esr from .parametric.catnets import CatLSTM, CatWaveNet from .parametric.hyper_net import HyperConvNet from .recurrent import LSTM @@ -30,12 +31,12 @@ from .wavenet import WaveNet class ValidationLoss(Enum): """ mse: mean squared error - esr: error signal ratio (Eq. (10) from + esr: error signal ratio (Eq. (10) from https://www.mdpi.com/2076-3417/10/3/766/htm - NOTE: Be careful when computing ESR on minibatches! The average ESR over - a minibatch of data not the same as the ESR of all of the same data in - the minibatch calculated over at once (because of the denominator). - (Hint: think about what happens if one item in the minibatch is all + NOTE: Be careful when computing ESR on minibatches! The average ESR over + a minibatch of data not the same as the ESR of all of the same data in + the minibatch calculated over at once (because of the denominator). + (Hint: think about what happens if one item in the minibatch is all zeroes...) """ @@ -49,7 +50,7 @@ class LossConfig(InitializableFromConfig): :param mask_first: How many of the first samples to ignore when comptuing the loss. :param dc_weight: Weight for the DC loss term. If 0, ignored. :params val_loss: Which loss to track for the best model checkpoint. - :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from + :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. """ @@ -68,11 +69,11 @@ class LossConfig(InitializableFromConfig): pre_emph_coef = config.get("pre_emph_coef") pre_emph_weight = config.get("pre_emph_weight") return { - "mask_first": mask_first, - "dc_weight": dc_weight, - "val_loss": val_loss, + "mask_first": mask_first, + "dc_weight": dc_weight, + "val_loss": val_loss, "pre_emph_coef": pre_emph_coef, - "pre_emph_weight": pre_emph_weight + "pre_emph_weight": pre_emph_weight, } def apply_mask(self, *args): @@ -120,8 +121,8 @@ class Model(pl.LightningModule, InitializableFromConfig): @classmethod def parse_config(cls, config): """ - e.g. - + e.g. + { "net": { "name": "ConvNet", @@ -144,7 +145,7 @@ class Model(pl.LightningModule, InitializableFromConfig): }, "monitor": "val_loss" } - } + } """ config = super().parse_config(config) net_config = config["net"] @@ -230,16 +231,23 @@ class Model(pl.LightningModule, InitializableFromConfig): self.log_dict({"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss}) return val_loss - def _esr_loss(self, preds, targets): + def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Error signal ratio aka ESR loss. Eq. (10), from https://www.mdpi.com/2076-3417/10/3/766/htm + + B: Batch size + L: Sequence length + + :param preds: (B,L) + :param targets: (B,L) + :return: () """ - return nn.MSELoss()(preds, targets) / nn.MSELoss()(targets, 0.0 * targets) + return esr(preds, targets) - def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float]=None): + def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None): if pre_emph_coef is not None: preds, targets = [ z[..., 1:] - pre_emph_coef * z[..., :-1] for z in (preds, targets) diff --git a/nam/models/losses.py b/nam/models/losses.py @@ -0,0 +1,33 @@ +# File: losses.py +# Created Date: Sunday January 22nd 2023 +# Author: Steven Atkinson (steven@atkinson.mn) + +""" +Loss functions +""" + +import torch + + +def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + ESR of (a batch of) predictions & targets + + :param preds: (N,) or (B,N) + :param targets: Same as preds + :return: () + """ + if preds.ndim == 1 and targets.ndim == 1: + preds, targets = preds[None], targets[None] + if preds.ndim != 2: + raise ValueError( + f"Expect 2D predictions (batch_size, num_samples). Got {preds.shape}" + ) + if targets.ndim != 2: + raise ValueError( + f"Expect 2D targets (batch_size, num_samples). Got {targets.shape}" + ) + return torch.mean( + torch.mean(torch.square(preds - targets), dim=1) + / torch.mean(torch.square(targets), dim=1) + ) diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -2,6 +2,8 @@ # Created Date: Friday May 6th 2022 # Author: Steven Atkinson (steven@atkinson.mn) +import math + import pytest import torch @@ -20,5 +22,27 @@ class TestDataset(object): x, y = self._create_xy() data.Dataset(x, y, 3, None, delay=0) + + def test_input_gain(self): + """ + Checks correctness of input gain parameter + """ + x_scale = 2.0 + input_gain = 20.0 * math.log10(x_scale) + x, y = self._create_xy() + nx = 3 + ny = None + args = (x, y, nx, ny) + d1 = data.Dataset(*args) + d2 = data.Dataset(*args, input_gain=input_gain) + + sample_x1 = d1[0][0] + sample_x2 = d2[0][0] + assert torch.allclose(sample_x1 * x_scale, sample_x2) + def _create_xy(self): return 0.99 * (2.0 * torch.rand((2, 7)) - 1.0) # Don't clip + + +if __name__ == "__main__": + pytest.main() +\ No newline at end of file diff --git a/tests/test_nam/test_models/test_losses.py b/tests/test_nam/test_models/test_losses.py @@ -0,0 +1,41 @@ +# File: test_losses.py +# Created Date: Saturday January 28th 2023 +# Author: Steven Atkinson (steven@atkinson.mn) + +import pytest +import torch +import torch.nn as nn + +from nam.models import losses + + +def test_esr(): + """ + Is the ESR calculation correct? + """ + + class Model(nn.Module): + def forward(self, x): + return x + + batch_size, input_length = 3, 5 + inputs = ( + torch.linspace(0.1, 1.0, batch_size)[:, None] + * torch.full((input_length,), 1.0)[None, :] + ) # (batch_size, input_length) + target_factor = torch.linspace(0.37, 1.22, batch_size) + targets = target_factor[:, None] * inputs # (batch_size, input_length) + # Do the algebra: + # y=a*yhat + # ESR=(y-yhat)^2 / y^2 + # ... + # =(1/a-1)^2 + expected_esr = torch.square(1.0 / target_factor - 1.0).mean() + model = Model() + preds = model(inputs) + actual_esr = losses.esr(preds, targets) + assert torch.allclose(actual_esr, expected_esr) + + +if __name__ == "__main__": + pytest.main()