neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_losses.py (1727B)


      1 # File: test_losses.py
      2 # Created Date: Saturday January 28th 2023
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 import pytest
      6 import torch
      7 import torch.nn as nn
      8 
      9 from nam.models import losses
     10 
     11 
     12 @pytest.mark.parametrize(
     13     "x,coef,y_expected",
     14     (
     15         (torch.Tensor([0.0, 1.0, 2.0]), 1.0, torch.Tensor([1.0, 1.0])),
     16         (torch.Tensor([0.0, 1.0, 2.0]), 0.5, torch.Tensor([1.0, 1.5])),
     17         (
     18             torch.Tensor([[0.0, 1.0, 0.0], [1.0, 1.5, 2.0]]),
     19             0.5,
     20             torch.Tensor([[1.0, -0.5], [1.0, 1.25]]),
     21         ),
     22     ),
     23 )
     24 def test_apply_pre_emphasis_filter_1d(
     25     x: torch.Tensor, coef: float, y_expected: torch.Tensor
     26 ):
     27     y_actual = losses.apply_pre_emphasis_filter(x, coef)
     28     assert isinstance(y_actual, torch.Tensor)
     29     assert y_actual.ndim == y_expected.ndim
     30     assert y_actual.shape == y_expected.shape
     31     assert torch.allclose(y_actual, y_expected)
     32 
     33 
     34 def test_esr():
     35     """
     36     Is the ESR calculation correct?
     37     """
     38 
     39     class Model(nn.Module):
     40         def forward(self, x):
     41             return x
     42 
     43     batch_size, input_length = 3, 5
     44     inputs = (
     45         torch.linspace(0.1, 1.0, batch_size)[:, None]
     46         * torch.full((input_length,), 1.0)[None, :]
     47     )  # (batch_size, input_length)
     48     target_factor = torch.linspace(0.37, 1.22, batch_size)
     49     targets = target_factor[:, None] * inputs  # (batch_size, input_length)
     50     # Do the algebra:
     51     # y=a*yhat
     52     # ESR=(y-yhat)^2 / y^2
     53     # ...
     54     # =(1/a-1)^2
     55     expected_esr = torch.square(1.0 / target_factor - 1.0).mean()
     56     model = Model()
     57     preds = model(inputs)
     58     actual_esr = losses.esr(preds, targets)
     59     assert torch.allclose(actual_esr, expected_esr)
     60 
     61 
     62 if __name__ == "__main__":
     63     pytest.main()