neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

wavenet.py (14240B)


      1 # File: wavenet.py
      2 # Created Date: Friday July 29th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 WaveNet implementation
      7 https://arxiv.org/abs/1609.03499
      8 """
      9 
     10 import json as _json
     11 from copy import deepcopy as _deepcopy
     12 from pathlib import Path as _Path
     13 from tempfile import TemporaryDirectory as _TemporaryDirectory
     14 from typing import (
     15     Dict as _Dict,
     16     Optional as _Optional,
     17     Sequence as _Sequence,
     18     Tuple as _Tuple,
     19 )
     20 
     21 import numpy as _np
     22 import torch as _torch
     23 import torch.nn as _nn
     24 
     25 from ._activations import get_activation as _get_activation
     26 from .base import BaseNet as _BaseNet
     27 from ._names import ACTIVATION_NAME as _ACTIVATION_NAME, CONV_NAME as _CONV_NAME
     28 
     29 
     30 class Conv1d(_nn.Conv1d):
     31     def export_weights(self) -> _torch.Tensor:
     32         tensors = []
     33         if self.weight is not None:
     34             tensors.append(self.weight.data.flatten())
     35         if self.bias is not None:
     36             tensors.append(self.bias.data.flatten())
     37         if len(tensors) == 0:
     38             return _torch.zeros((0,))
     39         else:
     40             return _torch.cat(tensors)
     41 
     42     def import_weights(self, weights: _torch.Tensor, i: int) -> int:
     43         if self.weight is not None:
     44             n = self.weight.numel()
     45             self.weight.data = (
     46                 weights[i : i + n].reshape(self.weight.shape).to(self.weight.device)
     47             )
     48             i += n
     49         if self.bias is not None:
     50             n = self.bias.numel()
     51             self.bias.data = (
     52                 weights[i : i + n].reshape(self.bias.shape).to(self.bias.device)
     53             )
     54             i += n
     55         return i
     56 
     57 
     58 class _Layer(_nn.Module):
     59     def __init__(
     60         self,
     61         condition_size: int,
     62         channels: int,
     63         kernel_size: int,
     64         dilation: int,
     65         activation: str,
     66         gated: bool,
     67     ):
     68         super().__init__()
     69         # Input mixer takes care of the bias
     70         mid_channels = 2 * channels if gated else channels
     71         self._conv = Conv1d(channels, mid_channels, kernel_size, dilation=dilation)
     72         # Custom init: favors direct input-output
     73         # self._conv.weight.data.zero_()
     74         self._input_mixer = Conv1d(condition_size, mid_channels, 1, bias=False)
     75         self._activation = _get_activation(activation)
     76         self._activation_name = activation
     77         self._1x1 = Conv1d(channels, channels, 1)
     78         self._gated = gated
     79 
     80     @property
     81     def activation_name(self) -> str:
     82         return self._activation_name
     83 
     84     @property
     85     def conv(self) -> Conv1d:
     86         return self._conv
     87 
     88     @property
     89     def gated(self) -> bool:
     90         return self._gated
     91 
     92     @property
     93     def kernel_size(self) -> int:
     94         return self._conv.kernel_size[0]
     95 
     96     def export_weights(self) -> _torch.Tensor:
     97         return _torch.cat(
     98             [
     99                 self.conv.export_weights(),
    100                 self._input_mixer.export_weights(),
    101                 self._1x1.export_weights(),
    102             ]
    103         )
    104 
    105     def forward(
    106         self, x: _torch.Tensor, h: _Optional[_torch.Tensor], out_length: int
    107     ) -> _Tuple[_Optional[_torch.Tensor], _torch.Tensor]:
    108         """
    109         :param x: (B,C,L1) From last layer
    110         :param h: (B,DX,L2) Conditioning. If first, ignored.
    111 
    112         :return:
    113             If not final:
    114                 (B,C,L1-d) to next layer
    115                 (B,C,L1-d) to mixer
    116             If final, next layer is None
    117         """
    118         zconv = self.conv(x)
    119         z1 = zconv + self._input_mixer(h)[:, :, -zconv.shape[2] :]
    120         post_activation = (
    121             self._activation(z1)
    122             if not self._gated
    123             else (
    124                 self._activation(z1[:, : self._channels])
    125                 * _torch.sigmoid(z1[:, self._channels :])
    126             )
    127         )
    128         return (
    129             x[:, :, -post_activation.shape[2] :] + self._1x1(post_activation),
    130             post_activation[:, :, -out_length:],
    131         )
    132 
    133     def import_weights(self, weights: _torch.Tensor, i: int) -> int:
    134         i = self.conv.import_weights(weights, i)
    135         i = self._input_mixer.import_weights(weights, i)
    136         return self._1x1.import_weights(weights, i)
    137 
    138     @property
    139     def _channels(self) -> int:
    140         return self._1x1.in_channels
    141 
    142 
    143 class _Layers(_nn.Module):
    144     """
    145     Takes in the input and condition (and maybe the head input so far); outputs the
    146     layer output and head input.
    147 
    148     The original WaveNet only uses one of these, but you can stack multiple of this
    149     module to vary the channels throughout with minimal extra channel-changing conv
    150     layers.
    151     """
    152 
    153     def __init__(
    154         self,
    155         input_size: int,
    156         condition_size: int,
    157         head_size,
    158         channels: int,
    159         kernel_size: int,
    160         dilations: _Sequence[int],
    161         activation: str = "Tanh",
    162         gated: bool = True,
    163         head_bias: bool = True,
    164     ):
    165         super().__init__()
    166         self._rechannel = Conv1d(input_size, channels, 1, bias=False)
    167         self._layers = _nn.ModuleList(
    168             [
    169                 _Layer(
    170                     condition_size, channels, kernel_size, dilation, activation, gated
    171                 )
    172                 for dilation in dilations
    173             ]
    174         )
    175         # Convert the head input from channels to head_size
    176         self._head_rechannel = Conv1d(channels, head_size, 1, bias=head_bias)
    177 
    178         self._config = {
    179             "input_size": input_size,
    180             "condition_size": condition_size,
    181             "head_size": head_size,
    182             "channels": channels,
    183             "kernel_size": kernel_size,
    184             "dilations": dilations,
    185             "activation": activation,
    186             "gated": gated,
    187             "head_bias": head_bias,
    188         }
    189 
    190     @property
    191     def receptive_field(self) -> int:
    192         return 1 + (self._kernel_size - 1) * sum(self._dilations)
    193 
    194     def export_config(self):
    195         return _deepcopy(self._config)
    196 
    197     def export_weights(self) -> _torch.Tensor:
    198         return _torch.cat(
    199             [self._rechannel.export_weights()]
    200             + [layer.export_weights() for layer in self._layers]
    201             + [self._head_rechannel.export_weights()]
    202         )
    203 
    204     def import_weights(self, weights: _torch.Tensor, i: int) -> int:
    205         i = self._rechannel.import_weights(weights, i)
    206         for layer in self._layers:
    207             i = layer.import_weights(weights, i)
    208         return self._head_rechannel.import_weights(weights, i)
    209 
    210     def forward(
    211         self,
    212         x: _torch.Tensor,
    213         c: _torch.Tensor,
    214         head_input: _Optional[_torch.Tensor] = None,
    215     ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
    216         """
    217         :param x: (B,Dx,L) layer input
    218         :param c: (B,Dc,L) condition
    219 
    220         :return:
    221             (B,Dc,L-R+1) head input
    222             (B,Dc,L-R+1) layer output
    223         """
    224         out_length = x.shape[2] - (self.receptive_field - 1)
    225         x = self._rechannel(x)
    226         for layer in self._layers:
    227             x, head_term = layer(x, c, out_length)  # Ensures head_term sample length
    228             head_input = (
    229                 head_term
    230                 if head_input is None
    231                 else head_input[:, :, -out_length:] + head_term
    232             )
    233         return self._head_rechannel(head_input), x
    234 
    235     @property
    236     def _dilations(self) -> _Sequence[int]:
    237         return self._config["dilations"]
    238 
    239     @property
    240     def _kernel_size(self) -> int:
    241         return self._layers[0].kernel_size
    242 
    243 
    244 class _Head(_nn.Module):
    245     def __init__(
    246         self,
    247         in_channels: int,
    248         channels: int,
    249         activation: str,
    250         num_layers: int,
    251         out_channels: int,
    252     ):
    253         super().__init__()
    254 
    255         def block(cx, cy):
    256             net = _nn.Sequential()
    257             net.add_module(_ACTIVATION_NAME, _get_activation(activation))
    258             net.add_module(_CONV_NAME, Conv1d(cx, cy, 1))
    259             return net
    260 
    261         assert num_layers > 0
    262 
    263         layers = _nn.Sequential()
    264         cin = in_channels
    265         for i in range(num_layers):
    266             layers.add_module(
    267                 f"layer_{i}",
    268                 block(cin, channels if i != num_layers - 1 else out_channels),
    269             )
    270             cin = channels
    271         self._layers = layers
    272 
    273         self._config = {
    274             "channels": channels,
    275             "activation": activation,
    276             "num_layers": num_layers,
    277             "out_channels": out_channels,
    278         }
    279 
    280     def export_config(self):
    281         return _deepcopy(self._config)
    282 
    283     def export_weights(self) -> _torch.Tensor:
    284         return _torch.cat([layer[1].export_weights() for layer in self._layers])
    285 
    286     def forward(self, *args, **kwargs):
    287         return self._layers(*args, **kwargs)
    288 
    289     def import_weights(self, weights: _torch.Tensor, i: int) -> int:
    290         for layer in self._layers:
    291             i = layer[1].import_weights(weights, i)
    292         return i
    293 
    294 
    295 class _WaveNet(_nn.Module):
    296     def __init__(
    297         self,
    298         layers_configs: _Sequence[_Dict],
    299         head_config: _Optional[_Dict] = None,
    300         head_scale: float = 1.0,
    301     ):
    302         super().__init__()
    303 
    304         self._layers = _nn.ModuleList([_Layers(**lc) for lc in layers_configs])
    305         self._head = None if head_config is None else _Head(**head_config)
    306         self._head_scale = head_scale
    307 
    308     @property
    309     def receptive_field(self) -> int:
    310         return 1 + sum([(layer.receptive_field - 1) for layer in self._layers])
    311 
    312     def export_config(self):
    313         return {
    314             "layers": [layers.export_config() for layers in self._layers],
    315             "head": None if self._head is None else self._head.export_config(),
    316             "head_scale": self._head_scale,
    317         }
    318 
    319     def export_weights(self) -> _np.ndarray:
    320         """
    321         :return: 1D array
    322         """
    323         weights = _torch.cat([layer.export_weights() for layer in self._layers])
    324         if self._head is not None:
    325             weights = _torch.cat([weights, self._head.export_weights()])
    326         weights = _torch.cat([weights.cpu(), _torch.Tensor([self._head_scale])])
    327         return weights.detach().cpu().numpy()
    328 
    329     def import_weights(self, weights: _torch.Tensor):
    330         i = 0
    331         for layer in self._layers:
    332             i = layer.import_weights(weights, i)
    333 
    334     def forward(self, x: _torch.Tensor) -> _torch.Tensor:
    335         """
    336         :param x: (B,Cx,L)
    337         :return: (B,Cy,L-R)
    338         """
    339         y, head_input = x, None
    340         for layer in self._layers:
    341             head_input, y = layer(y, x, head_input=head_input)
    342         head_input = self._head_scale * head_input
    343         return head_input if self._head is None else self._head(head_input)
    344 
    345 
    346 class WaveNet(_BaseNet):
    347     def __init__(self, *args, sample_rate: _Optional[float] = None, **kwargs):
    348         super().__init__(sample_rate=sample_rate)
    349         self._net = _WaveNet(*args, **kwargs)
    350 
    351     @property
    352     def pad_start_default(self) -> bool:
    353         return True
    354 
    355     @property
    356     def receptive_field(self) -> int:
    357         return self._net.receptive_field
    358 
    359     def export_cpp_header(self, filename: _Path):
    360         with _TemporaryDirectory() as tmpdir:
    361             tmpdir = _Path(tmpdir)
    362             WaveNet.export(self, _Path(tmpdir))  # Hacky...need to work w/ CatWaveNet
    363             with open(_Path(tmpdir, "model.nam"), "r") as fp:
    364                 _c = _json.load(fp)
    365             version = _c["version"]
    366             config = _c["config"]
    367 
    368             if config["head"] is not None:
    369                 raise NotImplementedError("No heads yet")
    370             # head_scale
    371             # with_head
    372             # parametric
    373 
    374             # String for layer array params:
    375             s_lap = (
    376                 "const std::vector<wavenet::LayerArrayParams> LAYER_ARRAY_PARAMS{\n",
    377             )
    378             for i, lc in enumerate(config["layers"], 1):
    379                 s_lap_line = (
    380                     f'  wavenet::LayerArrayParams({lc["input_size"]}, '
    381                     f'{lc["condition_size"]}, {lc["head_size"]}, {lc["channels"]}, '
    382                     f'{lc["kernel_size"]}, std::vector<int> '
    383                     "{"
    384                     + ", ".join([str(d) for d in lc["dilations"]])
    385                     + "}, "
    386                     + (
    387                         f'"{lc["activation"]}", {str(lc["gated"]).lower()}, '
    388                         f'{str(lc["head_bias"]).lower()})'
    389                     )
    390                 )
    391                 if i < len(config["layers"]):
    392                     s_lap_line += ","
    393                 s_lap_line += "\n"
    394                 s_lap += (s_lap_line,)
    395             s_lap += ("};\n",)
    396             s_parametric = self._export_cpp_header_parametric(config.get("parametric"))
    397             with open(filename, "w") as f:
    398                 f.writelines(
    399                     (
    400                         "#pragma once\n",
    401                         "// Automatically-generated model file\n",
    402                         "#include <vector>\n",
    403                         '#include "json.hpp"\n',
    404                         '#include "wavenet.h"\n',
    405                         f'#define PYTHON_MODEL_VERSION "{version}"\n',
    406                     )
    407                     + s_lap
    408                     + (
    409                         f'const float HEAD_SCALE = {config["head_scale"]};\n',
    410                         "const bool WITH_HEAD = false;\n",
    411                     )
    412                     + s_parametric
    413                     + (
    414                         "std::vector<float> PARAMS{"
    415                         + ", ".join([f"{w:.16f}f" for w in _c["weights"]])
    416                         + "};\n",
    417                     )
    418                 )
    419 
    420     def import_weights(self, weights: _Sequence[float]):
    421         if not isinstance(weights, _torch.Tensor):
    422             weights = _torch.Tensor(weights)
    423         self._net.import_weights(weights)
    424 
    425     def _export_config(self):
    426         return self._net.export_config()
    427 
    428     def _export_cpp_header_parametric(self, config):
    429         if config is not None:
    430             raise ValueError("Got non-None parametric config")
    431         return ("nlohmann::json PARAMETRIC {};\n",)
    432 
    433     def _export_weights(self) -> _np.ndarray:
    434         return self._net.export_weights()
    435 
    436     def _forward(self, x):
    437         if x.ndim == 2:
    438             x = x[:, None, :]
    439         y = self._net(x)
    440         assert y.shape[1] == 1
    441         return y[:, 0, :]