neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit f5b3909ad2be15eb62d44cd43e0bbc88edc1c389
parent 6e87932a3a0f78e34c1c9bbb1924d2fe3c4e0144
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 29 Apr 2023 13:16:40 -0700

New reamping file `v2_0_0.wav` (#217)

* Support training file 2.0.0

* Version comparison and testing

* Move to latest input version

* Black

* Bump version
Diffstat:
Mnam/_version.py | 2+-
Mnam/train/_version.py | 3+++
Mnam/train/colab.py | 18++++++++++++------
Mnam/train/core.py | 180+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------
Atests/test_nam/test_train/__init__.py | 0
Atests/test_nam/test_train/test_version.py | 22++++++++++++++++++++++
6 files changed, 191 insertions(+), 34 deletions(-)

diff --git a/nam/_version.py b/nam/_version.py @@ -1 +1 @@ -__version__ = "0.5.2" +__version__ = "0.5.3" diff --git a/nam/train/_version.py b/nam/train/_version.py @@ -13,6 +13,9 @@ class Version: self.minor = minor self.patch = patch + def __eq__(self, other) -> bool: + return self.major == other.major and self.minor == other.minor and self.patch == other.patch + def __lt__(self, other) -> bool: if self.major != other.major: return self.major < other.major diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -15,7 +15,12 @@ from ._version import Version from .core import train -_INPUT_BASENAMES = ((Version(1, 1, 1), "v1_1_1.wav"), (Version(1, 0, 0), "v1.wav")) +_INPUT_BASENAMES = ( + (Version(2, 0, 0), "v2_0_0.wav"), + (Version(1, 1, 1), "v1_1_1.wav"), + (Version(1, 0, 0), "v1.wav"), +) +_LATEST_VERSION = _INPUT_BASENAMES[0] _BUGGY_INPUT_BASENAMES = { # 1.1.0 has the spikes at the wrong spots. "v1_1_0.wav" @@ -29,19 +34,20 @@ def _check_for_files() -> Tuple[Version, str]: for name in _BUGGY_INPUT_BASENAMES: if Path(name).exists(): raise RuntimeError( - f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_INPUT_BASENAMES[0][1]}" + f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}" ) - for i, (input_version, input_basename) in enumerate(_INPUT_BASENAMES): + for input_version, input_basename in enumerate(_INPUT_BASENAMES): if Path(input_basename).exists(): - if i > 0: + if input_version != _LATEST_VERSION[0]: print( f"WARNING: Using out-of-date input file {input_basename}. " - "Recommend downloading and using the latest version." + "Recommend downloading and using the latest version, " + f"{_LATEST_VERSION[1]}." ) break else: raise FileNotFoundError( - f"Didn't find NAM's input audio file. Please upload {_INPUT_BASENAMES[0][1]}" + f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION[1]}" ) if not Path(_OUTPUT_BASENAME).exists(): raise FileNotFoundError( diff --git a/nam/train/core.py b/nam/train/core.py @@ -9,16 +9,18 @@ Functions used by the GUI trainer. import hashlib from enum import Enum from time import time -from typing import Optional, Union +from typing import Optional, Sequence, Union import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch +import torch.nn as nn from torch.utils.data import DataLoader -from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np +from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model +from ..models.losses import esr from ._version import Version @@ -45,6 +47,7 @@ def _detect_input_version(input_path) -> Version: version = { "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), + "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0), }.get(file_hash) if version is None: raise RuntimeError( @@ -55,9 +58,13 @@ def _detect_input_version(input_path) -> Version: _V1_BLIP_LOCATIONS = 12_000, 36_000 +_V2_START_BLIP_LOCATIONS = _V1_BLIP_LOCATIONS +_V2_END_BLIP_LOCATIONS = -36_000, -12_000 -def _calibrate_delay_v1(input_path, output_path) -> int: +def _calibrate_delay_v1( + input_path, output_path, locations: Sequence[int] = _V1_BLIP_LOCATIONS +) -> int: lookahead = 1_000 lookback = 10_000 safety_factor = 4 @@ -68,7 +75,7 @@ def _calibrate_delay_v1(input_path, output_path) -> int: trigger_threshold = max(background_level + 0.01, 1.01 * background_level) delays = [] - for blip_index, i in enumerate(_V1_BLIP_LOCATIONS, 1): + for blip_index, i in enumerate(locations, 1): start_looking = i - lookahead stop_looking = i + lookback @@ -106,6 +113,9 @@ def _calibrate_delay_v1(input_path, output_path) -> int: return delay +_calibrate_delay_v2 = _calibrate_delay_v1 + + def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True): print("Plotting the delay for manual inspection...") x = wav_to_np(input_path)[:48_000] @@ -141,15 +151,20 @@ def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True): plt.show() # This doesn't freeze the notebook +_plot_delay_v2 = _plot_delay_v1 + + def _calibrate_delay( delay: Optional[int], input_version: Version, input_path: str, output_path: str, - silent: bool=False + silent: bool = False, ) -> int: if input_version.major == 1: calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 + elif input_version.major == 2: + calibrate, plot = _calibrate_delay_v2, _plot_delay_v2 else: raise NotImplementedError( f"Input calibration not implemented for input version {input_version}" @@ -164,6 +179,70 @@ def _calibrate_delay( return delay +def _check_v1(*args, **kwargs): + return True + + +def _check_v2(input_path, output_path) -> bool: + print("V2 checks...") + rate = REQUIRED_RATE + y = wav_to_tensor(output_path, rate=rate) + y_val_1 = y[-19 * rate : -10 * rate] + y_val_2 = y[-10 * rate : -1 * rate] + esr_replicate = esr(y_val_1, y_val_2).item() + print(f"Replicate ESR is {esr_replicate:.8f}.") + # Do the blips line up? + # [start/end,replicate] + blips = [ + [y[: rate // 2], y[rate // 2 : rate]], + [y[-rate : -rate // 2], y[-rate // 2 :]], + ] + mse = nn.MSELoss() + mse_0 = mse(blips[0][0], blips[0][1]).item() # Within start + mse_1 = mse(blips[1][0], blips[1][1]).item() # Within end + mse_cross_0 = mse(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end + mse_cross_1 = mse(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end + + mse_max = max(mse_0, mse_1) + # mse_range = mse_max - min(mse_0, mse_1) + safety_factor = 2.0 + if mse_cross_0 > safety_factor * mse_max or mse_cross_1 > safety_factor * mse_max: + plt.plot() + [ + [ + plt.plot(b, label=f"{startend}, replicate {replicate}") + for replicate, b in enumerate(bi, 1) + ] + for startend, bi in zip(("start", "end"), blips) + ] + plt.xlabel("Sample") + plt.ylabel("Output") + plt.legend() + plt.grid() + print( + "Failed blip checks. Did something change between the start and end of reamping?" + ) + plt.show() + return False + return True + + +def _check(input_path: str, output_path: str, input_version: Version) -> bool: + """ + Ensure that everything should go smoothly + + :return: True if looks good + """ + if input_version.major == 1: + f = _check_v1 + elif input_version.major == 2: + f = _check_v2 + else: + print(f"Checks not implemented for input version {input_version}; skip") + return True + return f(input_path, output_path) + + def _get_wavenet_config(architecture): return { Architecture.STANDARD: { @@ -251,19 +330,40 @@ def _get_wavenet_config(architecture): def _get_configs( + input_version: Version, input_basename: str, output_basename: str, delay: int, epochs: int, architecture: Architecture, + ny: int, lr: float, lr_decay: float, + batch_size: int, ): - val_seconds = 9 - train_val_split = -val_seconds * REQUIRED_RATE + def get_kwargs(): + val_seconds = 9 + rate = REQUIRED_RATE + if input_version.major == 1: + train_val_split = -val_seconds * rate + train_kwargs = {"stop": train_val_split} + validation_kwargs = {"start": train_val_split} + elif input_version.major == 2: + blip_seconds = 1 + val_replicates = 2 + train_stop = -(blip_seconds + val_replicates * val_seconds) * rate + validation_start = train_stop + validation_stop = -(blip_seconds + val_seconds) * rate + train_kwargs = {"stop": train_stop} + validation_kwargs = {"start": validation_start, "stop": validation_stop} + else: + raise NotImplementedError(f"kwargs for input version {input_version}") + return train_kwargs, validation_kwargs + + train_kwargs, validation_kwargs = get_kwargs() data_config = { - "train": {"ny": 8192, "stop": train_val_split}, - "validation": {"ny": None, "start": train_val_split}, + "train": {"ny": ny, **train_kwargs}, + "validation": {"ny": None, **validation_kwargs}, "common": { "x_path": input_basename, "y_path": output_basename, @@ -290,7 +390,7 @@ def _get_configs( device_config = {} learning_config = { "train_dataloader": { - "batch_size": 16, + "batch_size": batch_size, "shuffle": True, "pin_memory": True, "drop_last": True, @@ -310,12 +410,12 @@ def _esr(pred: torch.Tensor, target: torch.Tensor) -> float: def _plot( - model, - ds, - window_start: Optional[int] = None, - window_end: Optional[int] = None, - filepath: Optional[str] = None, - silent: bool = False + model, + ds, + window_start: Optional[int] = None, + window_end: Optional[int] = None, + filepath: Optional[str] = None, + silent: bool = False, ): print("Plotting a comparison of your model with the target output...") with torch.no_grad(): @@ -351,6 +451,7 @@ def _plot( if not silent: plt.show() + def train( input_path: str, output_path: str, @@ -359,34 +460,42 @@ def train( epochs=100, delay=None, architecture: Union[Architecture, str] = Architecture.STANDARD, + batch_size: int = 16, + ny: int = 8192, lr=0.004, lr_decay=0.007, seed: Optional[int] = 0, - save_plot: bool=False, - silent: bool=False, - modelname: str="model" + save_plot: bool = False, + silent: bool = False, + modelname: str = "model", ): if seed is not None: torch.manual_seed(seed) - # This needs more thought... - # 1. Does the user want me to calibrate the delay? - # 2. Does the user want to see what the chosen (by them or me) delay looks like? if delay is None: if input_version is None: input_version = _detect_input_version(input_path) - delay = _calibrate_delay(delay, input_version, input_path, output_path, silent=silent) + delay = _calibrate_delay( + delay, input_version, input_path, output_path, silent=silent + ) else: print(f"Delay provided as {delay}; skip calibration") + if not _check(input_path, output_path, input_version): + print("Failed checks; exit training") + return + data_config, model_config, learning_config = _get_configs( + input_version, input_path, output_path, delay, epochs, Architecture(architecture), + ny, lr, lr_decay, + batch_size, ) print("Starting training. It's time to kick ass and chew bubblegum!") @@ -424,12 +533,29 @@ def train( model.cpu() model.eval() + def window_kwargs(version: Version): + if version.major == 1: + return dict( + window_start=100_000, # Start of the plotting window, in samples + window_end=101_000, # End of the plotting window, in samples + ) + elif version.major == 2: + # Same validation set even though it's a different spot in the reamp file + return dict( + window_start=100_000, # Start of the plotting window, in samples + window_end=101_000, # End of the plotting window, in samples + ) + # Fallback: + return dict( + window_start=100_000, # Start of the plotting window, in samples + window_end=101_000, # End of the plotting window, in samples + ) + _plot( model, dataset_validation, - window_start=100_000, # Start of the plotting window, in samples - window_end=101_000, # End of the plotting window, in samples - filepath=train_path +'/'+ modelname if save_plot else None, - silent=silent + filepath=train_path + "/" + modelname if save_plot else None, + silent=silent, + **window_kwargs(input_version), ) return model diff --git a/tests/test_nam/test_train/__init__.py b/tests/test_nam/test_train/__init__.py diff --git a/tests/test_nam/test_train/test_version.py b/tests/test_nam/test_train/test_version.py @@ -0,0 +1,22 @@ +# File: test_version.py +# Created Date: Saturday April 29th 2023 +# Author: Steven Atkinson (steven@atkinson.mn) + +from nam.train import _version + + +def test_eq(): + assert _version.Version(0, 0, 0) == _version.Version(0, 0, 0) + assert _version.Version(0, 0, 0) != _version.Version(0, 0, 1) + assert _version.Version(0, 0, 0) != _version.Version(0, 1, 0) + assert _version.Version(0, 0, 0) != _version.Version(1, 0, 0) + + +def test_lt(): + assert _version.Version(0, 0, 0) < _version.Version(0, 0, 1) + assert _version.Version(0, 0, 0) < _version.Version(0, 1, 0) + assert _version.Version(0, 0, 0) < _version.Version(1, 0, 0) + + assert _version.Version(1, 2, 3) < _version.Version(2, 0, 0) + + assert not _version.Version(1, 2, 3) < _version.Version(0, 4, 5)