test_losses.py (1727B)
1 # File: test_losses.py 2 # Created Date: Saturday January 28th 2023 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 import pytest 6 import torch 7 import torch.nn as nn 8 9 from nam.models import losses 10 11 12 @pytest.mark.parametrize( 13 "x,coef,y_expected", 14 ( 15 (torch.Tensor([0.0, 1.0, 2.0]), 1.0, torch.Tensor([1.0, 1.0])), 16 (torch.Tensor([0.0, 1.0, 2.0]), 0.5, torch.Tensor([1.0, 1.5])), 17 ( 18 torch.Tensor([[0.0, 1.0, 0.0], [1.0, 1.5, 2.0]]), 19 0.5, 20 torch.Tensor([[1.0, -0.5], [1.0, 1.25]]), 21 ), 22 ), 23 ) 24 def test_apply_pre_emphasis_filter_1d( 25 x: torch.Tensor, coef: float, y_expected: torch.Tensor 26 ): 27 y_actual = losses.apply_pre_emphasis_filter(x, coef) 28 assert isinstance(y_actual, torch.Tensor) 29 assert y_actual.ndim == y_expected.ndim 30 assert y_actual.shape == y_expected.shape 31 assert torch.allclose(y_actual, y_expected) 32 33 34 def test_esr(): 35 """ 36 Is the ESR calculation correct? 37 """ 38 39 class Model(nn.Module): 40 def forward(self, x): 41 return x 42 43 batch_size, input_length = 3, 5 44 inputs = ( 45 torch.linspace(0.1, 1.0, batch_size)[:, None] 46 * torch.full((input_length,), 1.0)[None, :] 47 ) # (batch_size, input_length) 48 target_factor = torch.linspace(0.37, 1.22, batch_size) 49 targets = target_factor[:, None] * inputs # (batch_size, input_length) 50 # Do the algebra: 51 # y=a*yhat 52 # ESR=(y-yhat)^2 / y^2 53 # ... 54 # =(1/a-1)^2 55 expected_esr = torch.square(1.0 / target_factor - 1.0).mean() 56 model = Model() 57 preds = model(inputs) 58 actual_esr = losses.esr(preds, targets) 59 assert torch.allclose(actual_esr, expected_esr) 60 61 62 if __name__ == "__main__": 63 pytest.main()