neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_main.py (6427B)


      1 # File: test_main.py
      2 # Created Date: Sunday April 30th 2023
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 import json
      6 from enum import Enum
      7 from pathlib import Path
      8 from subprocess import check_call
      9 from tempfile import TemporaryDirectory
     10 from typing import Dict, Tuple, Union
     11 
     12 import numpy as np
     13 import pytest
     14 import torch
     15 
     16 from nam.data import np_to_wav
     17 
     18 
     19 class _Device(Enum):
     20     CPU = "cpu"
     21     GPU = "gpu"
     22     MPS = "mps"
     23 
     24 
     25 class Test(object):
     26     @classmethod
     27     def setup_class(cls):
     28         cls._num_samples = 128
     29         cls._num_samples_validation = 15
     30         cls._ny = 2
     31         cls._batch_size = 2
     32 
     33     def test_cpu(self):
     34         self._t_main(_Device.CPU)
     35 
     36     @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU test")
     37     def test_gpu(self):
     38         self._t_main(_Device.GPU)
     39 
     40     @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS test")
     41     def test_mps(self):
     42         self._t_main(_Device.MPS)
     43 
     44     @classmethod
     45     def _data_config_path(cls, root_path: Path) -> Path:
     46         return Path(cls._input_path(root_path), "data_config.json")
     47 
     48     def _get_configs(
     49         self, root_path: Path, device: _Device
     50     ) -> Tuple[Dict, Dict, Dict]:  # TODO pydantic models
     51         data_config = {
     52             "train": {
     53                 "start": None,
     54                 "stop": -self._num_samples_validation,
     55                 "ny": self._ny,
     56             },
     57             "validation": {
     58                 "start": -self._num_samples_validation,
     59                 "stop": None,
     60                 "ny": None,
     61             },
     62             "common": {
     63                 "x_path": str(self._x_path(root_path)),
     64                 "y_path": str(self._y_path(root_path)),
     65                 "delay": 0,
     66                 "require_input_pre_silence": None,
     67             },
     68         }
     69         stage_channels = (3, 2)
     70         model_config = {
     71             "net": {
     72                 "name": "WaveNet",
     73                 "config": {
     74                     "layers_configs": [
     75                         {
     76                             "condition_size": 1,
     77                             "input_size": 1,
     78                             "channels": stage_channels[0],
     79                             "head_size": stage_channels[1],
     80                             "kernel_size": 3,
     81                             "dilations": [1],
     82                             "activation": "Tanh",
     83                             "gated": False,
     84                             "head_bias": False,
     85                         },
     86                         {
     87                             "condition_size": 1,
     88                             "input_size": stage_channels[0],
     89                             "channels": stage_channels[1],
     90                             "head_size": 1,
     91                             "kernel_size": 3,
     92                             "dilations": [2],
     93                             "activation": "Tanh",
     94                             "gated": False,
     95                             "head_bias": False,
     96                         },
     97                     ],
     98                     "head_scale": 0.02,
     99                 },
    100             },
    101             "optimizer": {"lr": 0.004},
    102             "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.993}},
    103         }
    104 
    105         def extra_trainer_kwargs(device) -> Dict[str, Union[int, str]]:
    106             return {
    107                 _Device.GPU: {"accelerator": "gpu", "devices": 1},
    108                 _Device.MPS: {"accelerator": "mps", "devices": 1},
    109             }.get(device, {})
    110 
    111         learning_config = {
    112             "train_dataloader": {
    113                 "batch_size": 3,
    114                 "shuffle": True,
    115                 "pin_memory": True,
    116                 "drop_last": True,
    117                 "num_workers": 0,
    118             },
    119             "val_dataloader": {},
    120             "trainer": {"max_epochs": 2, **extra_trainer_kwargs(device)},
    121             "trainer_fit_kwargs": {},
    122         }
    123 
    124         return data_config, model_config, learning_config
    125 
    126     def _get_data(self) -> Tuple[np.ndarray, np.ndarray]:
    127         """
    128         :return: (N,), (N,)
    129         """
    130         x = np.random.rand(self._num_samples) - 0.5
    131         y = 1.1 * x
    132         return x, y
    133 
    134     @classmethod
    135     def _input_path(cls, root_path: Path, ensure: bool = False) -> Path:
    136         p = Path(root_path, "inputs")
    137         if ensure:
    138             p.mkdir()
    139         return p
    140 
    141     @classmethod
    142     def _learning_config_path(cls, root_path: Path) -> Path:
    143         return Path(cls._input_path(root_path), "learning_config.json")
    144 
    145     @classmethod
    146     def _model_config_path(cls, root_path: Path) -> Path:
    147         return Path(cls._input_path(root_path), "model_config.json")
    148 
    149     @classmethod
    150     def _output_path(cls, root_path: Path, ensure: bool = False) -> Path:
    151         p = Path(root_path, "outputs")
    152         if ensure:
    153             p.mkdir()
    154         return p
    155 
    156     def _setup_files(self, root_path: Path, device: _Device):
    157         x, y = self._get_data()
    158         np_to_wav(x, self._x_path(root_path))
    159         np_to_wav(y, self._y_path(root_path))
    160         data_config, model_config, learning_config = self._get_configs(
    161             root_path, device
    162         )
    163         with open(self._data_config_path(root_path), "w") as fp:
    164             json.dump(data_config, fp)
    165         with open(self._model_config_path(root_path), "w") as fp:
    166             json.dump(model_config, fp)
    167         with open(self._learning_config_path(root_path), "w") as fp:
    168             json.dump(learning_config, fp)
    169 
    170     def _t_main(self, device: _Device):
    171         """
    172         End-to-end test of bin/train/main.py
    173         """
    174         with TemporaryDirectory() as tempdir:
    175             tempdir = Path(tempdir)
    176             self._input_path(tempdir, ensure=True)
    177             self._setup_files(tempdir, device)
    178             check_call(
    179                 [
    180                     "nam-full",  # HACK not DRY w/ setup.py
    181                     str(self._data_config_path(tempdir)),
    182                     str(self._model_config_path(tempdir)),
    183                     str(self._learning_config_path(tempdir)),
    184                     str(self._output_path(tempdir, ensure=True)),
    185                     "--no-show",
    186                 ]
    187             )
    188 
    189     @classmethod
    190     def _x_path(cls, root_path: Path) -> Path:
    191         return Path(cls._input_path(root_path), "input.wav")
    192 
    193     @classmethod
    194     def _y_path(cls, root_path: Path) -> Path:
    195         return Path(cls._input_path(root_path), "output.wav")
    196 
    197 
    198 if __name__ == "__main__":
    199     pytest.main()