commit 0d6e110f85905a1849dbc07b1412943f6dfa1984
parent d1e88bf99043c61ae8a02ba07c09e46416dad7d1
Author: Steven Atkinson <steven@atkinson.mn>
Date: Thu, 8 Dec 2022 18:43:46 -0800
Improvements to easy-mode Colab trainer
* Visualize the calibrated delay
* Tweak default learning rate to be a little more aggressive
Diffstat:
M | nam/train/colab.py | | | 74 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------- |
1 file changed, 63 insertions(+), 11 deletions(-)
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -8,7 +8,7 @@ Hide the mess in Colab to make things look pretty for users.
from pathlib import Path
from time import time
-from typing import Optional, Union
+from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
@@ -45,7 +45,7 @@ def _calibrate_delay_v1() -> int:
y = wav_to_np(_OUTPUT_BASENAME)[:48_000]
background_level = np.max(np.abs(y[:6_000]))
- trigger_threshold = background_level + 0.01
+ trigger_threshold = max(background_level + 0.01, 1.01 * background_level)
j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][
0
]
@@ -59,12 +59,44 @@ def _calibrate_delay_v1() -> int:
return delay
-def _calibrate_delay() -> int:
- print("Delay wasn't provided; attempting to calibrate automatically...")
- return _calibrate_delay_v1()
+def _plot_delay_v1(delay: int):
+ print("Plotting the delay for manual inspection...")
+ x = wav_to_np(_INPUT_BASENAME)[:48_000]
+ y = wav_to_np(_OUTPUT_BASENAME)[:48_000]
+ i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly
+ di = 20
+ plt.figure()
+ # plt.plot(x[i - di : i + di], ".-", label="Input")
+ plt.plot(
+ np.arange(-di, di),
+ y[i - di + delay : i + di + delay],
+ ".-",
+ label="Output",
+ )
+ plt.axvline(x=0, linestyle="--", color="C1")
+ plt.legend()
+ plt.show() # This doesn't freeze the notebook
-def _get_configs(delay: int, epochs: int, stage_1_channels, stage_2_channels):
+def _calibrate_delay(delay: Optional[int]) -> int:
+ calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
+ if delay is not None:
+ print(f"Delay is specified as {delay}")
+ else:
+ print("Delay wasn't provided; attempting to calibrate automatically...")
+ delay = calibrate()
+ plot(delay)
+ return delay
+
+
+def _get_configs(
+ delay: int,
+ epochs: int,
+ stage_1_channels: int,
+ stage_2_channels: int,
+ lr: float,
+ lr_decay: float,
+):
val_seconds = 9
train_val_split = -val_seconds * REQUIRED_RATE
data_config = {
@@ -109,8 +141,8 @@ def _get_configs(delay: int, epochs: int, stage_1_channels, stage_2_channels):
},
},
"loss": {"val_loss": "esr"},
- "optimizer": {"lr": 0.003},
- "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.993}},
+ "optimizer": {"lr": lr},
+ "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}},
}
learning_config = {
"train_dataloader": {
@@ -165,6 +197,7 @@ def _plot(
plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
plt.title(f"ESR={esr:.3f}")
plt.legend()
+ plt.show()
def _get_valid_export_directory():
@@ -177,11 +210,30 @@ def _get_valid_export_directory():
return get_path(version)
-def run(epochs=100, delay=None, stage_1_channels=16, stage_2_channels=8):
+def run(
+ epochs=100,
+ delay=None,
+ stage_1_channels=16,
+ stage_2_channels=8,
+ lr=0.004,
+ lr_decay=0.007,
+ seed=0,
+):
+ """
+ :param epochs: How amny epochs we'll train for.
+ :param delay: How far the output algs the input due to round-trip latency during
+ reamping, in samples.
+ :param stage_1_channels: The number of channels in the WaveNet's first stage.
+ :param stage_2_channels: The number of channels in the WaveNet's second stage.
+ :param lr: The initial learning rate
+ :param lr_decay: The amount by which the learning rate decays each epoch
+ :param seed: RNG seed for reproducibility.
+ """
+ torch.manual_seed(seed)
_check_for_files()
- delay = _calibrate_delay() if delay is None else delay
+ delay = _calibrate_delay(delay)
data_config, model_config, learning_config = _get_configs(
- delay, epochs, stage_1_channels, stage_2_channels
+ delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay
)
print("Starting training. Let's rock!")