losses.py (2437B)
1 # File: losses.py 2 # Created Date: Sunday January 22nd 2023 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 Loss functions 7 """ 8 9 from typing import Optional as _Optional 10 11 import torch as _torch 12 from auraloss.freq import MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss 13 14 15 def apply_pre_emphasis_filter(x: _torch.Tensor, coef: float) -> _torch.Tensor: 16 """ 17 Apply first-order pre-emphsis filter 18 19 :param x: (*, L) 20 :param coef: The coefficient 21 22 :return: (*, L-1) 23 """ 24 return x[..., 1:] - coef * x[..., :-1] 25 26 27 def esr(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor: 28 """ 29 ESR of (a batch of) predictions & targets 30 31 :param preds: (N,) or (B,N) 32 :param targets: Same as preds 33 :return: () 34 """ 35 if preds.ndim == 1 and targets.ndim == 1: 36 preds, targets = preds[None], targets[None] 37 if preds.ndim != 2: 38 raise ValueError( 39 f"Expect 2D predictions (batch_size, num_samples). Got {preds.shape}" 40 ) 41 if targets.ndim != 2: 42 raise ValueError( 43 f"Expect 2D targets (batch_size, num_samples). Got {targets.shape}" 44 ) 45 return _torch.mean( 46 _torch.mean(_torch.square(preds - targets), dim=1) 47 / _torch.mean(_torch.square(targets), dim=1) 48 ) 49 50 51 def multi_resolution_stft_loss( 52 preds: _torch.Tensor, 53 targets: _torch.Tensor, 54 loss_func: _Optional[_MultiResolutionSTFTLoss] = None, 55 device: _Optional[_torch.device] = None, 56 ) -> _torch.Tensor: 57 """ 58 Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. 59 B: Batch size 60 L: Sequence length 61 62 :param preds: (B,L) 63 :param targets: (B,L) 64 :param loss_func: A pre-initialized instance of the loss function module. Providing 65 this saves time. 66 :param device: If provided, send the preds and targets to the provided device. 67 :return: () 68 """ 69 loss_func = _MultiResolutionSTFTLoss() if loss_func is None else loss_func 70 if device is not None: 71 preds, targets = [z.to(device) for z in (preds, targets)] 72 return loss_func(preds, targets) 73 74 75 def mse_fft(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor: 76 """ 77 Fourier loss 78 79 :param preds: (N,) or (B,N) 80 :param targets: Same as preds 81 :return: () 82 """ 83 fp = _torch.fft.fft(preds) 84 ft = _torch.fft.fft(targets) 85 e = fp - ft 86 return _torch.mean(_torch.square(e.abs()))