neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 92a241f230743536afef4da92c0090829b59dbd4
parent a2f9dbcc1d9dfd6488de7df899decb9507a90957
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 18 Mar 2023 12:08:08 -0500

FFT Loss (#143)

Implement FFT loss
Diffstat:
Mnam/models/base.py | 10++++++++--
Mnam/models/losses.py | 14++++++++++++++
2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/nam/models/base.py b/nam/models/base.py @@ -21,7 +21,7 @@ import torch.nn as nn from .._core import InitializableFromConfig from .conv_net import ConvNet from .linear import Linear -from .losses import esr +from .losses import esr, mse_fft from .parametric.catnets import CatLSTM, CatWaveNet from .parametric.hyper_net import HyperConvNet from .recurrent import LSTM @@ -54,6 +54,7 @@ class LossConfig(InitializableFromConfig): https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. """ + fourier: bool = False mask_first: int = 0 dc_weight: float = 0.0 val_loss: ValidationLoss = ValidationLoss.MSE @@ -63,12 +64,14 @@ class LossConfig(InitializableFromConfig): @classmethod def parse_config(cls, config): config = super().parse_config(config) + fourier = config.get("fourier", False) dc_weight = config.get("dc_weight", 0.0) val_loss = ValidationLoss(config.get("val_loss", "mse")) mask_first = config.get("mask_first", 0) pre_emph_coef = config.get("pre_emph_coef") pre_emph_weight = config.get("pre_emph_weight") return { + "fourier": fourier, "mask_first": mask_first, "dc_weight": dc_weight, "val_loss": val_loss, @@ -198,7 +201,10 @@ class Model(pl.LightningModule, InitializableFromConfig): loss = 0.0 # Prediction aka MSE loss - loss = loss + self._mse_loss(preds, targets) + if self._loss_config.fourier: + loss = loss + mse_fft(preds, targets) + else: + loss = loss + self._mse_loss(preds, targets) # Pre-emphasized MSE if self._loss_config.pre_emph_weight is not None: if (self._loss_config.pre_emph_coef is None) != ( diff --git a/nam/models/losses.py b/nam/models/losses.py @@ -31,3 +31,17 @@ def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: torch.mean(torch.square(preds - targets), dim=1) / torch.mean(torch.square(targets), dim=1) ) + + +def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Fourier loss + + :param preds: (N,) or (B,N) + :param targets: Same as preds + :return: () + """ + fp = torch.fft.fft(preds) + ft = torch.fft.fft(targets) + e = fp - ft + return torch.mean(torch.square(e.abs()))