neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 7fa0fec58aa6d5600c1a104695e690353fbfa8de
parent a94215def62f69875d504c207fe993ea56ed8068
Author: honkkis <mikko.honkala@gmail.com>
Date:   Fri, 24 Mar 2023 05:49:43 +0200

Added experimental multiresolution STFT loss using auraloss. (#147)

Added experimental multiresolution STFT loss using auraloss implementation.
Diffstat:
Mnam/models/base.py | 38++++++++++++++++++++++++++++++++++++--
1 file changed, 36 insertions(+), 2 deletions(-)

diff --git a/nam/models/base.py b/nam/models/base.py @@ -54,6 +54,7 @@ class LossConfig(InitializableFromConfig): https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. """ + mstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it fourier: bool = False mask_first: int = 0 dc_weight: float = 0.0 @@ -110,6 +111,7 @@ class Model(pl.LightningModule, InitializableFromConfig): self._optimizer_config = {} if optimizer_config is None else optimizer_config self._scheduler_config = scheduler_config self._loss_config = LossConfig() if loss_config is None else loss_config + self._mrstft = None @classmethod def init_from_config(cls, config): @@ -173,6 +175,15 @@ class Model(pl.LightningModule, InitializableFromConfig): def net(self) -> nn.Module: return self._net + def initialize_losses(self): + if self._loss_config.mstft_weight > 0.0: + import auraloss + self._mrstft = auraloss.freq.MultiResolutionSTFTLoss() + + def setup(self, stage): + super().setup(stage) + self.initialize_losses() + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config) if self._scheduler_config is None: @@ -186,7 +197,6 @@ class Model(pl.LightningModule, InitializableFromConfig): if key in self._scheduler_config: lr_scheduler_config[key] = self._scheduler_config[key] return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - def forward(self, *args, **kwargs): return self.net(*args, **kwargs) @@ -205,6 +215,8 @@ class Model(pl.LightningModule, InitializableFromConfig): loss = loss + mse_fft(preds, targets) else: loss = loss + self._mse_loss(preds, targets) + if self._loss_config.mstft_weight > 0.0 and self._mrstft is not None: + loss = loss + self._loss_config.mstft_weight * self._mrstft_loss(preds, targets) # Pre-emphasized MSE if self._loss_config.pre_emph_weight is not None: if (self._loss_config.pre_emph_coef is None) != ( @@ -234,7 +246,11 @@ class Model(pl.LightningModule, InitializableFromConfig): val_loss = {ValidationLoss.MSE: mse_loss, ValidationLoss.ESR: esr_loss}[ self._loss_config.val_loss ] - self.log_dict({"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss}) + dict_to_log = {"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss} + if self._loss_config.mstft_weight > 0.0 and self._mrstft is not None: + mrstft_loss = self._mrstft_loss(preds, targets) + dict_to_log.update({"MRSTFT": mrstft_loss}) + self.log_dict(dict_to_log) return val_loss def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: @@ -259,3 +275,20 @@ class Model(pl.LightningModule, InitializableFromConfig): z[..., 1:] - pre_emph_coef * z[..., :-1] for z in (preds, targets) ] return nn.MSELoss()(preds, targets) + + def _mrstft_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. + B: Batch size + L: Sequence length + + :param preds: (B,L) + :param targets: (B,L) + :return: () + """ + device = 'cpu' # not all platforms support this on gpu yet + preds_cpu = preds.to(device) + targets_cpu = targets.to(device) + + loss = self._mrstft(preds_cpu, targets_cpu) + return loss +\ No newline at end of file