test_main.py (6427B)
1 # File: test_main.py 2 # Created Date: Sunday April 30th 2023 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 import json 6 from enum import Enum 7 from pathlib import Path 8 from subprocess import check_call 9 from tempfile import TemporaryDirectory 10 from typing import Dict, Tuple, Union 11 12 import numpy as np 13 import pytest 14 import torch 15 16 from nam.data import np_to_wav 17 18 19 class _Device(Enum): 20 CPU = "cpu" 21 GPU = "gpu" 22 MPS = "mps" 23 24 25 class Test(object): 26 @classmethod 27 def setup_class(cls): 28 cls._num_samples = 128 29 cls._num_samples_validation = 15 30 cls._ny = 2 31 cls._batch_size = 2 32 33 def test_cpu(self): 34 self._t_main(_Device.CPU) 35 36 @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU test") 37 def test_gpu(self): 38 self._t_main(_Device.GPU) 39 40 @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS test") 41 def test_mps(self): 42 self._t_main(_Device.MPS) 43 44 @classmethod 45 def _data_config_path(cls, root_path: Path) -> Path: 46 return Path(cls._input_path(root_path), "data_config.json") 47 48 def _get_configs( 49 self, root_path: Path, device: _Device 50 ) -> Tuple[Dict, Dict, Dict]: # TODO pydantic models 51 data_config = { 52 "train": { 53 "start": None, 54 "stop": -self._num_samples_validation, 55 "ny": self._ny, 56 }, 57 "validation": { 58 "start": -self._num_samples_validation, 59 "stop": None, 60 "ny": None, 61 }, 62 "common": { 63 "x_path": str(self._x_path(root_path)), 64 "y_path": str(self._y_path(root_path)), 65 "delay": 0, 66 "require_input_pre_silence": None, 67 }, 68 } 69 stage_channels = (3, 2) 70 model_config = { 71 "net": { 72 "name": "WaveNet", 73 "config": { 74 "layers_configs": [ 75 { 76 "condition_size": 1, 77 "input_size": 1, 78 "channels": stage_channels[0], 79 "head_size": stage_channels[1], 80 "kernel_size": 3, 81 "dilations": [1], 82 "activation": "Tanh", 83 "gated": False, 84 "head_bias": False, 85 }, 86 { 87 "condition_size": 1, 88 "input_size": stage_channels[0], 89 "channels": stage_channels[1], 90 "head_size": 1, 91 "kernel_size": 3, 92 "dilations": [2], 93 "activation": "Tanh", 94 "gated": False, 95 "head_bias": False, 96 }, 97 ], 98 "head_scale": 0.02, 99 }, 100 }, 101 "optimizer": {"lr": 0.004}, 102 "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.993}}, 103 } 104 105 def extra_trainer_kwargs(device) -> Dict[str, Union[int, str]]: 106 return { 107 _Device.GPU: {"accelerator": "gpu", "devices": 1}, 108 _Device.MPS: {"accelerator": "mps", "devices": 1}, 109 }.get(device, {}) 110 111 learning_config = { 112 "train_dataloader": { 113 "batch_size": 3, 114 "shuffle": True, 115 "pin_memory": True, 116 "drop_last": True, 117 "num_workers": 0, 118 }, 119 "val_dataloader": {}, 120 "trainer": {"max_epochs": 2, **extra_trainer_kwargs(device)}, 121 "trainer_fit_kwargs": {}, 122 } 123 124 return data_config, model_config, learning_config 125 126 def _get_data(self) -> Tuple[np.ndarray, np.ndarray]: 127 """ 128 :return: (N,), (N,) 129 """ 130 x = np.random.rand(self._num_samples) - 0.5 131 y = 1.1 * x 132 return x, y 133 134 @classmethod 135 def _input_path(cls, root_path: Path, ensure: bool = False) -> Path: 136 p = Path(root_path, "inputs") 137 if ensure: 138 p.mkdir() 139 return p 140 141 @classmethod 142 def _learning_config_path(cls, root_path: Path) -> Path: 143 return Path(cls._input_path(root_path), "learning_config.json") 144 145 @classmethod 146 def _model_config_path(cls, root_path: Path) -> Path: 147 return Path(cls._input_path(root_path), "model_config.json") 148 149 @classmethod 150 def _output_path(cls, root_path: Path, ensure: bool = False) -> Path: 151 p = Path(root_path, "outputs") 152 if ensure: 153 p.mkdir() 154 return p 155 156 def _setup_files(self, root_path: Path, device: _Device): 157 x, y = self._get_data() 158 np_to_wav(x, self._x_path(root_path)) 159 np_to_wav(y, self._y_path(root_path)) 160 data_config, model_config, learning_config = self._get_configs( 161 root_path, device 162 ) 163 with open(self._data_config_path(root_path), "w") as fp: 164 json.dump(data_config, fp) 165 with open(self._model_config_path(root_path), "w") as fp: 166 json.dump(model_config, fp) 167 with open(self._learning_config_path(root_path), "w") as fp: 168 json.dump(learning_config, fp) 169 170 def _t_main(self, device: _Device): 171 """ 172 End-to-end test of bin/train/main.py 173 """ 174 with TemporaryDirectory() as tempdir: 175 tempdir = Path(tempdir) 176 self._input_path(tempdir, ensure=True) 177 self._setup_files(tempdir, device) 178 check_call( 179 [ 180 "nam-full", # HACK not DRY w/ setup.py 181 str(self._data_config_path(tempdir)), 182 str(self._model_config_path(tempdir)), 183 str(self._learning_config_path(tempdir)), 184 str(self._output_path(tempdir, ensure=True)), 185 "--no-show", 186 ] 187 ) 188 189 @classmethod 190 def _x_path(cls, root_path: Path) -> Path: 191 return Path(cls._input_path(root_path), "input.wav") 192 193 @classmethod 194 def _y_path(cls, root_path: Path) -> Path: 195 return Path(cls._input_path(root_path), "output.wav") 196 197 198 if __name__ == "__main__": 199 pytest.main()