neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_wavenet.py (1425B)


      1 # File: test_wavenet.py
      2 # Created Date: Friday May 5th 2023
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 import pytest as _pytest
      6 import torch as _torch
      7 
      8 from nam.models.wavenet import WaveNet as _WaveNet
      9 from nam.train.core import (
     10     Architecture as _Architecture,
     11     get_wavenet_config as _get_wavenet_config,
     12 )
     13 
     14 from .base import Base as _Base
     15 
     16 
     17 class TestWaveNet(_Base):
     18     @classmethod
     19     def setup_class(cls):
     20         C = _WaveNet
     21         args = ()
     22         kwargs = {
     23             "layers_configs": [
     24                 {
     25                     "input_size": 1,
     26                     "condition_size": 1,
     27                     "head_size": 1,
     28                     "channels": 1,
     29                     "kernel_size": 1,
     30                     "dilations": [1],
     31                 }
     32             ]
     33         }
     34         super().setup_class(C, args, kwargs)
     35 
     36     def test_import_weights(self):
     37         config = _get_wavenet_config(_Architecture.FEATHER)
     38         model_1 = _WaveNet.init_from_config(config)
     39         model_2 = _WaveNet.init_from_config(config)
     40 
     41         batch_size = 2
     42         x = _torch.randn(batch_size, model_1.receptive_field + 23)
     43 
     44         y1 = model_1(x)
     45         y2_before = model_2(x)
     46 
     47         model_2.import_weights(model_1._export_weights())
     48         y2_after = model_2(x)
     49 
     50         assert not _torch.allclose(y2_before, y1)
     51         assert _torch.allclose(y2_after, y1)
     52 
     53 
     54 if __name__ == "__main__":
     55     _pytest.main()