wavenet.py (14240B)
1 # File: wavenet.py 2 # Created Date: Friday July 29th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 WaveNet implementation 7 https://arxiv.org/abs/1609.03499 8 """ 9 10 import json as _json 11 from copy import deepcopy as _deepcopy 12 from pathlib import Path as _Path 13 from tempfile import TemporaryDirectory as _TemporaryDirectory 14 from typing import ( 15 Dict as _Dict, 16 Optional as _Optional, 17 Sequence as _Sequence, 18 Tuple as _Tuple, 19 ) 20 21 import numpy as _np 22 import torch as _torch 23 import torch.nn as _nn 24 25 from ._activations import get_activation as _get_activation 26 from .base import BaseNet as _BaseNet 27 from ._names import ACTIVATION_NAME as _ACTIVATION_NAME, CONV_NAME as _CONV_NAME 28 29 30 class Conv1d(_nn.Conv1d): 31 def export_weights(self) -> _torch.Tensor: 32 tensors = [] 33 if self.weight is not None: 34 tensors.append(self.weight.data.flatten()) 35 if self.bias is not None: 36 tensors.append(self.bias.data.flatten()) 37 if len(tensors) == 0: 38 return _torch.zeros((0,)) 39 else: 40 return _torch.cat(tensors) 41 42 def import_weights(self, weights: _torch.Tensor, i: int) -> int: 43 if self.weight is not None: 44 n = self.weight.numel() 45 self.weight.data = ( 46 weights[i : i + n].reshape(self.weight.shape).to(self.weight.device) 47 ) 48 i += n 49 if self.bias is not None: 50 n = self.bias.numel() 51 self.bias.data = ( 52 weights[i : i + n].reshape(self.bias.shape).to(self.bias.device) 53 ) 54 i += n 55 return i 56 57 58 class _Layer(_nn.Module): 59 def __init__( 60 self, 61 condition_size: int, 62 channels: int, 63 kernel_size: int, 64 dilation: int, 65 activation: str, 66 gated: bool, 67 ): 68 super().__init__() 69 # Input mixer takes care of the bias 70 mid_channels = 2 * channels if gated else channels 71 self._conv = Conv1d(channels, mid_channels, kernel_size, dilation=dilation) 72 # Custom init: favors direct input-output 73 # self._conv.weight.data.zero_() 74 self._input_mixer = Conv1d(condition_size, mid_channels, 1, bias=False) 75 self._activation = _get_activation(activation) 76 self._activation_name = activation 77 self._1x1 = Conv1d(channels, channels, 1) 78 self._gated = gated 79 80 @property 81 def activation_name(self) -> str: 82 return self._activation_name 83 84 @property 85 def conv(self) -> Conv1d: 86 return self._conv 87 88 @property 89 def gated(self) -> bool: 90 return self._gated 91 92 @property 93 def kernel_size(self) -> int: 94 return self._conv.kernel_size[0] 95 96 def export_weights(self) -> _torch.Tensor: 97 return _torch.cat( 98 [ 99 self.conv.export_weights(), 100 self._input_mixer.export_weights(), 101 self._1x1.export_weights(), 102 ] 103 ) 104 105 def forward( 106 self, x: _torch.Tensor, h: _Optional[_torch.Tensor], out_length: int 107 ) -> _Tuple[_Optional[_torch.Tensor], _torch.Tensor]: 108 """ 109 :param x: (B,C,L1) From last layer 110 :param h: (B,DX,L2) Conditioning. If first, ignored. 111 112 :return: 113 If not final: 114 (B,C,L1-d) to next layer 115 (B,C,L1-d) to mixer 116 If final, next layer is None 117 """ 118 zconv = self.conv(x) 119 z1 = zconv + self._input_mixer(h)[:, :, -zconv.shape[2] :] 120 post_activation = ( 121 self._activation(z1) 122 if not self._gated 123 else ( 124 self._activation(z1[:, : self._channels]) 125 * _torch.sigmoid(z1[:, self._channels :]) 126 ) 127 ) 128 return ( 129 x[:, :, -post_activation.shape[2] :] + self._1x1(post_activation), 130 post_activation[:, :, -out_length:], 131 ) 132 133 def import_weights(self, weights: _torch.Tensor, i: int) -> int: 134 i = self.conv.import_weights(weights, i) 135 i = self._input_mixer.import_weights(weights, i) 136 return self._1x1.import_weights(weights, i) 137 138 @property 139 def _channels(self) -> int: 140 return self._1x1.in_channels 141 142 143 class _Layers(_nn.Module): 144 """ 145 Takes in the input and condition (and maybe the head input so far); outputs the 146 layer output and head input. 147 148 The original WaveNet only uses one of these, but you can stack multiple of this 149 module to vary the channels throughout with minimal extra channel-changing conv 150 layers. 151 """ 152 153 def __init__( 154 self, 155 input_size: int, 156 condition_size: int, 157 head_size, 158 channels: int, 159 kernel_size: int, 160 dilations: _Sequence[int], 161 activation: str = "Tanh", 162 gated: bool = True, 163 head_bias: bool = True, 164 ): 165 super().__init__() 166 self._rechannel = Conv1d(input_size, channels, 1, bias=False) 167 self._layers = _nn.ModuleList( 168 [ 169 _Layer( 170 condition_size, channels, kernel_size, dilation, activation, gated 171 ) 172 for dilation in dilations 173 ] 174 ) 175 # Convert the head input from channels to head_size 176 self._head_rechannel = Conv1d(channels, head_size, 1, bias=head_bias) 177 178 self._config = { 179 "input_size": input_size, 180 "condition_size": condition_size, 181 "head_size": head_size, 182 "channels": channels, 183 "kernel_size": kernel_size, 184 "dilations": dilations, 185 "activation": activation, 186 "gated": gated, 187 "head_bias": head_bias, 188 } 189 190 @property 191 def receptive_field(self) -> int: 192 return 1 + (self._kernel_size - 1) * sum(self._dilations) 193 194 def export_config(self): 195 return _deepcopy(self._config) 196 197 def export_weights(self) -> _torch.Tensor: 198 return _torch.cat( 199 [self._rechannel.export_weights()] 200 + [layer.export_weights() for layer in self._layers] 201 + [self._head_rechannel.export_weights()] 202 ) 203 204 def import_weights(self, weights: _torch.Tensor, i: int) -> int: 205 i = self._rechannel.import_weights(weights, i) 206 for layer in self._layers: 207 i = layer.import_weights(weights, i) 208 return self._head_rechannel.import_weights(weights, i) 209 210 def forward( 211 self, 212 x: _torch.Tensor, 213 c: _torch.Tensor, 214 head_input: _Optional[_torch.Tensor] = None, 215 ) -> _Tuple[_torch.Tensor, _torch.Tensor]: 216 """ 217 :param x: (B,Dx,L) layer input 218 :param c: (B,Dc,L) condition 219 220 :return: 221 (B,Dc,L-R+1) head input 222 (B,Dc,L-R+1) layer output 223 """ 224 out_length = x.shape[2] - (self.receptive_field - 1) 225 x = self._rechannel(x) 226 for layer in self._layers: 227 x, head_term = layer(x, c, out_length) # Ensures head_term sample length 228 head_input = ( 229 head_term 230 if head_input is None 231 else head_input[:, :, -out_length:] + head_term 232 ) 233 return self._head_rechannel(head_input), x 234 235 @property 236 def _dilations(self) -> _Sequence[int]: 237 return self._config["dilations"] 238 239 @property 240 def _kernel_size(self) -> int: 241 return self._layers[0].kernel_size 242 243 244 class _Head(_nn.Module): 245 def __init__( 246 self, 247 in_channels: int, 248 channels: int, 249 activation: str, 250 num_layers: int, 251 out_channels: int, 252 ): 253 super().__init__() 254 255 def block(cx, cy): 256 net = _nn.Sequential() 257 net.add_module(_ACTIVATION_NAME, _get_activation(activation)) 258 net.add_module(_CONV_NAME, Conv1d(cx, cy, 1)) 259 return net 260 261 assert num_layers > 0 262 263 layers = _nn.Sequential() 264 cin = in_channels 265 for i in range(num_layers): 266 layers.add_module( 267 f"layer_{i}", 268 block(cin, channels if i != num_layers - 1 else out_channels), 269 ) 270 cin = channels 271 self._layers = layers 272 273 self._config = { 274 "channels": channels, 275 "activation": activation, 276 "num_layers": num_layers, 277 "out_channels": out_channels, 278 } 279 280 def export_config(self): 281 return _deepcopy(self._config) 282 283 def export_weights(self) -> _torch.Tensor: 284 return _torch.cat([layer[1].export_weights() for layer in self._layers]) 285 286 def forward(self, *args, **kwargs): 287 return self._layers(*args, **kwargs) 288 289 def import_weights(self, weights: _torch.Tensor, i: int) -> int: 290 for layer in self._layers: 291 i = layer[1].import_weights(weights, i) 292 return i 293 294 295 class _WaveNet(_nn.Module): 296 def __init__( 297 self, 298 layers_configs: _Sequence[_Dict], 299 head_config: _Optional[_Dict] = None, 300 head_scale: float = 1.0, 301 ): 302 super().__init__() 303 304 self._layers = _nn.ModuleList([_Layers(**lc) for lc in layers_configs]) 305 self._head = None if head_config is None else _Head(**head_config) 306 self._head_scale = head_scale 307 308 @property 309 def receptive_field(self) -> int: 310 return 1 + sum([(layer.receptive_field - 1) for layer in self._layers]) 311 312 def export_config(self): 313 return { 314 "layers": [layers.export_config() for layers in self._layers], 315 "head": None if self._head is None else self._head.export_config(), 316 "head_scale": self._head_scale, 317 } 318 319 def export_weights(self) -> _np.ndarray: 320 """ 321 :return: 1D array 322 """ 323 weights = _torch.cat([layer.export_weights() for layer in self._layers]) 324 if self._head is not None: 325 weights = _torch.cat([weights, self._head.export_weights()]) 326 weights = _torch.cat([weights.cpu(), _torch.Tensor([self._head_scale])]) 327 return weights.detach().cpu().numpy() 328 329 def import_weights(self, weights: _torch.Tensor): 330 i = 0 331 for layer in self._layers: 332 i = layer.import_weights(weights, i) 333 334 def forward(self, x: _torch.Tensor) -> _torch.Tensor: 335 """ 336 :param x: (B,Cx,L) 337 :return: (B,Cy,L-R) 338 """ 339 y, head_input = x, None 340 for layer in self._layers: 341 head_input, y = layer(y, x, head_input=head_input) 342 head_input = self._head_scale * head_input 343 return head_input if self._head is None else self._head(head_input) 344 345 346 class WaveNet(_BaseNet): 347 def __init__(self, *args, sample_rate: _Optional[float] = None, **kwargs): 348 super().__init__(sample_rate=sample_rate) 349 self._net = _WaveNet(*args, **kwargs) 350 351 @property 352 def pad_start_default(self) -> bool: 353 return True 354 355 @property 356 def receptive_field(self) -> int: 357 return self._net.receptive_field 358 359 def export_cpp_header(self, filename: _Path): 360 with _TemporaryDirectory() as tmpdir: 361 tmpdir = _Path(tmpdir) 362 WaveNet.export(self, _Path(tmpdir)) # Hacky...need to work w/ CatWaveNet 363 with open(_Path(tmpdir, "model.nam"), "r") as fp: 364 _c = _json.load(fp) 365 version = _c["version"] 366 config = _c["config"] 367 368 if config["head"] is not None: 369 raise NotImplementedError("No heads yet") 370 # head_scale 371 # with_head 372 # parametric 373 374 # String for layer array params: 375 s_lap = ( 376 "const std::vector<wavenet::LayerArrayParams> LAYER_ARRAY_PARAMS{\n", 377 ) 378 for i, lc in enumerate(config["layers"], 1): 379 s_lap_line = ( 380 f' wavenet::LayerArrayParams({lc["input_size"]}, ' 381 f'{lc["condition_size"]}, {lc["head_size"]}, {lc["channels"]}, ' 382 f'{lc["kernel_size"]}, std::vector<int> ' 383 "{" 384 + ", ".join([str(d) for d in lc["dilations"]]) 385 + "}, " 386 + ( 387 f'"{lc["activation"]}", {str(lc["gated"]).lower()}, ' 388 f'{str(lc["head_bias"]).lower()})' 389 ) 390 ) 391 if i < len(config["layers"]): 392 s_lap_line += "," 393 s_lap_line += "\n" 394 s_lap += (s_lap_line,) 395 s_lap += ("};\n",) 396 s_parametric = self._export_cpp_header_parametric(config.get("parametric")) 397 with open(filename, "w") as f: 398 f.writelines( 399 ( 400 "#pragma once\n", 401 "// Automatically-generated model file\n", 402 "#include <vector>\n", 403 '#include "json.hpp"\n', 404 '#include "wavenet.h"\n', 405 f'#define PYTHON_MODEL_VERSION "{version}"\n', 406 ) 407 + s_lap 408 + ( 409 f'const float HEAD_SCALE = {config["head_scale"]};\n', 410 "const bool WITH_HEAD = false;\n", 411 ) 412 + s_parametric 413 + ( 414 "std::vector<float> PARAMS{" 415 + ", ".join([f"{w:.16f}f" for w in _c["weights"]]) 416 + "};\n", 417 ) 418 ) 419 420 def import_weights(self, weights: _Sequence[float]): 421 if not isinstance(weights, _torch.Tensor): 422 weights = _torch.Tensor(weights) 423 self._net.import_weights(weights) 424 425 def _export_config(self): 426 return self._net.export_config() 427 428 def _export_cpp_header_parametric(self, config): 429 if config is not None: 430 raise ValueError("Got non-None parametric config") 431 return ("nlohmann::json PARAMETRIC {};\n",) 432 433 def _export_weights(self) -> _np.ndarray: 434 return self._net.export_weights() 435 436 def _forward(self, x): 437 if x.ndim == 2: 438 x = x[:, None, :] 439 y = self._net(x) 440 assert y.shape[1] == 1 441 return y[:, 0, :]