test_wavenet.py (1425B)
1 # File: test_wavenet.py 2 # Created Date: Friday May 5th 2023 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 import pytest as _pytest 6 import torch as _torch 7 8 from nam.models.wavenet import WaveNet as _WaveNet 9 from nam.train.core import ( 10 Architecture as _Architecture, 11 get_wavenet_config as _get_wavenet_config, 12 ) 13 14 from .base import Base as _Base 15 16 17 class TestWaveNet(_Base): 18 @classmethod 19 def setup_class(cls): 20 C = _WaveNet 21 args = () 22 kwargs = { 23 "layers_configs": [ 24 { 25 "input_size": 1, 26 "condition_size": 1, 27 "head_size": 1, 28 "channels": 1, 29 "kernel_size": 1, 30 "dilations": [1], 31 } 32 ] 33 } 34 super().setup_class(C, args, kwargs) 35 36 def test_import_weights(self): 37 config = _get_wavenet_config(_Architecture.FEATHER) 38 model_1 = _WaveNet.init_from_config(config) 39 model_2 = _WaveNet.init_from_config(config) 40 41 batch_size = 2 42 x = _torch.randn(batch_size, model_1.receptive_field + 23) 43 44 y1 = model_1(x) 45 y2_before = model_2(x) 46 47 model_2.import_weights(model_1._export_weights()) 48 y2_after = model_2(x) 49 50 assert not _torch.allclose(y2_before, y1) 51 assert _torch.allclose(y2_after, y1) 52 53 54 if __name__ == "__main__": 55 _pytest.main()