commit 8a3df26731d8207e04ddffbe7d889324a3f3b98d
parent 384b06fac6da3d0cfe898e2999dcc0e830a5efe3
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 29 Jan 2023 00:07:27 -0800
Fix ESR loss (#73)
* Pencil in requirements
* Python package workflow
* install package, trigger on dev branch
* Quick ConvNet tests
* Colab Notebook test
* Fix Issue 15
* Update colab.ipynb
* Bump to version 0.2.1
* Should be good
* Update README.md
* Update README.md
* Parametric model (#26)
* Parametric data and Hypernet
* Validte data print path if fail
* Affine instead of layer norm, batchnorm momentum
* Tweak batchnorm momentum
* Fix test so data doesn't clip
* Exporting Hypernet models (#28)
* Export for HyperConvNet, tests
* Fix bool in CPP header
* DC loss (#32)
Closes #31
Implement DC loss of Eq. (19) of https://www.mdpi.com/2076-3417/10/3/766/htm, but without denominator term.
* ESR Loss (#33)
Closes #10
Energy-signal ratio loss metric of Eq. (10) of https://www.mdpi.com/2076-3417/10/3/766/htm
* tqdm loading data
* LSTM (#35)
Closes #24
* Implement an LSTM model.
* Config with some decent starting parameters as well as some hints.
* Some improvements to the training script
* Some refactoring to model exporting
Squash of:
* Better callbacks
* Rearrange training input JSONs
* np_to_wav, expose REQUIRED_RATE
* LSTM model
* Version bump to 0.3.0
* Conditional LSTM (#38)
Closes #36
* A conditional LSTM where the input signal is concatenated with the parametric inputs.
* Example configurations with helpful tips commented.
* Quality of life improvements in the trainer script including functionality for `ConcatDataset`-type validation datasets
* Slicing of a single pair of WAV files into datasets at different parametric settings.
Squash of:
* Better callbacks
* Rearrange training input JSONs
* np_to_wav, expose REQUIRED_RATE
* LSTM model
* Version bump to 0.3.0
* Plot ConcatDatasets
* Flatten datasets inside a ConcatDataset
* CatLSTM
* Tests
* Config for CatLSTM
* Better error message on invalid nx
* Fix export for parametric models (#40)
* Fix export for parametric models
* Version bump to 0.3.1
* Improve docstring
* WaveNet and other improvements (#46)
Smashing together a few things...
* Implements the WaveNet architecture (#43) (and the parametric version, concatenating the knobs as additional inputs i.e. "CatWaveNet")
* Speed up `ConcatDataset` access (#45)
* Deprecate use of `"nx"` in dataset config (#44)
* Increment version to 0.4.0
Commit notes:
* WaveNet
* Fix invalid broadcasting, rechannel needed
* CatWaveNet
* Faster lookup of data in ConcatDataset
* Gated, exporting weights work
* Fix bugs
* Fix export for parametric models
* Version bump to 0.3.1
* Exporting
* Move bias from the input mixer to the dilated conv, which is always used.
* Fix redundant conv in WaveNet head
* Fix bugs; works with plugin
* Automatically add nx to data config from model, check fewer than once per epoch
* Refactor for multiple layer sets
* Zero out through-connection for init (learn direct paths)
* Layer send to head before 1x1, remove zeroing init
* Fix bug: reintroduce out_length in head term
* Fix receptive field w/ non-2 kernel sizes
* Fix up WaveNet export
* cpp headers for WaveNet and CatWaveNet
* Improve docstring
* LSTM tweaks, etc (#47)
* Tweak default parameters of LSTM configs
* Pre-emphasis filtered loss (#42)
* LSTM export C++ header
* Improve printing on the training figures
* Update config files
* Update README.md
* Update README.md
* Update Colab notebook
* Update Colab notebook
* Input gain (#50)
* Input gain for data sets
Adds parameter `input_gain` for data sets, default at 0 (unity)
New test, passes.
* Version bump
* Fix Issue 52
* Losses module
* Black, use esr()
* Black
* Test
* Rename test
* Don't bump version yet
* Remove redundant tests
Diffstat:
8 files changed, 168 insertions(+), 20 deletions(-)
diff --git a/nam/__init__.py b/nam/__init__.py
@@ -7,3 +7,4 @@ from ._version import __version__ # Must be before models or else circular
from . import _core # noqa F401
from . import data # noqa F401
from . import models # noqa F401
+from . import train # noqa F401
diff --git a/nam/data.py b/nam/data.py
@@ -107,11 +107,16 @@ def np_to_wav(
class AbstractDataset(_Dataset, abc.ABC):
@abc.abstractmethod
def __getitem__(
- self, idx
+ self, idx: int
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
+ """
+ :return:
+ Case 1: Input (N1,), Output (N2,)
+ Case 2: Parameters (D,), Input (N1,), Output (N2,)
+ """
pass
@@ -134,11 +139,31 @@ class Dataset(AbstractDataset, InitializableFromConfig):
y_scale: float = 1.0,
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
+ input_gain: float=0.0
):
"""
- :param start: In samples
- :param stop: In samples
+ :param x: The input signal. A 1D array.
+ :param y: The associated output from the model. A 1D array.
+ :param nx: The number of samples required as input for the model. For example,
+ for a ConvNet, this would be the receptive field.
+ :param ny: How many samples to provide as the output array for a single "datum".
+ It's usually more computationally-efficient to provide a larger `ny` than 1
+ so that the forward pass can process more audio all at once. However, this
+ shouldn't be too large or else you won't be able to provide a large batch
+ size (where each input-output pair could be something substantially
+ different and improve batch diversity).
+ :param start: In samples; clip x and y up to this point.
+ :param stop: In samples; clip x and y past this point.
+ :param y_scale: Multiplies the output signal by a factor (e.g. if the data are
+ too quiet).
:param delay: In samples. Positive means we get rid of the start of x, end of y.
+ :param input_gain: In dB. If the input signal wasn't fed to the amp at unity
+ gain, you can indicate the gain here. The data set will multipy the raw
+ audio file by the specified gain so that the true input signal amplitude
+ experienced by the signal chain will be provided as input to the model. If
+ you are using a reamping setup, you can estimate this by reamping a
+ completely dry signal (i.e. connecting the interface output directly back
+ into the input with which the guitar was originally recorded.)
"""
x, y = [z[start:stop] for z in (x, y)]
if delay is not None:
@@ -148,6 +173,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
elif delay < 0:
x = x[-delay:]
y = y[:delay]
+ x_scale = 10.0 ** (input_gain / 20.0)
+ x = x * x_scale
y = y * y_scale
self._x_path = x_path
self._y_path = y_path
@@ -158,6 +185,11 @@ class Dataset(AbstractDataset, InitializableFromConfig):
self._ny = ny if ny is not None else len(x) - nx + 1
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ :return:
+ Input (NX+NY-1,)
+ Output (NY,)
+ """
if idx >= len(self):
raise IndexError(f"Attempted to access datum {idx}, but len is {len(self)}")
i = idx * self._ny
@@ -271,6 +303,12 @@ class ParametricDataset(Dataset):
return config, x, y, slices
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ :return:
+ Parameter values (D,)
+ Input (NX+NY-1,)
+ Output (NY,)
+ """
# FIXME don't override signature
x, y = super().__getitem__(idx)
return self.vals, x, y
diff --git a/nam/models/__init__.py b/nam/models/__init__.py
@@ -4,6 +4,8 @@
from . import _base # noqa F401
from . import _exportable # noqa F401
+from . import losses # noqa F401
+from . import wavenet # noqa F401
from .base import Model # noqa F401
from .linear import Linear # noqa F401
from .conv_net import ConvNet # noqa F401
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -61,7 +61,7 @@ class Exportable(abc.ABC):
@abc.abstractmethod
def _export_config(self):
"""
- Creates the JSON of the model's archtecture hyperparameters (number of layers,
+ Creates the JSON of the model's archtecture hyperparameters (number of layers,
number of units, etc)
:return: a JSON serializable object
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -21,6 +21,7 @@ import torch.nn as nn
from .._core import InitializableFromConfig
from .conv_net import ConvNet
from .linear import Linear
+from .losses import esr
from .parametric.catnets import CatLSTM, CatWaveNet
from .parametric.hyper_net import HyperConvNet
from .recurrent import LSTM
@@ -30,12 +31,12 @@ from .wavenet import WaveNet
class ValidationLoss(Enum):
"""
mse: mean squared error
- esr: error signal ratio (Eq. (10) from
+ esr: error signal ratio (Eq. (10) from
https://www.mdpi.com/2076-3417/10/3/766/htm
- NOTE: Be careful when computing ESR on minibatches! The average ESR over
- a minibatch of data not the same as the ESR of all of the same data in
- the minibatch calculated over at once (because of the denominator).
- (Hint: think about what happens if one item in the minibatch is all
+ NOTE: Be careful when computing ESR on minibatches! The average ESR over
+ a minibatch of data not the same as the ESR of all of the same data in
+ the minibatch calculated over at once (because of the denominator).
+ (Hint: think about what happens if one item in the minibatch is all
zeroes...)
"""
@@ -49,7 +50,7 @@ class LossConfig(InitializableFromConfig):
:param mask_first: How many of the first samples to ignore when comptuing the loss.
:param dc_weight: Weight for the DC loss term. If 0, ignored.
:params val_loss: Which loss to track for the best model checkpoint.
- :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from
+ :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from
https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
"""
@@ -68,11 +69,11 @@ class LossConfig(InitializableFromConfig):
pre_emph_coef = config.get("pre_emph_coef")
pre_emph_weight = config.get("pre_emph_weight")
return {
- "mask_first": mask_first,
- "dc_weight": dc_weight,
- "val_loss": val_loss,
+ "mask_first": mask_first,
+ "dc_weight": dc_weight,
+ "val_loss": val_loss,
"pre_emph_coef": pre_emph_coef,
- "pre_emph_weight": pre_emph_weight
+ "pre_emph_weight": pre_emph_weight,
}
def apply_mask(self, *args):
@@ -120,8 +121,8 @@ class Model(pl.LightningModule, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
"""
- e.g.
-
+ e.g.
+
{
"net": {
"name": "ConvNet",
@@ -144,7 +145,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
},
"monitor": "val_loss"
}
- }
+ }
"""
config = super().parse_config(config)
net_config = config["net"]
@@ -230,16 +231,23 @@ class Model(pl.LightningModule, InitializableFromConfig):
self.log_dict({"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss})
return val_loss
- def _esr_loss(self, preds, targets):
+ def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Error signal ratio aka ESR loss.
Eq. (10), from
https://www.mdpi.com/2076-3417/10/3/766/htm
+
+ B: Batch size
+ L: Sequence length
+
+ :param preds: (B,L)
+ :param targets: (B,L)
+ :return: ()
"""
- return nn.MSELoss()(preds, targets) / nn.MSELoss()(targets, 0.0 * targets)
+ return esr(preds, targets)
- def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float]=None):
+ def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None):
if pre_emph_coef is not None:
preds, targets = [
z[..., 1:] - pre_emph_coef * z[..., :-1] for z in (preds, targets)
diff --git a/nam/models/losses.py b/nam/models/losses.py
@@ -0,0 +1,33 @@
+# File: losses.py
+# Created Date: Sunday January 22nd 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+"""
+Loss functions
+"""
+
+import torch
+
+
+def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """
+ ESR of (a batch of) predictions & targets
+
+ :param preds: (N,) or (B,N)
+ :param targets: Same as preds
+ :return: ()
+ """
+ if preds.ndim == 1 and targets.ndim == 1:
+ preds, targets = preds[None], targets[None]
+ if preds.ndim != 2:
+ raise ValueError(
+ f"Expect 2D predictions (batch_size, num_samples). Got {preds.shape}"
+ )
+ if targets.ndim != 2:
+ raise ValueError(
+ f"Expect 2D targets (batch_size, num_samples). Got {targets.shape}"
+ )
+ return torch.mean(
+ torch.mean(torch.square(preds - targets), dim=1)
+ / torch.mean(torch.square(targets), dim=1)
+ )
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -2,6 +2,8 @@
# Created Date: Friday May 6th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
+import math
+
import pytest
import torch
@@ -20,5 +22,27 @@ class TestDataset(object):
x, y = self._create_xy()
data.Dataset(x, y, 3, None, delay=0)
+
+ def test_input_gain(self):
+ """
+ Checks correctness of input gain parameter
+ """
+ x_scale = 2.0
+ input_gain = 20.0 * math.log10(x_scale)
+ x, y = self._create_xy()
+ nx = 3
+ ny = None
+ args = (x, y, nx, ny)
+ d1 = data.Dataset(*args)
+ d2 = data.Dataset(*args, input_gain=input_gain)
+
+ sample_x1 = d1[0][0]
+ sample_x2 = d2[0][0]
+ assert torch.allclose(sample_x1 * x_scale, sample_x2)
+
def _create_xy(self):
return 0.99 * (2.0 * torch.rand((2, 7)) - 1.0) # Don't clip
+
+
+if __name__ == "__main__":
+ pytest.main()
+\ No newline at end of file
diff --git a/tests/test_nam/test_models/test_losses.py b/tests/test_nam/test_models/test_losses.py
@@ -0,0 +1,41 @@
+# File: test_losses.py
+# Created Date: Saturday January 28th 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+import pytest
+import torch
+import torch.nn as nn
+
+from nam.models import losses
+
+
+def test_esr():
+ """
+ Is the ESR calculation correct?
+ """
+
+ class Model(nn.Module):
+ def forward(self, x):
+ return x
+
+ batch_size, input_length = 3, 5
+ inputs = (
+ torch.linspace(0.1, 1.0, batch_size)[:, None]
+ * torch.full((input_length,), 1.0)[None, :]
+ ) # (batch_size, input_length)
+ target_factor = torch.linspace(0.37, 1.22, batch_size)
+ targets = target_factor[:, None] * inputs # (batch_size, input_length)
+ # Do the algebra:
+ # y=a*yhat
+ # ESR=(y-yhat)^2 / y^2
+ # ...
+ # =(1/a-1)^2
+ expected_esr = torch.square(1.0 / target_factor - 1.0).mean()
+ model = Model()
+ preds = model(inputs)
+ actual_esr = losses.esr(preds, targets)
+ assert torch.allclose(actual_esr, expected_esr)
+
+
+if __name__ == "__main__":
+ pytest.main()