commit 7fa0fec58aa6d5600c1a104695e690353fbfa8de
parent a94215def62f69875d504c207fe993ea56ed8068
Author: honkkis <mikko.honkala@gmail.com>
Date: Fri, 24 Mar 2023 05:49:43 +0200
Added experimental multiresolution STFT loss using auraloss. (#147)
Added experimental multiresolution STFT loss using auraloss implementation.
Diffstat:
1 file changed, 36 insertions(+), 2 deletions(-)
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -54,6 +54,7 @@ class LossConfig(InitializableFromConfig):
https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
"""
+ mstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it
fourier: bool = False
mask_first: int = 0
dc_weight: float = 0.0
@@ -110,6 +111,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
self._optimizer_config = {} if optimizer_config is None else optimizer_config
self._scheduler_config = scheduler_config
self._loss_config = LossConfig() if loss_config is None else loss_config
+ self._mrstft = None
@classmethod
def init_from_config(cls, config):
@@ -173,6 +175,15 @@ class Model(pl.LightningModule, InitializableFromConfig):
def net(self) -> nn.Module:
return self._net
+ def initialize_losses(self):
+ if self._loss_config.mstft_weight > 0.0:
+ import auraloss
+ self._mrstft = auraloss.freq.MultiResolutionSTFTLoss()
+
+ def setup(self, stage):
+ super().setup(stage)
+ self.initialize_losses()
+
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config)
if self._scheduler_config is None:
@@ -186,7 +197,6 @@ class Model(pl.LightningModule, InitializableFromConfig):
if key in self._scheduler_config:
lr_scheduler_config[key] = self._scheduler_config[key]
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
-
def forward(self, *args, **kwargs):
return self.net(*args, **kwargs)
@@ -205,6 +215,8 @@ class Model(pl.LightningModule, InitializableFromConfig):
loss = loss + mse_fft(preds, targets)
else:
loss = loss + self._mse_loss(preds, targets)
+ if self._loss_config.mstft_weight > 0.0 and self._mrstft is not None:
+ loss = loss + self._loss_config.mstft_weight * self._mrstft_loss(preds, targets)
# Pre-emphasized MSE
if self._loss_config.pre_emph_weight is not None:
if (self._loss_config.pre_emph_coef is None) != (
@@ -234,7 +246,11 @@ class Model(pl.LightningModule, InitializableFromConfig):
val_loss = {ValidationLoss.MSE: mse_loss, ValidationLoss.ESR: esr_loss}[
self._loss_config.val_loss
]
- self.log_dict({"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss})
+ dict_to_log = {"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss}
+ if self._loss_config.mstft_weight > 0.0 and self._mrstft is not None:
+ mrstft_loss = self._mrstft_loss(preds, targets)
+ dict_to_log.update({"MRSTFT": mrstft_loss})
+ self.log_dict(dict_to_log)
return val_loss
def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
@@ -259,3 +275,20 @@ class Model(pl.LightningModule, InitializableFromConfig):
z[..., 1:] - pre_emph_coef * z[..., :-1] for z in (preds, targets)
]
return nn.MSELoss()(preds, targets)
+
+ def _mrstft_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """
+ Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
+ B: Batch size
+ L: Sequence length
+
+ :param preds: (B,L)
+ :param targets: (B,L)
+ :return: ()
+ """
+ device = 'cpu' # not all platforms support this on gpu yet
+ preds_cpu = preds.to(device)
+ targets_cpu = targets.to(device)
+
+ loss = self._mrstft(preds_cpu, targets_cpu)
+ return loss
+\ No newline at end of file