test_data.py (14908B)
1 # File: test_data.py 2 # Created Date: Friday May 6th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 import math 6 import os 7 from enum import Enum 8 from pathlib import Path 9 from tempfile import TemporaryDirectory 10 from typing import Optional, Tuple, Union 11 12 import numpy as np 13 import pytest 14 import torch 15 16 from nam import data 17 18 _SAMPLE_RATES = (44_100.0, 48_000.0, 88_200.0, 96_000.0) 19 _DEFAULT_SAMPLE_RATE = 48_000.0 20 21 22 class _XYMethod(Enum): 23 ARANGE = "arange" 24 RAND = "rand" 25 STEP = "step" 26 27 28 class TestDataset(object): 29 """ 30 Assertions about nam.data.Dataset 31 """ 32 33 def test_apply_delay_zero(self): 34 """ 35 Assert proper function of Dataset._apply_delay() when zero delay is given, i.e. 36 no change. 37 """ 38 x, y = self._create_xy() 39 x_out, y_out = data.Dataset._apply_delay( 40 x, y, 0, data._DelayInterpolationMethod.CUBIC 41 ) 42 assert torch.all(x == x_out) 43 assert torch.all(y == y_out) 44 45 @pytest.mark.parametrize("method", (data._DelayInterpolationMethod)) 46 def test_apply_delay_float_negative(self, method): 47 n = 7 48 delay = -2.5 49 x_out, y_out = self._t_apply_delay_float(n, delay, method) 50 51 assert torch.all(x_out == torch.Tensor([3, 4, 5, 6])) 52 assert torch.all(y_out == torch.Tensor([0.5, 1.5, 2.5, 3.5])) 53 54 @pytest.mark.parametrize("method", (data._DelayInterpolationMethod)) 55 def test_apply_delay_float_positive(self, method): 56 n = 7 57 delay = 2.5 58 x_out, y_out = self._t_apply_delay_float(n, delay, method) 59 60 assert torch.all(x_out == torch.Tensor([0, 1, 2, 3])) 61 assert torch.all(y_out == torch.Tensor([2.5, 3.5, 4.5, 5.5])) 62 63 def test_apply_delay_int_negative(self): 64 """ 65 Assert proper function of Dataset._apply_delay() when a positive integer delay 66 is given. 67 """ 68 n = 7 69 delay = -3 70 x_out, y_out = self._t_apply_delay_int(n, delay) 71 72 assert torch.all(x_out == torch.Tensor([3, 4, 5, 6])) 73 assert torch.all(y_out == torch.Tensor([0, 1, 2, 3])) 74 75 def test_apply_delay_int_positive(self): 76 """ 77 Assert proper function of Dataset._apply_delay() when a positive integer delay 78 is given. 79 """ 80 n = 7 81 delay = 3 82 x_out, y_out = self._t_apply_delay_int(n, delay) 83 84 assert torch.all(x_out == torch.Tensor([0, 1, 2, 3])) 85 assert torch.all(y_out == torch.Tensor([3, 4, 5, 6])) 86 87 def test_init(self): 88 x, y = self._create_xy() 89 data.Dataset(x, y, 3, None, sample_rate=_DEFAULT_SAMPLE_RATE) 90 91 def test_init_sample_rate(self): 92 x, y = self._create_xy() 93 sample_rate = _DEFAULT_SAMPLE_RATE 94 d = data.Dataset(x, y, 3, None, sample_rate=sample_rate) 95 assert hasattr(d, "sample_rate") 96 assert isinstance(d.sample_rate, float) 97 assert d.sample_rate == sample_rate 98 99 def test_init_zero_delay(self): 100 """ 101 Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed 102 """ 103 x, y = self._create_xy() 104 data.Dataset(x, y, 3, None, delay=0, sample_rate=_DEFAULT_SAMPLE_RATE) 105 106 def test_input_gain(self): 107 """ 108 Checks correctness of input gain parameter 109 """ 110 x_scale = 2.0 111 input_gain = 20.0 * math.log10(x_scale) 112 x, y = self._create_xy() 113 nx = 3 114 ny = None 115 args = (x, y, nx, ny) 116 d1 = data.Dataset(*args, sample_rate=_DEFAULT_SAMPLE_RATE) 117 d2 = data.Dataset( 118 *args, sample_rate=_DEFAULT_SAMPLE_RATE, input_gain=input_gain 119 ) 120 121 sample_x1 = d1[0][0] 122 sample_x2 = d2[0][0] 123 assert torch.allclose(sample_x1 * x_scale, sample_x2) 124 125 @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES) 126 def test_sample_rates(self, sample_rate: int): 127 """ 128 Test that datasets with various sample rates can be made 129 """ 130 x = np.random.rand(16) - 0.5 131 y = x 132 with TemporaryDirectory() as tmpdir: 133 x_path = Path(tmpdir, "input.wav") 134 y_path = Path(tmpdir, "output.wav") 135 data.np_to_wav(x, x_path, rate=sample_rate) 136 data.np_to_wav(y, y_path, rate=sample_rate) 137 config = {"x_path": str(x_path), "y_path": str(y_path), "nx": 4, "ny": 2} 138 parsed_config = data.Dataset.parse_config(config) 139 assert parsed_config["sample_rate"] == sample_rate 140 141 @pytest.mark.parametrize( 142 "n,start,valid", 143 ( 144 (13, None, True), # No start restrictions; nothing wrong 145 (13, 2, True), # Starts before the end; fine. 146 (13, 12, True), # Starts w/ one to go--ok 147 (13, 13, False), # Starts after the end 148 (13, -5, True), # Starts counting back from the end, fine 149 (13, -13, True), # Starts at the beginning of the array--ok 150 (13, -14, False), # Starts before the beginning of the array--invalid 151 ), 152 ) 153 def test_validate_start(self, n: int, start: int, valid: bool): 154 """ 155 Assert that a data set can be successfully instantiated when valid args are 156 given, including `start`. 157 Assert that `StartError` is raised if invalid start is provided 158 """ 159 160 def init(): 161 data.Dataset(x, y, nx, ny, start=start, sample_rate=_DEFAULT_SAMPLE_RATE) 162 163 nx = 1 164 ny = None 165 x, y = self._create_xy(n=n) 166 if start is not None: 167 x[:start] = 0.0 # Ensure silent input before the start 168 if valid: 169 init() 170 assert True # No problem! 171 else: 172 with pytest.raises(data.StartError): 173 init() 174 175 @pytest.mark.parametrize( 176 "start,start_samples,start_seconds,stop,stop_samples,stop_seconds,sample_rate,raises", 177 ( 178 # Nones across the board (valid) 179 (None, None, None, None, None, None, None, None), 180 # start and stop (valid) 181 (1, None, None, -1, None, None, None, None), 182 # start_samples and stop_samples (valid) 183 (None, 1, None, None, -1, None, None, None), 184 # start_seconds and stop_seconds with sample_rate (valid) 185 (None, None, 0.5, None, None, -0.5, 2, None), 186 # Multiple start-like, even if they agree (invalid) 187 (1, 1, None, None, None, None, None, ValueError), 188 # Multiple stop-like, even if they agree (invalid) 189 (None, None, None, -1, -1, None, None, ValueError), 190 # seconds w/o sample rate (invalid) 191 (None, None, 1.0, None, None, None, None, ValueError), 192 ), 193 ) 194 def test_validate_start_stop( 195 self, 196 start: Optional[int], 197 start_samples: Optional[int], 198 start_seconds: Optional[Union[int, float]], 199 stop: Optional[int], 200 stop_samples: Optional[int], 201 stop_seconds: Optional[Union[int, float]], 202 sample_rate: Optional[int], 203 raises: Optional[Exception], 204 ): 205 """ 206 Assert correct behavior of `._validate_start_stop()` class method. 207 """ 208 209 def f(): 210 # Don't provide start/stop that are too large for the fake data plz. 211 x, y = torch.zeros((2, 32)) 212 data.Dataset._validate_start_stop( 213 x, 214 y, 215 start, 216 stop, 217 start_samples, 218 stop_samples, 219 start_seconds, 220 stop_seconds, 221 sample_rate, 222 ) 223 assert True 224 225 if raises is None: 226 f() 227 else: 228 with pytest.raises(raises): 229 f() 230 231 @pytest.mark.parametrize( 232 "n,stop,valid", 233 ( 234 (13, None, True), # No stop restrictions; nothing wrong 235 (13, 2, True), # Stops before the end; fine. 236 (13, 13, True), # Stops at the end--ok 237 (13, 14, False), # Stops after the end--not ok 238 (13, -5, True), # Stops counting back from the end, fine 239 (13, -12, True), # Stops w/ one sample--ok 240 (13, -13, False), # Stops w/ no samples--not ok 241 ), 242 ) 243 def test_validate_stop(self, n: int, stop: int, valid: bool): 244 def init(): 245 data.Dataset(x, y, nx, ny, stop=stop, sample_rate=_DEFAULT_SAMPLE_RATE) 246 247 nx = 1 248 ny = None 249 x, y = self._create_xy(n=n) 250 if valid: 251 init() 252 assert True # No problem! 253 else: 254 with pytest.raises(data.StopError): 255 init() 256 257 @pytest.mark.parametrize( 258 "lenx,leny,valid", 259 ((3, 3, True), (3, 4, False), (0, 0, False)), # Lenght mismatch # Empty! 260 ) 261 def test_validate_x_y(self, lenx: int, leny: int, valid: bool): 262 def init(): 263 data.Dataset(x, y, nx, ny, sample_rate=_DEFAULT_SAMPLE_RATE) 264 265 x, y = self._create_xy() 266 assert len(x) >= lenx, "Invalid test!" 267 assert len(y) >= leny, "Invalid test!" 268 x = x[:lenx] 269 y = y[:leny] 270 nx = 1 271 ny = None 272 if valid: 273 init() 274 assert True # It worked! 275 else: 276 with pytest.raises(data.XYError): 277 init() 278 279 def _create_xy( 280 self, 281 n: int = 7, 282 method: _XYMethod = _XYMethod.RAND, 283 must_be_in_valid_range: bool = True, 284 ) -> Tuple[torch.Tensor, torch.Tensor]: 285 """ 286 :return: (n,), (n,) 287 """ 288 if method == _XYMethod.ARANGE: 289 # note: this isn't "valid" data in the sense that it's beyond (-1, 1). 290 # But it is useful for the delay code. 291 assert not must_be_in_valid_range 292 return tuple( 293 torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1)) 294 ) 295 elif method == _XYMethod.RAND: 296 return tuple(0.99 * (2.0 * torch.rand((2, n)) - 1.0)) # Don't clip 297 elif method == _XYMethod.STEP: 298 return tuple( 299 torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1)) 300 ) 301 302 def _t_apply_delay_float( 303 self, n: int, delay: int, method: data._DelayInterpolationMethod 304 ): 305 x, y = self._create_xy( 306 n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False 307 ) 308 309 x_out, y_out = data.Dataset._apply_delay(x, y, delay, method) 310 # 7, +/-2.5 -> 4 311 n_out = n - int(np.ceil(np.abs(delay))) 312 assert len(x_out) == n_out 313 assert len(y_out) == n_out 314 315 return x_out, y_out 316 317 def _t_apply_delay_int(self, n: int, delay: int): 318 x, y = self._create_xy( 319 n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False 320 ) 321 322 x_out, y_out = data.Dataset._apply_delay( 323 x, y, delay, data._DelayInterpolationMethod.CUBIC 324 ) 325 n_out = n - np.abs(delay) 326 assert len(x_out) == n_out 327 assert len(y_out) == n_out 328 329 return x_out, y_out 330 331 332 class TestWav(object): 333 tolerance = 1e-6 334 335 @pytest.fixture(scope="class") 336 def tmpdir(self): 337 with TemporaryDirectory() as tmp: 338 yield tmp 339 340 def test_np_to_wav_to_np(self, tmpdir): 341 # Create random numpy array 342 x = np.random.rand(1000) 343 # Save numpy array as WAV file 344 filename = os.path.join(tmpdir, "test.wav") 345 data.np_to_wav(x, filename) 346 # Load WAV file 347 y = data.wav_to_np(filename) 348 # Check if the two arrays are equal 349 assert y == pytest.approx(x, abs=self.tolerance) 350 351 @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES) 352 def test_np_to_wav_to_np_sample_rates(self, sample_rate: int): 353 with TemporaryDirectory() as tmpdir: 354 # Create random numpy array 355 x = np.random.rand(8) 356 # Save numpy array as WAV file with sampling rate of 44 kHz 357 filename = Path(tmpdir, "x.wav") 358 data.np_to_wav(x, filename, rate=sample_rate) 359 # Load WAV file with sampling rate of 44 kHz 360 y = data.wav_to_np(filename, rate=sample_rate) 361 # Check if the two arrays are equal 362 assert y == pytest.approx(x, abs=self.tolerance) 363 364 def test_np_to_wav_to_np_scale_arg(self, tmpdir): 365 # Create random numpy array 366 x = np.random.rand(100) 367 # Save numpy array as WAV file with scaling 368 filename = os.path.join(tmpdir, "test.wav") 369 data.np_to_wav(x, filename, scale=None) 370 # Load WAV file 371 y = data.wav_to_np(filename) 372 # Check if the two arrays are equal 373 assert y == pytest.approx(x, abs=self.tolerance) 374 375 @pytest.mark.parametrize("sample_width", (2, 3)) 376 def test_sample_widths(self, sample_width: int): 377 """ 378 Test that datasets with various sample widths can be made 379 """ 380 x = np.random.rand(16) - 0.5 381 with TemporaryDirectory() as tmpdir: 382 x_path = Path(tmpdir, "x.wav") 383 data.np_to_wav(x, x_path, sampwidth=sample_width) 384 _, info = data.wav_to_np(x_path, info=True) 385 assert info.sampwidth == sample_width 386 387 388 def test_audio_mismatch_shapes_in_order(): 389 """ 390 https://github.com/sdatkinson/neural-amp-modeler/issues/257 391 """ 392 x_samples, y_samples = 5, 7 393 num_channels = 1 394 395 x, y = [np.zeros((n, num_channels)) for n in (x_samples, y_samples)] 396 397 with TemporaryDirectory() as tmpdir: 398 y_path = Path(tmpdir, "y.wav") 399 data.np_to_wav(y, y_path) 400 f = lambda: data.wav_to_np(y_path, required_shape=x.shape) 401 402 with pytest.raises(data.AudioShapeMismatchError) as e: 403 f() 404 405 try: 406 f() 407 assert False, "Shouldn't have succeeded!" 408 except data.AudioShapeMismatchError as e: 409 # x is loaded first; we expect that y matches. 410 assert e.shape_expected == (x_samples, num_channels) 411 assert e.shape_actual == (y_samples, num_channels) 412 413 414 def test_register_dataset_initializer(): 415 """ 416 Assert that you can add and use new data sets 417 """ 418 419 class MyDataset(data.Dataset): 420 pass 421 422 name = "my_dataset" 423 424 data.register_dataset_initializer(name, MyDataset.init_from_config) 425 426 x = np.random.rand(32) - 0.5 427 y = x 428 split = data.Split.TRAIN 429 430 with TemporaryDirectory() as tmpdir: 431 x_path = Path(tmpdir, "x.wav") 432 y_path = Path(tmpdir, "y.wav") 433 data.np_to_wav(x, x_path) 434 data.np_to_wav(y, y_path) 435 config = { 436 "type": name, 437 split.value: { 438 "x_path": str(x_path), 439 "y_path": str(y_path), 440 "nx": 3, 441 "ny": 2, 442 }, 443 } 444 dataset = data.init_dataset(config, split) 445 assert isinstance(dataset, MyDataset) 446 447 448 if __name__ == "__main__": 449 pytest.main()