neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_data.py (14908B)


      1 # File: test_data.py
      2 # Created Date: Friday May 6th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 import math
      6 import os
      7 from enum import Enum
      8 from pathlib import Path
      9 from tempfile import TemporaryDirectory
     10 from typing import Optional, Tuple, Union
     11 
     12 import numpy as np
     13 import pytest
     14 import torch
     15 
     16 from nam import data
     17 
     18 _SAMPLE_RATES = (44_100.0, 48_000.0, 88_200.0, 96_000.0)
     19 _DEFAULT_SAMPLE_RATE = 48_000.0
     20 
     21 
     22 class _XYMethod(Enum):
     23     ARANGE = "arange"
     24     RAND = "rand"
     25     STEP = "step"
     26 
     27 
     28 class TestDataset(object):
     29     """
     30     Assertions about nam.data.Dataset
     31     """
     32 
     33     def test_apply_delay_zero(self):
     34         """
     35         Assert proper function of Dataset._apply_delay() when zero delay is given, i.e.
     36         no change.
     37         """
     38         x, y = self._create_xy()
     39         x_out, y_out = data.Dataset._apply_delay(
     40             x, y, 0, data._DelayInterpolationMethod.CUBIC
     41         )
     42         assert torch.all(x == x_out)
     43         assert torch.all(y == y_out)
     44 
     45     @pytest.mark.parametrize("method", (data._DelayInterpolationMethod))
     46     def test_apply_delay_float_negative(self, method):
     47         n = 7
     48         delay = -2.5
     49         x_out, y_out = self._t_apply_delay_float(n, delay, method)
     50 
     51         assert torch.all(x_out == torch.Tensor([3, 4, 5, 6]))
     52         assert torch.all(y_out == torch.Tensor([0.5, 1.5, 2.5, 3.5]))
     53 
     54     @pytest.mark.parametrize("method", (data._DelayInterpolationMethod))
     55     def test_apply_delay_float_positive(self, method):
     56         n = 7
     57         delay = 2.5
     58         x_out, y_out = self._t_apply_delay_float(n, delay, method)
     59 
     60         assert torch.all(x_out == torch.Tensor([0, 1, 2, 3]))
     61         assert torch.all(y_out == torch.Tensor([2.5, 3.5, 4.5, 5.5]))
     62 
     63     def test_apply_delay_int_negative(self):
     64         """
     65         Assert proper function of Dataset._apply_delay() when a positive integer delay
     66         is given.
     67         """
     68         n = 7
     69         delay = -3
     70         x_out, y_out = self._t_apply_delay_int(n, delay)
     71 
     72         assert torch.all(x_out == torch.Tensor([3, 4, 5, 6]))
     73         assert torch.all(y_out == torch.Tensor([0, 1, 2, 3]))
     74 
     75     def test_apply_delay_int_positive(self):
     76         """
     77         Assert proper function of Dataset._apply_delay() when a positive integer delay
     78         is given.
     79         """
     80         n = 7
     81         delay = 3
     82         x_out, y_out = self._t_apply_delay_int(n, delay)
     83 
     84         assert torch.all(x_out == torch.Tensor([0, 1, 2, 3]))
     85         assert torch.all(y_out == torch.Tensor([3, 4, 5, 6]))
     86 
     87     def test_init(self):
     88         x, y = self._create_xy()
     89         data.Dataset(x, y, 3, None, sample_rate=_DEFAULT_SAMPLE_RATE)
     90 
     91     def test_init_sample_rate(self):
     92         x, y = self._create_xy()
     93         sample_rate = _DEFAULT_SAMPLE_RATE
     94         d = data.Dataset(x, y, 3, None, sample_rate=sample_rate)
     95         assert hasattr(d, "sample_rate")
     96         assert isinstance(d.sample_rate, float)
     97         assert d.sample_rate == sample_rate
     98 
     99     def test_init_zero_delay(self):
    100         """
    101         Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
    102         """
    103         x, y = self._create_xy()
    104         data.Dataset(x, y, 3, None, delay=0, sample_rate=_DEFAULT_SAMPLE_RATE)
    105 
    106     def test_input_gain(self):
    107         """
    108         Checks correctness of input gain parameter
    109         """
    110         x_scale = 2.0
    111         input_gain = 20.0 * math.log10(x_scale)
    112         x, y = self._create_xy()
    113         nx = 3
    114         ny = None
    115         args = (x, y, nx, ny)
    116         d1 = data.Dataset(*args, sample_rate=_DEFAULT_SAMPLE_RATE)
    117         d2 = data.Dataset(
    118             *args, sample_rate=_DEFAULT_SAMPLE_RATE, input_gain=input_gain
    119         )
    120 
    121         sample_x1 = d1[0][0]
    122         sample_x2 = d2[0][0]
    123         assert torch.allclose(sample_x1 * x_scale, sample_x2)
    124 
    125     @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
    126     def test_sample_rates(self, sample_rate: int):
    127         """
    128         Test that datasets with various sample rates can be made
    129         """
    130         x = np.random.rand(16) - 0.5
    131         y = x
    132         with TemporaryDirectory() as tmpdir:
    133             x_path = Path(tmpdir, "input.wav")
    134             y_path = Path(tmpdir, "output.wav")
    135             data.np_to_wav(x, x_path, rate=sample_rate)
    136             data.np_to_wav(y, y_path, rate=sample_rate)
    137             config = {"x_path": str(x_path), "y_path": str(y_path), "nx": 4, "ny": 2}
    138             parsed_config = data.Dataset.parse_config(config)
    139         assert parsed_config["sample_rate"] == sample_rate
    140 
    141     @pytest.mark.parametrize(
    142         "n,start,valid",
    143         (
    144             (13, None, True),  # No start restrictions; nothing wrong
    145             (13, 2, True),  # Starts before the end; fine.
    146             (13, 12, True),  # Starts w/ one to go--ok
    147             (13, 13, False),  # Starts after the end
    148             (13, -5, True),  # Starts counting back from the end, fine
    149             (13, -13, True),  # Starts at the beginning of the array--ok
    150             (13, -14, False),  # Starts before the beginning of the array--invalid
    151         ),
    152     )
    153     def test_validate_start(self, n: int, start: int, valid: bool):
    154         """
    155         Assert that a data set can be successfully instantiated when valid args are
    156         given, including `start`.
    157         Assert that `StartError` is raised if invalid start is provided
    158         """
    159 
    160         def init():
    161             data.Dataset(x, y, nx, ny, start=start, sample_rate=_DEFAULT_SAMPLE_RATE)
    162 
    163         nx = 1
    164         ny = None
    165         x, y = self._create_xy(n=n)
    166         if start is not None:
    167             x[:start] = 0.0  # Ensure silent input before the start
    168         if valid:
    169             init()
    170             assert True  # No problem!
    171         else:
    172             with pytest.raises(data.StartError):
    173                 init()
    174 
    175     @pytest.mark.parametrize(
    176         "start,start_samples,start_seconds,stop,stop_samples,stop_seconds,sample_rate,raises",
    177         (
    178             # Nones across the board (valid)
    179             (None, None, None, None, None, None, None, None),
    180             # start and stop (valid)
    181             (1, None, None, -1, None, None, None, None),
    182             # start_samples and stop_samples (valid)
    183             (None, 1, None, None, -1, None, None, None),
    184             # start_seconds and stop_seconds with sample_rate (valid)
    185             (None, None, 0.5, None, None, -0.5, 2, None),
    186             # Multiple start-like, even if they agree (invalid)
    187             (1, 1, None, None, None, None, None, ValueError),
    188             # Multiple stop-like, even if they agree (invalid)
    189             (None, None, None, -1, -1, None, None, ValueError),
    190             # seconds w/o sample rate (invalid)
    191             (None, None, 1.0, None, None, None, None, ValueError),
    192         ),
    193     )
    194     def test_validate_start_stop(
    195         self,
    196         start: Optional[int],
    197         start_samples: Optional[int],
    198         start_seconds: Optional[Union[int, float]],
    199         stop: Optional[int],
    200         stop_samples: Optional[int],
    201         stop_seconds: Optional[Union[int, float]],
    202         sample_rate: Optional[int],
    203         raises: Optional[Exception],
    204     ):
    205         """
    206         Assert correct behavior of `._validate_start_stop()` class method.
    207         """
    208 
    209         def f():
    210             # Don't provide start/stop that are too large for the fake data plz.
    211             x, y = torch.zeros((2, 32))
    212             data.Dataset._validate_start_stop(
    213                 x,
    214                 y,
    215                 start,
    216                 stop,
    217                 start_samples,
    218                 stop_samples,
    219                 start_seconds,
    220                 stop_seconds,
    221                 sample_rate,
    222             )
    223             assert True
    224 
    225         if raises is None:
    226             f()
    227         else:
    228             with pytest.raises(raises):
    229                 f()
    230 
    231     @pytest.mark.parametrize(
    232         "n,stop,valid",
    233         (
    234             (13, None, True),  # No stop restrictions; nothing wrong
    235             (13, 2, True),  # Stops before the end; fine.
    236             (13, 13, True),  # Stops at the end--ok
    237             (13, 14, False),  # Stops after the end--not ok
    238             (13, -5, True),  # Stops counting back from the end, fine
    239             (13, -12, True),  # Stops w/ one sample--ok
    240             (13, -13, False),  # Stops w/ no samples--not ok
    241         ),
    242     )
    243     def test_validate_stop(self, n: int, stop: int, valid: bool):
    244         def init():
    245             data.Dataset(x, y, nx, ny, stop=stop, sample_rate=_DEFAULT_SAMPLE_RATE)
    246 
    247         nx = 1
    248         ny = None
    249         x, y = self._create_xy(n=n)
    250         if valid:
    251             init()
    252             assert True  # No problem!
    253         else:
    254             with pytest.raises(data.StopError):
    255                 init()
    256 
    257     @pytest.mark.parametrize(
    258         "lenx,leny,valid",
    259         ((3, 3, True), (3, 4, False), (0, 0, False)),  # Lenght mismatch  # Empty!
    260     )
    261     def test_validate_x_y(self, lenx: int, leny: int, valid: bool):
    262         def init():
    263             data.Dataset(x, y, nx, ny, sample_rate=_DEFAULT_SAMPLE_RATE)
    264 
    265         x, y = self._create_xy()
    266         assert len(x) >= lenx, "Invalid test!"
    267         assert len(y) >= leny, "Invalid test!"
    268         x = x[:lenx]
    269         y = y[:leny]
    270         nx = 1
    271         ny = None
    272         if valid:
    273             init()
    274             assert True  # It worked!
    275         else:
    276             with pytest.raises(data.XYError):
    277                 init()
    278 
    279     def _create_xy(
    280         self,
    281         n: int = 7,
    282         method: _XYMethod = _XYMethod.RAND,
    283         must_be_in_valid_range: bool = True,
    284     ) -> Tuple[torch.Tensor, torch.Tensor]:
    285         """
    286         :return: (n,), (n,)
    287         """
    288         if method == _XYMethod.ARANGE:
    289             # note: this isn't "valid" data in the sense that it's beyond (-1, 1).
    290             # But it is useful for the delay code.
    291             assert not must_be_in_valid_range
    292             return tuple(
    293                 torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1))
    294             )
    295         elif method == _XYMethod.RAND:
    296             return tuple(0.99 * (2.0 * torch.rand((2, n)) - 1.0))  # Don't clip
    297         elif method == _XYMethod.STEP:
    298             return tuple(
    299                 torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1))
    300             )
    301 
    302     def _t_apply_delay_float(
    303         self, n: int, delay: int, method: data._DelayInterpolationMethod
    304     ):
    305         x, y = self._create_xy(
    306             n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False
    307         )
    308 
    309         x_out, y_out = data.Dataset._apply_delay(x, y, delay, method)
    310         # 7, +/-2.5 -> 4
    311         n_out = n - int(np.ceil(np.abs(delay)))
    312         assert len(x_out) == n_out
    313         assert len(y_out) == n_out
    314 
    315         return x_out, y_out
    316 
    317     def _t_apply_delay_int(self, n: int, delay: int):
    318         x, y = self._create_xy(
    319             n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False
    320         )
    321 
    322         x_out, y_out = data.Dataset._apply_delay(
    323             x, y, delay, data._DelayInterpolationMethod.CUBIC
    324         )
    325         n_out = n - np.abs(delay)
    326         assert len(x_out) == n_out
    327         assert len(y_out) == n_out
    328 
    329         return x_out, y_out
    330 
    331 
    332 class TestWav(object):
    333     tolerance = 1e-6
    334 
    335     @pytest.fixture(scope="class")
    336     def tmpdir(self):
    337         with TemporaryDirectory() as tmp:
    338             yield tmp
    339 
    340     def test_np_to_wav_to_np(self, tmpdir):
    341         # Create random numpy array
    342         x = np.random.rand(1000)
    343         # Save numpy array as WAV file
    344         filename = os.path.join(tmpdir, "test.wav")
    345         data.np_to_wav(x, filename)
    346         # Load WAV file
    347         y = data.wav_to_np(filename)
    348         # Check if the two arrays are equal
    349         assert y == pytest.approx(x, abs=self.tolerance)
    350 
    351     @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
    352     def test_np_to_wav_to_np_sample_rates(self, sample_rate: int):
    353         with TemporaryDirectory() as tmpdir:
    354             # Create random numpy array
    355             x = np.random.rand(8)
    356             # Save numpy array as WAV file with sampling rate of 44 kHz
    357             filename = Path(tmpdir, "x.wav")
    358             data.np_to_wav(x, filename, rate=sample_rate)
    359             # Load WAV file with sampling rate of 44 kHz
    360             y = data.wav_to_np(filename, rate=sample_rate)
    361             # Check if the two arrays are equal
    362             assert y == pytest.approx(x, abs=self.tolerance)
    363 
    364     def test_np_to_wav_to_np_scale_arg(self, tmpdir):
    365         # Create random numpy array
    366         x = np.random.rand(100)
    367         # Save numpy array as WAV file with scaling
    368         filename = os.path.join(tmpdir, "test.wav")
    369         data.np_to_wav(x, filename, scale=None)
    370         # Load WAV file
    371         y = data.wav_to_np(filename)
    372         # Check if the two arrays are equal
    373         assert y == pytest.approx(x, abs=self.tolerance)
    374 
    375     @pytest.mark.parametrize("sample_width", (2, 3))
    376     def test_sample_widths(self, sample_width: int):
    377         """
    378         Test that datasets with various sample widths can be made
    379         """
    380         x = np.random.rand(16) - 0.5
    381         with TemporaryDirectory() as tmpdir:
    382             x_path = Path(tmpdir, "x.wav")
    383             data.np_to_wav(x, x_path, sampwidth=sample_width)
    384             _, info = data.wav_to_np(x_path, info=True)
    385         assert info.sampwidth == sample_width
    386 
    387 
    388 def test_audio_mismatch_shapes_in_order():
    389     """
    390     https://github.com/sdatkinson/neural-amp-modeler/issues/257
    391     """
    392     x_samples, y_samples = 5, 7
    393     num_channels = 1
    394 
    395     x, y = [np.zeros((n, num_channels)) for n in (x_samples, y_samples)]
    396 
    397     with TemporaryDirectory() as tmpdir:
    398         y_path = Path(tmpdir, "y.wav")
    399         data.np_to_wav(y, y_path)
    400         f = lambda: data.wav_to_np(y_path, required_shape=x.shape)
    401 
    402         with pytest.raises(data.AudioShapeMismatchError) as e:
    403             f()
    404 
    405         try:
    406             f()
    407             assert False, "Shouldn't have succeeded!"
    408         except data.AudioShapeMismatchError as e:
    409             # x is loaded first; we expect that y matches.
    410             assert e.shape_expected == (x_samples, num_channels)
    411             assert e.shape_actual == (y_samples, num_channels)
    412 
    413 
    414 def test_register_dataset_initializer():
    415     """
    416     Assert that you can add and use new data sets
    417     """
    418 
    419     class MyDataset(data.Dataset):
    420         pass
    421 
    422     name = "my_dataset"
    423 
    424     data.register_dataset_initializer(name, MyDataset.init_from_config)
    425 
    426     x = np.random.rand(32) - 0.5
    427     y = x
    428     split = data.Split.TRAIN
    429 
    430     with TemporaryDirectory() as tmpdir:
    431         x_path = Path(tmpdir, "x.wav")
    432         y_path = Path(tmpdir, "y.wav")
    433         data.np_to_wav(x, x_path)
    434         data.np_to_wav(y, y_path)
    435         config = {
    436             "type": name,
    437             split.value: {
    438                 "x_path": str(x_path),
    439                 "y_path": str(y_path),
    440                 "nx": 3,
    441                 "ny": 2,
    442             },
    443         }
    444         dataset = data.init_dataset(config, split)
    445     assert isinstance(dataset, MyDataset)
    446 
    447 
    448 if __name__ == "__main__":
    449     pytest.main()