commit 92a241f230743536afef4da92c0090829b59dbd4
parent a2f9dbcc1d9dfd6488de7df899decb9507a90957
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 18 Mar 2023 12:08:08 -0500
FFT Loss (#143)
Implement FFT loss
Diffstat:
2 files changed, 22 insertions(+), 2 deletions(-)
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -21,7 +21,7 @@ import torch.nn as nn
from .._core import InitializableFromConfig
from .conv_net import ConvNet
from .linear import Linear
-from .losses import esr
+from .losses import esr, mse_fft
from .parametric.catnets import CatLSTM, CatWaveNet
from .parametric.hyper_net import HyperConvNet
from .recurrent import LSTM
@@ -54,6 +54,7 @@ class LossConfig(InitializableFromConfig):
https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
"""
+ fourier: bool = False
mask_first: int = 0
dc_weight: float = 0.0
val_loss: ValidationLoss = ValidationLoss.MSE
@@ -63,12 +64,14 @@ class LossConfig(InitializableFromConfig):
@classmethod
def parse_config(cls, config):
config = super().parse_config(config)
+ fourier = config.get("fourier", False)
dc_weight = config.get("dc_weight", 0.0)
val_loss = ValidationLoss(config.get("val_loss", "mse"))
mask_first = config.get("mask_first", 0)
pre_emph_coef = config.get("pre_emph_coef")
pre_emph_weight = config.get("pre_emph_weight")
return {
+ "fourier": fourier,
"mask_first": mask_first,
"dc_weight": dc_weight,
"val_loss": val_loss,
@@ -198,7 +201,10 @@ class Model(pl.LightningModule, InitializableFromConfig):
loss = 0.0
# Prediction aka MSE loss
- loss = loss + self._mse_loss(preds, targets)
+ if self._loss_config.fourier:
+ loss = loss + mse_fft(preds, targets)
+ else:
+ loss = loss + self._mse_loss(preds, targets)
# Pre-emphasized MSE
if self._loss_config.pre_emph_weight is not None:
if (self._loss_config.pre_emph_coef is None) != (
diff --git a/nam/models/losses.py b/nam/models/losses.py
@@ -31,3 +31,17 @@ def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
torch.mean(torch.square(preds - targets), dim=1)
/ torch.mean(torch.square(targets), dim=1)
)
+
+
+def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """
+ Fourier loss
+
+ :param preds: (N,) or (B,N)
+ :param targets: Same as preds
+ :return: ()
+ """
+ fp = torch.fft.fft(preds)
+ ft = torch.fft.fft(targets)
+ e = fp - ft
+ return torch.mean(torch.square(e.abs()))