neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

losses.py (2437B)


      1 # File: losses.py
      2 # Created Date: Sunday January 22nd 2023
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 Loss functions
      7 """
      8 
      9 from typing import Optional as _Optional
     10 
     11 import torch as _torch
     12 from auraloss.freq import MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss
     13 
     14 
     15 def apply_pre_emphasis_filter(x: _torch.Tensor, coef: float) -> _torch.Tensor:
     16     """
     17     Apply first-order pre-emphsis filter
     18 
     19     :param x: (*, L)
     20     :param coef: The coefficient
     21 
     22     :return: (*, L-1)
     23     """
     24     return x[..., 1:] - coef * x[..., :-1]
     25 
     26 
     27 def esr(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor:
     28     """
     29     ESR of (a batch of) predictions & targets
     30 
     31     :param preds: (N,) or (B,N)
     32     :param targets: Same as preds
     33     :return: ()
     34     """
     35     if preds.ndim == 1 and targets.ndim == 1:
     36         preds, targets = preds[None], targets[None]
     37     if preds.ndim != 2:
     38         raise ValueError(
     39             f"Expect 2D predictions (batch_size, num_samples). Got {preds.shape}"
     40         )
     41     if targets.ndim != 2:
     42         raise ValueError(
     43             f"Expect 2D targets (batch_size, num_samples). Got {targets.shape}"
     44         )
     45     return _torch.mean(
     46         _torch.mean(_torch.square(preds - targets), dim=1)
     47         / _torch.mean(_torch.square(targets), dim=1)
     48     )
     49 
     50 
     51 def multi_resolution_stft_loss(
     52     preds: _torch.Tensor,
     53     targets: _torch.Tensor,
     54     loss_func: _Optional[_MultiResolutionSTFTLoss] = None,
     55     device: _Optional[_torch.device] = None,
     56 ) -> _torch.Tensor:
     57     """
     58     Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
     59     B: Batch size
     60     L: Sequence length
     61 
     62     :param preds: (B,L)
     63     :param targets: (B,L)
     64     :param loss_func: A pre-initialized instance of the loss function module. Providing
     65         this saves time.
     66     :param device: If provided, send the preds and targets to the provided device.
     67     :return: ()
     68     """
     69     loss_func = _MultiResolutionSTFTLoss() if loss_func is None else loss_func
     70     if device is not None:
     71         preds, targets = [z.to(device) for z in (preds, targets)]
     72     return loss_func(preds, targets)
     73 
     74 
     75 def mse_fft(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor:
     76     """
     77     Fourier loss
     78 
     79     :param preds: (N,) or (B,N)
     80     :param targets: Same as preds
     81     :return: ()
     82     """
     83     fp = _torch.fft.fft(preds)
     84     ft = _torch.fft.fft(targets)
     85     e = fp - ft
     86     return _torch.mean(_torch.square(e.abs()))