neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 0d6e110f85905a1849dbc07b1412943f6dfa1984
parent d1e88bf99043c61ae8a02ba07c09e46416dad7d1
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Thu,  8 Dec 2022 18:43:46 -0800

Improvements to easy-mode Colab trainer

* Visualize the calibrated delay
* Tweak default learning rate to be a little more aggressive

Diffstat:
Mnam/train/colab.py | 74+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------
1 file changed, 63 insertions(+), 11 deletions(-)

diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -8,7 +8,7 @@ Hide the mess in Colab to make things look pretty for users. from pathlib import Path from time import time -from typing import Optional, Union +from typing import Optional import matplotlib.pyplot as plt import numpy as np @@ -45,7 +45,7 @@ def _calibrate_delay_v1() -> int: y = wav_to_np(_OUTPUT_BASENAME)[:48_000] background_level = np.max(np.abs(y[:6_000])) - trigger_threshold = background_level + 0.01 + trigger_threshold = max(background_level + 0.01, 1.01 * background_level) j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][ 0 ] @@ -59,12 +59,44 @@ def _calibrate_delay_v1() -> int: return delay -def _calibrate_delay() -> int: - print("Delay wasn't provided; attempting to calibrate automatically...") - return _calibrate_delay_v1() +def _plot_delay_v1(delay: int): + print("Plotting the delay for manual inspection...") + x = wav_to_np(_INPUT_BASENAME)[:48_000] + y = wav_to_np(_OUTPUT_BASENAME)[:48_000] + i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly + di = 20 + plt.figure() + # plt.plot(x[i - di : i + di], ".-", label="Input") + plt.plot( + np.arange(-di, di), + y[i - di + delay : i + di + delay], + ".-", + label="Output", + ) + plt.axvline(x=0, linestyle="--", color="C1") + plt.legend() + plt.show() # This doesn't freeze the notebook -def _get_configs(delay: int, epochs: int, stage_1_channels, stage_2_channels): +def _calibrate_delay(delay: Optional[int]) -> int: + calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 + if delay is not None: + print(f"Delay is specified as {delay}") + else: + print("Delay wasn't provided; attempting to calibrate automatically...") + delay = calibrate() + plot(delay) + return delay + + +def _get_configs( + delay: int, + epochs: int, + stage_1_channels: int, + stage_2_channels: int, + lr: float, + lr_decay: float, +): val_seconds = 9 train_val_split = -val_seconds * REQUIRED_RATE data_config = { @@ -109,8 +141,8 @@ def _get_configs(delay: int, epochs: int, stage_1_channels, stage_2_channels): }, }, "loss": {"val_loss": "esr"}, - "optimizer": {"lr": 0.003}, - "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.993}}, + "optimizer": {"lr": lr}, + "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}}, } learning_config = { "train_dataloader": { @@ -165,6 +197,7 @@ def _plot( plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") plt.title(f"ESR={esr:.3f}") plt.legend() + plt.show() def _get_valid_export_directory(): @@ -177,11 +210,30 @@ def _get_valid_export_directory(): return get_path(version) -def run(epochs=100, delay=None, stage_1_channels=16, stage_2_channels=8): +def run( + epochs=100, + delay=None, + stage_1_channels=16, + stage_2_channels=8, + lr=0.004, + lr_decay=0.007, + seed=0, +): + """ + :param epochs: How amny epochs we'll train for. + :param delay: How far the output algs the input due to round-trip latency during + reamping, in samples. + :param stage_1_channels: The number of channels in the WaveNet's first stage. + :param stage_2_channels: The number of channels in the WaveNet's second stage. + :param lr: The initial learning rate + :param lr_decay: The amount by which the learning rate decays each epoch + :param seed: RNG seed for reproducibility. + """ + torch.manual_seed(seed) _check_for_files() - delay = _calibrate_delay() if delay is None else delay + delay = _calibrate_delay(delay) data_config, model_config, learning_config = _get_configs( - delay, epochs, stage_1_channels, stage_2_channels + delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay ) print("Starting training. Let's rock!")