commit f5b3909ad2be15eb62d44cd43e0bbc88edc1c389
parent 6e87932a3a0f78e34c1c9bbb1924d2fe3c4e0144
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 29 Apr 2023 13:16:40 -0700
New reamping file `v2_0_0.wav` (#217)
* Support training file 2.0.0
* Version comparison and testing
* Move to latest input version
* Black
* Bump version
Diffstat:
6 files changed, 191 insertions(+), 34 deletions(-)
diff --git a/nam/_version.py b/nam/_version.py
@@ -1 +1 @@
-__version__ = "0.5.2"
+__version__ = "0.5.3"
diff --git a/nam/train/_version.py b/nam/train/_version.py
@@ -13,6 +13,9 @@ class Version:
self.minor = minor
self.patch = patch
+ def __eq__(self, other) -> bool:
+ return self.major == other.major and self.minor == other.minor and self.patch == other.patch
+
def __lt__(self, other) -> bool:
if self.major != other.major:
return self.major < other.major
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -15,7 +15,12 @@ from ._version import Version
from .core import train
-_INPUT_BASENAMES = ((Version(1, 1, 1), "v1_1_1.wav"), (Version(1, 0, 0), "v1.wav"))
+_INPUT_BASENAMES = (
+ (Version(2, 0, 0), "v2_0_0.wav"),
+ (Version(1, 1, 1), "v1_1_1.wav"),
+ (Version(1, 0, 0), "v1.wav"),
+)
+_LATEST_VERSION = _INPUT_BASENAMES[0]
_BUGGY_INPUT_BASENAMES = {
# 1.1.0 has the spikes at the wrong spots.
"v1_1_0.wav"
@@ -29,19 +34,20 @@ def _check_for_files() -> Tuple[Version, str]:
for name in _BUGGY_INPUT_BASENAMES:
if Path(name).exists():
raise RuntimeError(
- f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_INPUT_BASENAMES[0][1]}"
+ f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}"
)
- for i, (input_version, input_basename) in enumerate(_INPUT_BASENAMES):
+ for input_version, input_basename in enumerate(_INPUT_BASENAMES):
if Path(input_basename).exists():
- if i > 0:
+ if input_version != _LATEST_VERSION[0]:
print(
f"WARNING: Using out-of-date input file {input_basename}. "
- "Recommend downloading and using the latest version."
+ "Recommend downloading and using the latest version, "
+ f"{_LATEST_VERSION[1]}."
)
break
else:
raise FileNotFoundError(
- f"Didn't find NAM's input audio file. Please upload {_INPUT_BASENAMES[0][1]}"
+ f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION[1]}"
)
if not Path(_OUTPUT_BASENAME).exists():
raise FileNotFoundError(
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -9,16 +9,18 @@ Functions used by the GUI trainer.
import hashlib
from enum import Enum
from time import time
-from typing import Optional, Union
+from typing import Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
+import torch.nn as nn
from torch.utils.data import DataLoader
-from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np
+from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
+from ..models.losses import esr
from ._version import Version
@@ -45,6 +47,7 @@ def _detect_input_version(input_path) -> Version:
version = {
"4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
+ "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0),
}.get(file_hash)
if version is None:
raise RuntimeError(
@@ -55,9 +58,13 @@ def _detect_input_version(input_path) -> Version:
_V1_BLIP_LOCATIONS = 12_000, 36_000
+_V2_START_BLIP_LOCATIONS = _V1_BLIP_LOCATIONS
+_V2_END_BLIP_LOCATIONS = -36_000, -12_000
-def _calibrate_delay_v1(input_path, output_path) -> int:
+def _calibrate_delay_v1(
+ input_path, output_path, locations: Sequence[int] = _V1_BLIP_LOCATIONS
+) -> int:
lookahead = 1_000
lookback = 10_000
safety_factor = 4
@@ -68,7 +75,7 @@ def _calibrate_delay_v1(input_path, output_path) -> int:
trigger_threshold = max(background_level + 0.01, 1.01 * background_level)
delays = []
- for blip_index, i in enumerate(_V1_BLIP_LOCATIONS, 1):
+ for blip_index, i in enumerate(locations, 1):
start_looking = i - lookahead
stop_looking = i + lookback
@@ -106,6 +113,9 @@ def _calibrate_delay_v1(input_path, output_path) -> int:
return delay
+_calibrate_delay_v2 = _calibrate_delay_v1
+
+
def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True):
print("Plotting the delay for manual inspection...")
x = wav_to_np(input_path)[:48_000]
@@ -141,15 +151,20 @@ def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True):
plt.show() # This doesn't freeze the notebook
+_plot_delay_v2 = _plot_delay_v1
+
+
def _calibrate_delay(
delay: Optional[int],
input_version: Version,
input_path: str,
output_path: str,
- silent: bool=False
+ silent: bool = False,
) -> int:
if input_version.major == 1:
calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
+ elif input_version.major == 2:
+ calibrate, plot = _calibrate_delay_v2, _plot_delay_v2
else:
raise NotImplementedError(
f"Input calibration not implemented for input version {input_version}"
@@ -164,6 +179,70 @@ def _calibrate_delay(
return delay
+def _check_v1(*args, **kwargs):
+ return True
+
+
+def _check_v2(input_path, output_path) -> bool:
+ print("V2 checks...")
+ rate = REQUIRED_RATE
+ y = wav_to_tensor(output_path, rate=rate)
+ y_val_1 = y[-19 * rate : -10 * rate]
+ y_val_2 = y[-10 * rate : -1 * rate]
+ esr_replicate = esr(y_val_1, y_val_2).item()
+ print(f"Replicate ESR is {esr_replicate:.8f}.")
+ # Do the blips line up?
+ # [start/end,replicate]
+ blips = [
+ [y[: rate // 2], y[rate // 2 : rate]],
+ [y[-rate : -rate // 2], y[-rate // 2 :]],
+ ]
+ mse = nn.MSELoss()
+ mse_0 = mse(blips[0][0], blips[0][1]).item() # Within start
+ mse_1 = mse(blips[1][0], blips[1][1]).item() # Within end
+ mse_cross_0 = mse(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end
+ mse_cross_1 = mse(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end
+
+ mse_max = max(mse_0, mse_1)
+ # mse_range = mse_max - min(mse_0, mse_1)
+ safety_factor = 2.0
+ if mse_cross_0 > safety_factor * mse_max or mse_cross_1 > safety_factor * mse_max:
+ plt.plot()
+ [
+ [
+ plt.plot(b, label=f"{startend}, replicate {replicate}")
+ for replicate, b in enumerate(bi, 1)
+ ]
+ for startend, bi in zip(("start", "end"), blips)
+ ]
+ plt.xlabel("Sample")
+ plt.ylabel("Output")
+ plt.legend()
+ plt.grid()
+ print(
+ "Failed blip checks. Did something change between the start and end of reamping?"
+ )
+ plt.show()
+ return False
+ return True
+
+
+def _check(input_path: str, output_path: str, input_version: Version) -> bool:
+ """
+ Ensure that everything should go smoothly
+
+ :return: True if looks good
+ """
+ if input_version.major == 1:
+ f = _check_v1
+ elif input_version.major == 2:
+ f = _check_v2
+ else:
+ print(f"Checks not implemented for input version {input_version}; skip")
+ return True
+ return f(input_path, output_path)
+
+
def _get_wavenet_config(architecture):
return {
Architecture.STANDARD: {
@@ -251,19 +330,40 @@ def _get_wavenet_config(architecture):
def _get_configs(
+ input_version: Version,
input_basename: str,
output_basename: str,
delay: int,
epochs: int,
architecture: Architecture,
+ ny: int,
lr: float,
lr_decay: float,
+ batch_size: int,
):
- val_seconds = 9
- train_val_split = -val_seconds * REQUIRED_RATE
+ def get_kwargs():
+ val_seconds = 9
+ rate = REQUIRED_RATE
+ if input_version.major == 1:
+ train_val_split = -val_seconds * rate
+ train_kwargs = {"stop": train_val_split}
+ validation_kwargs = {"start": train_val_split}
+ elif input_version.major == 2:
+ blip_seconds = 1
+ val_replicates = 2
+ train_stop = -(blip_seconds + val_replicates * val_seconds) * rate
+ validation_start = train_stop
+ validation_stop = -(blip_seconds + val_seconds) * rate
+ train_kwargs = {"stop": train_stop}
+ validation_kwargs = {"start": validation_start, "stop": validation_stop}
+ else:
+ raise NotImplementedError(f"kwargs for input version {input_version}")
+ return train_kwargs, validation_kwargs
+
+ train_kwargs, validation_kwargs = get_kwargs()
data_config = {
- "train": {"ny": 8192, "stop": train_val_split},
- "validation": {"ny": None, "start": train_val_split},
+ "train": {"ny": ny, **train_kwargs},
+ "validation": {"ny": None, **validation_kwargs},
"common": {
"x_path": input_basename,
"y_path": output_basename,
@@ -290,7 +390,7 @@ def _get_configs(
device_config = {}
learning_config = {
"train_dataloader": {
- "batch_size": 16,
+ "batch_size": batch_size,
"shuffle": True,
"pin_memory": True,
"drop_last": True,
@@ -310,12 +410,12 @@ def _esr(pred: torch.Tensor, target: torch.Tensor) -> float:
def _plot(
- model,
- ds,
- window_start: Optional[int] = None,
- window_end: Optional[int] = None,
- filepath: Optional[str] = None,
- silent: bool = False
+ model,
+ ds,
+ window_start: Optional[int] = None,
+ window_end: Optional[int] = None,
+ filepath: Optional[str] = None,
+ silent: bool = False,
):
print("Plotting a comparison of your model with the target output...")
with torch.no_grad():
@@ -351,6 +451,7 @@ def _plot(
if not silent:
plt.show()
+
def train(
input_path: str,
output_path: str,
@@ -359,34 +460,42 @@ def train(
epochs=100,
delay=None,
architecture: Union[Architecture, str] = Architecture.STANDARD,
+ batch_size: int = 16,
+ ny: int = 8192,
lr=0.004,
lr_decay=0.007,
seed: Optional[int] = 0,
- save_plot: bool=False,
- silent: bool=False,
- modelname: str="model"
+ save_plot: bool = False,
+ silent: bool = False,
+ modelname: str = "model",
):
if seed is not None:
torch.manual_seed(seed)
- # This needs more thought...
- # 1. Does the user want me to calibrate the delay?
- # 2. Does the user want to see what the chosen (by them or me) delay looks like?
if delay is None:
if input_version is None:
input_version = _detect_input_version(input_path)
- delay = _calibrate_delay(delay, input_version, input_path, output_path, silent=silent)
+ delay = _calibrate_delay(
+ delay, input_version, input_path, output_path, silent=silent
+ )
else:
print(f"Delay provided as {delay}; skip calibration")
+ if not _check(input_path, output_path, input_version):
+ print("Failed checks; exit training")
+ return
+
data_config, model_config, learning_config = _get_configs(
+ input_version,
input_path,
output_path,
delay,
epochs,
Architecture(architecture),
+ ny,
lr,
lr_decay,
+ batch_size,
)
print("Starting training. It's time to kick ass and chew bubblegum!")
@@ -424,12 +533,29 @@ def train(
model.cpu()
model.eval()
+ def window_kwargs(version: Version):
+ if version.major == 1:
+ return dict(
+ window_start=100_000, # Start of the plotting window, in samples
+ window_end=101_000, # End of the plotting window, in samples
+ )
+ elif version.major == 2:
+ # Same validation set even though it's a different spot in the reamp file
+ return dict(
+ window_start=100_000, # Start of the plotting window, in samples
+ window_end=101_000, # End of the plotting window, in samples
+ )
+ # Fallback:
+ return dict(
+ window_start=100_000, # Start of the plotting window, in samples
+ window_end=101_000, # End of the plotting window, in samples
+ )
+
_plot(
model,
dataset_validation,
- window_start=100_000, # Start of the plotting window, in samples
- window_end=101_000, # End of the plotting window, in samples
- filepath=train_path +'/'+ modelname if save_plot else None,
- silent=silent
+ filepath=train_path + "/" + modelname if save_plot else None,
+ silent=silent,
+ **window_kwargs(input_version),
)
return model
diff --git a/tests/test_nam/test_train/__init__.py b/tests/test_nam/test_train/__init__.py
diff --git a/tests/test_nam/test_train/test_version.py b/tests/test_nam/test_train/test_version.py
@@ -0,0 +1,22 @@
+# File: test_version.py
+# Created Date: Saturday April 29th 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+from nam.train import _version
+
+
+def test_eq():
+ assert _version.Version(0, 0, 0) == _version.Version(0, 0, 0)
+ assert _version.Version(0, 0, 0) != _version.Version(0, 0, 1)
+ assert _version.Version(0, 0, 0) != _version.Version(0, 1, 0)
+ assert _version.Version(0, 0, 0) != _version.Version(1, 0, 0)
+
+
+def test_lt():
+ assert _version.Version(0, 0, 0) < _version.Version(0, 0, 1)
+ assert _version.Version(0, 0, 0) < _version.Version(0, 1, 0)
+ assert _version.Version(0, 0, 0) < _version.Version(1, 0, 0)
+
+ assert _version.Version(1, 2, 3) < _version.Version(2, 0, 0)
+
+ assert not _version.Version(1, 2, 3) < _version.Version(0, 4, 5)