neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

core.py (60431B)


      1 # File: core.py
      2 # Created Date: Tuesday December 20th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 The core of the "simplified trainer"
      7 
      8 Used by the GUI and Colab trainers.
      9 """
     10 
     11 import hashlib as _hashlib
     12 import tkinter as _tk
     13 from copy import deepcopy as _deepcopy
     14 from enum import Enum as _Enum
     15 from functools import partial as _partial
     16 from pathlib import Path as _Path
     17 from time import time as _time
     18 from typing import (
     19     Dict as _Dict,
     20     NamedTuple as _NamedTuple,
     21     Optional as _Optional,
     22     Sequence as _Sequence,
     23     Tuple as _Tuple,
     24     Union as _Union,
     25 )
     26 
     27 import matplotlib.pyplot as _plt
     28 import numpy as _np
     29 import pytorch_lightning as _pl
     30 import torch as _torch
     31 from pydantic import BaseModel as _BaseModel
     32 from pytorch_lightning.utilities.warnings import (
     33     PossibleUserWarning as _PossibleUserWarning,
     34 )
     35 from torch.utils.data import DataLoader as _DataLoader
     36 
     37 from ..data import (
     38     DataError as _DataError,
     39     Split as _Split,
     40     init_dataset as _init_dataset,
     41     wav_to_np as _wav_to_np,
     42     wav_to_tensor as _wav_to_tensor,
     43 )
     44 from ..models.exportable import Exportable as _Exportable
     45 from ..models.losses import esr as _ESR
     46 from ..models.metadata import UserMetadata as _UserMetadata
     47 from ..util import filter_warnings as _filter_warnings
     48 from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version
     49 from .lightning_module import LightningModule as _LightningModule
     50 from . import metadata as _metadata
     51 
     52 # Training using the simplified trainers in NAM is done at 48k.
     53 STANDARD_SAMPLE_RATE = 48_000.0
     54 # Default number of output samples per datum.
     55 _NY_DEFAULT = 8192
     56 
     57 
     58 class Architecture(_Enum):
     59     STANDARD = "standard"
     60     LITE = "lite"
     61     FEATHER = "feather"
     62     NANO = "nano"
     63 
     64 
     65 class _InputValidationError(ValueError):
     66     pass
     67 
     68 
     69 def _detect_input_version(input_path) -> _Tuple[_Version, bool]:
     70     """
     71     Check to see if the input matches any of the known inputs
     72 
     73     :return: version, strong match
     74     """
     75 
     76     def detect_strong(input_path) -> _Optional[_Version]:
     77         def assign_hash(path):
     78             # Use this to create hashes for new files
     79             md5 = _hashlib.md5()
     80             buffer_size = 65536
     81             with open(path, "rb") as f:
     82                 while True:
     83                     data = f.read(buffer_size)
     84                     if not data:
     85                         break
     86                     md5.update(data)
     87             file_hash = md5.hexdigest()
     88             return file_hash
     89 
     90         file_hash = assign_hash(input_path)
     91         print(f"Strong hash: {file_hash}")
     92 
     93         version = {
     94             "4d54a958861bf720ec4637f43d44a7ef": _Version(1, 0, 0),
     95             "7c3b6119c74465f79d96c761a0e27370": _Version(1, 1, 1),
     96             "ede3b9d82135ce10c7ace3bb27469422": _Version(2, 0, 0),
     97             "36cd1af62985c2fac3e654333e36431e": _Version(3, 0, 0),
     98             "80e224bd5622fd6153ff1fd9f34cb3bd": _PROTEUS_VERSION,
     99         }.get(file_hash)
    100         if version is None:
    101             print(
    102                 f"Provided input file {input_path} does not strong-match any known "
    103                 "standard input files."
    104             )
    105         return version
    106 
    107     def detect_weak(input_path) -> _Optional[_Version]:
    108         def assign_hash(path):
    109             Hash = _Optional[str]
    110             Hashes = _Tuple[Hash, Hash]
    111 
    112             def _hash(x: _np.ndarray) -> str:
    113                 return _hashlib.md5(x).hexdigest()
    114 
    115             def assign_hashes_v1(path) -> Hashes:
    116                 # Use this to create recognized hashes for new files
    117                 x, info = _wav_to_np(path, info=True)
    118                 rate = info.rate
    119                 if rate != _V1_DATA_INFO.rate:
    120                     return None, None
    121                 # Times of intervals, in seconds
    122                 t_blips = _V1_DATA_INFO.t_blips
    123                 t_sweep = 3 * rate
    124                 t_white = 3 * rate
    125                 t_validation = _V1_DATA_INFO.t_validate
    126                 # v1 and v2 start with 1 blips, sine sweeps, and white noise
    127                 start_hash = _hash(x[: t_blips + t_sweep + t_white])
    128                 # v1 ends with validation signal
    129                 end_hash = _hash(x[-t_validation:])
    130                 return start_hash, end_hash
    131 
    132             def assign_hashes_v2(path) -> Hashes:
    133                 # Use this to create recognized hashes for new files
    134                 x, info = _wav_to_np(path, info=True)
    135                 rate = info.rate
    136                 if rate != _V2_DATA_INFO.rate:
    137                     return None, None
    138                 # Times of intervals, in seconds
    139                 t_blips = _V2_DATA_INFO.t_blips
    140                 t_sweep = 3 * rate
    141                 t_white = 3 * rate
    142                 t_validation = _V1_DATA_INFO.t_validate
    143                 # v1 and v2 start with 1 blips, sine sweeps, and white noise
    144                 start_hash = _hash(x[: (t_blips + t_sweep + t_white)])
    145                 # v2 ends with 2x validation & blips
    146                 end_hash = _hash(x[-(2 * t_validation + t_blips) :])
    147                 return start_hash, end_hash
    148 
    149             def assign_hashes_v3(path) -> Hashes:
    150                 # Use this to create recognized hashes for new files
    151                 x, info = _wav_to_np(path, info=True)
    152                 rate = info.rate
    153                 if rate != _V3_DATA_INFO.rate:
    154                     return None, None
    155                 # Times of intervals, in seconds
    156                 # See below.
    157                 end_of_start_interval = 17 * rate  # Start at 0
    158                 start_of_end_interval = -9 * rate
    159                 start_hash = _hash(x[:end_of_start_interval])
    160                 end_hash = _hash(x[start_of_end_interval:])
    161                 return start_hash, end_hash
    162 
    163             def assign_hash_v4(path) -> Hash:
    164                 # Use this to create recognized hashes for new files
    165                 x, info = _wav_to_np(path, info=True)
    166                 rate = info.rate
    167                 if rate != _V4_DATA_INFO.rate:
    168                     return None
    169                 # I don't care about anything in the file except the starting blip and
    170                 start_hash = _hash(x[: int(1 * _V4_DATA_INFO.rate)])
    171                 return start_hash
    172 
    173             start_hash_v1, end_hash_v1 = assign_hashes_v1(path)
    174             start_hash_v2, end_hash_v2 = assign_hashes_v2(path)
    175             start_hash_v3, end_hash_v3 = assign_hashes_v3(path)
    176             hash_v4 = assign_hash_v4(path)
    177             return (
    178                 start_hash_v1,
    179                 end_hash_v1,
    180                 start_hash_v2,
    181                 end_hash_v2,
    182                 start_hash_v3,
    183                 end_hash_v3,
    184                 hash_v4,
    185             )
    186 
    187         (
    188             start_hash_v1,
    189             end_hash_v1,
    190             start_hash_v2,
    191             end_hash_v2,
    192             start_hash_v3,
    193             end_hash_v3,
    194             hash_v4,
    195         ) = assign_hash(input_path)
    196         print(
    197             "Weak hashes:\n"
    198             f" Start (v1) : {start_hash_v1}\n"
    199             f" End (v1)   : {end_hash_v1}\n"
    200             f" Start (v2) : {start_hash_v2}\n"
    201             f" End (v2)   : {end_hash_v2}\n"
    202             f" Start (v3) : {start_hash_v3}\n"
    203             f" End (v3)   : {end_hash_v3}\n"
    204             f" Proteus    : {hash_v4}\n"
    205         )
    206 
    207         # Check for matches, starting with most recent. Proteus last since its match is
    208         # the most permissive.
    209         version = {
    210             (
    211                 "dadb5d62f6c3973a59bf01439799809b",
    212                 "8458126969a3f9d8e19a53554eb1fd52",
    213             ): _Version(3, 0, 0)
    214         }.get((start_hash_v3, end_hash_v3))
    215         if version is not None:
    216             return version
    217         version = {
    218             (
    219                 "1c4d94fbcb47e4d820bef611c1d4ae65",
    220                 "28694e7bf9ab3f8ae6ef86e9545d4663",
    221             ): _Version(2, 0, 0)
    222         }.get((start_hash_v2, end_hash_v2))
    223         if version is not None:
    224             return version
    225         version = {
    226             (
    227                 "bb4e140c9299bae67560d280917eb52b",
    228                 "9b2468fcb6e9460a399fc5f64389d353",
    229             ): _Version(
    230                 1, 0, 0
    231             ),  # FIXME!
    232             (
    233                 "9f20c6b5f7fef68dd88307625a573a14",
    234                 "8458126969a3f9d8e19a53554eb1fd52",
    235             ): _Version(1, 1, 1),
    236         }.get((start_hash_v1, end_hash_v1))
    237         if version is not None:
    238             return version
    239         version = {"46151c8030798081acc00a725325a07d": _PROTEUS_VERSION}.get(hash_v4)
    240         return version
    241 
    242     version = detect_strong(input_path)
    243     if version is not None:
    244         strong_match = True
    245         return version, strong_match
    246     print("Falling back to weak-matching...")
    247     version = detect_weak(input_path)
    248     if version is None:
    249         raise _InputValidationError(
    250             f"Input file at {input_path} cannot be recognized as any known version!"
    251         )
    252     strong_match = False
    253 
    254     return version, strong_match
    255 
    256 
    257 class _DataInfo(_BaseModel):
    258     """
    259     :param major_version: Data major version
    260     """
    261 
    262     major_version: int
    263     rate: _Optional[float]
    264     t_blips: int
    265     first_blips_start: int
    266     t_validate: int
    267     train_start: int
    268     validation_start: int
    269     noise_interval: _Tuple[int, int]
    270     blip_locations: _Sequence[_Sequence[int]]
    271 
    272 
    273 _V1_DATA_INFO = _DataInfo(
    274     major_version=1,
    275     rate=STANDARD_SAMPLE_RATE,
    276     t_blips=48_000,
    277     first_blips_start=0,
    278     t_validate=432_000,
    279     train_start=0,
    280     validation_start=-432_000,
    281     noise_interval=(0, 6000),
    282     blip_locations=((12_000, 36_000),),
    283 )
    284 # V2:
    285 # (0:00-0:02) Blips at 0:00.5 and 0:01.5
    286 # (0:02-0:05) Chirps
    287 # (0:05-0:07) Noise
    288 # (0:07-2:50.5) General training data
    289 # (2:50.5-2:51) Silence
    290 # (2:51-3:00) Validation 1
    291 # (3:00-3:09) Validation 2
    292 # (3:09-3:11) Blips at 3:09.5 and 3:10.5
    293 _V2_DATA_INFO = _DataInfo(
    294     major_version=2,
    295     rate=STANDARD_SAMPLE_RATE,
    296     t_blips=96_000,
    297     first_blips_start=0,
    298     t_validate=432_000,
    299     train_start=0,
    300     validation_start=-960_000,  # 96_000 + 2 * 432_000
    301     noise_interval=(12_000, 18_000),
    302     blip_locations=((24_000, 72_000), (-72_000, -24_000)),
    303 )
    304 # V3:
    305 # (0:00-0:09) Validation 1
    306 # (0:09-0:10) Silence
    307 # (0:10-0:12) Blips at 0:10.5 and 0:11.5
    308 # (0:12-0:15) Chirps
    309 # (0:15-0:17) Noise
    310 # (0:17-3:00.5) General training data
    311 # (3:00.5-3:01) Silence
    312 # (3:01-3:10) Validation 2
    313 _V3_DATA_INFO = _DataInfo(
    314     major_version=3,
    315     rate=STANDARD_SAMPLE_RATE,
    316     t_blips=96_000,
    317     first_blips_start=480_000,
    318     t_validate=432_000,
    319     train_start=480_000,
    320     validation_start=-432_000,
    321     noise_interval=(492_000, 498_000),
    322     blip_locations=((504_000, 552_000),),
    323 )
    324 # V4 (aka GuitarML Proteus)
    325 # https://github.com/GuitarML/Releases/releases/download/v1.0.0/Proteus_Capture_Utility.zip
    326 # * 44.1k
    327 # * Odd length...
    328 # * There's a blip on sample zero. This has to be ignored or else over-compensated
    329 #   latencies will come out wrong!
    330 # (0:00-0:01) Blips at 0:00.0 and 0:00.5
    331 # (0:01-0:09) Sine sweeps
    332 # (0:09-0:17) White noise
    333 # (0:17:0.20) Rising white noise (to 0:20.333 appx)
    334 # (0:20-3:30.858) General training data (ends on sample 9,298,872)
    335 # I'm arbitrarily assigning the last 10 seconds as validation data.
    336 _V4_DATA_INFO = _DataInfo(
    337     major_version=4,
    338     rate=44_100.0,
    339     t_blips=44_099,  # Need to ignore the first blip!
    340     first_blips_start=1,  # Need to ignore the first blip!
    341     t_validate=441_000,
    342     # Blips are problematic for training because they don't have preceding silence
    343     train_start=44_100,
    344     validation_start=-441_000,
    345     noise_interval=(6_000, 12_000),
    346     blip_locations=((22_050,),),
    347 )
    348 
    349 _DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
    350 _DELAY_CALIBRATION_REL_THRESHOLD = 0.001
    351 _DELAY_CALIBRATION_SAFETY_FACTOR = 1  # Might be able to make this zero...
    352 
    353 
    354 def _warn_lookaheads(indices: _Sequence[int]) -> str:
    355     return (
    356         f"WARNING: delays from some blips ({','.join([str(i) for i in indices])}) are "
    357         "at the minimum value possible. This usually means that something is "
    358         "wrong with your data. Check if trianing ends with a poor result!"
    359     )
    360 
    361 
    362 def _calibrate_latency_v_all(
    363     data_info: _DataInfo,
    364     y,
    365     abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD,
    366     rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD,
    367     safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR,
    368 ) -> _metadata.LatencyCalibration:
    369     """
    370     Calibrate the delay in teh input-output pair based on blips.
    371     This only uses the blips in the first set of blip locations!
    372 
    373     :param y: The output audio, in complete.
    374     """
    375 
    376     def report_any_latency_warnings(
    377         delays: _Sequence[int],
    378     ) -> _metadata.LatencyCalibrationWarnings:
    379         # Warnings associated with any single delay:
    380 
    381         # "Lookahead warning": if the delay is equal to the lookahead, then it's
    382         # probably an error.
    383         lookahead_warnings = [i for i, d in enumerate(delays, 1) if d == -lookahead]
    384         matches_lookahead = len(lookahead_warnings) > 0
    385         if matches_lookahead:
    386             print(_warn_lookaheads(lookahead_warnings))
    387 
    388         # Ensemble warnings
    389 
    390         # If they're _really_ different, then something might be wrong.
    391         max_disagreement_threshold = 20
    392         max_disagreement_too_high = (
    393             _np.max(delays) - _np.min(delays) >= max_disagreement_threshold
    394         )
    395         if max_disagreement_too_high:
    396             print(
    397                 "WARNING: Latencies are anomalously different from each other (more "
    398                 f"than {max_disagreement_threshold} samples). If this model turns out "
    399                 "badly, then you might need to provide the latency manually."
    400             )
    401 
    402         return _metadata.LatencyCalibrationWarnings(
    403             matches_lookahead=matches_lookahead,
    404             disagreement_too_high=max_disagreement_too_high,
    405         )
    406 
    407     lookahead = 1_000
    408     lookback = 10_000
    409     # Calibrate the level for the trigger:
    410     y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips]
    411     background_level = _np.max(
    412         _np.abs(
    413             y[
    414                 data_info.noise_interval[0]
    415                 - data_info.first_blips_start : data_info.noise_interval[1]
    416                 - data_info.first_blips_start
    417             ]
    418         )
    419     )
    420     trigger_threshold = max(
    421         background_level + abs_threshold,
    422         (1.0 + rel_threshold) * background_level,
    423     )
    424 
    425     y_scans = []
    426     for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1):
    427         # Relative to start of the data
    428         i_rel = i_abs - data_info.first_blips_start
    429         start_looking = i_rel - lookahead
    430         stop_looking = i_rel + lookback
    431         y_scans.append(y[start_looking:stop_looking])
    432     y_scan_average = _np.mean(_np.stack(y_scans), axis=0)
    433     triggered = _np.where(_np.abs(y_scan_average) > trigger_threshold)[0]
    434     if len(triggered) == 0:
    435         msg = (
    436             "No response activated the trigger in response to input spikes. "
    437             "Is something wrong with the reamp?"
    438         )
    439         print(msg)
    440         print("SHARE THIS PLOT IF YOU ASK FOR HELP")
    441         _plt.figure()
    442         _plt.plot(
    443             _np.arange(-lookahead, lookback),
    444             y_scan_average,
    445             color="C0",
    446             label="Signal average",
    447         )
    448         for y_scan in y_scans:
    449             _plt.plot(_np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2)
    450         _plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
    451         _plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold")
    452         _plt.axhline(y=trigger_threshold, color="k", linestyle="--")
    453         _plt.xlim((-lookahead, lookback))
    454         _plt.xlabel("Samples")
    455         _plt.ylabel("Response")
    456         _plt.legend()
    457         _plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP")
    458         _plt.show()
    459         raise RuntimeError(msg)
    460     else:
    461         j = triggered[0]
    462         delay = j + start_looking - i_rel
    463 
    464     print(f"Delay based on average is {delay}")
    465     warnings = report_any_latency_warnings([delay])
    466 
    467     delay_post_safety_factor = delay - safety_factor
    468     print(
    469         f"After aplying safety factor of {safety_factor}, the final delay is "
    470         f"{delay_post_safety_factor}"
    471     )
    472     return _metadata.LatencyCalibration(
    473         algorithm_version=1,
    474         delays=[delay],
    475         safety_factor=safety_factor,
    476         recommended=delay_post_safety_factor,
    477         warnings=warnings,
    478     )
    479 
    480 
    481 _calibrate_latency_v1 = _partial(_calibrate_latency_v_all, _V1_DATA_INFO)
    482 _calibrate_latency_v2 = _partial(_calibrate_latency_v_all, _V2_DATA_INFO)
    483 _calibrate_latency_v3 = _partial(_calibrate_latency_v_all, _V3_DATA_INFO)
    484 _calibrate_latency_v4 = _partial(_calibrate_latency_v_all, _V4_DATA_INFO)
    485 
    486 
    487 def _plot_latency_v_all(
    488     data_info: _DataInfo, latency: int, input_path: str, output_path: str, _nofail=True
    489 ):
    490     print("Plotting the latency for manual inspection...")
    491     x = _wav_to_np(input_path)[
    492         data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips
    493     ]
    494     y = _wav_to_np(output_path)[
    495         data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips
    496     ]
    497     # Only get the blips we really want.
    498     i = _np.where(_np.abs(x) > 0.5 * _np.abs(x).max())[0]
    499     if len(i) == 0:
    500         print("Failed to find the spike in the input file.")
    501         print(
    502             "Plotting the input and output; there should be spikes at around the "
    503             "marked locations."
    504         )
    505         t = _np.arange(
    506             data_info.first_blips_start, data_info.first_blips_start + data_info.t_blips
    507         )
    508         expected_spikes = data_info.blip_locations[0]  # For v1 specifically
    509         fig, axs = _plt.subplots(len((x, y)), 1)
    510         for ax, curve in zip(axs, (x, y)):
    511             ax.plot(t, curve)
    512             [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes]
    513         _plt.show()
    514         if _nofail:
    515             raise RuntimeError("Failed to plot delay")
    516     else:
    517         _plt.figure()
    518         di = 20
    519         # V1's got not a spike but a longer plateau; take the front of it.
    520         if data_info.major_version == 1:
    521             i = [i[0]]
    522         for e, ii in enumerate(i, 1):
    523             _plt.plot(
    524                 _np.arange(-di, di),
    525                 y[ii - di + latency : ii + di + latency],
    526                 ".-",
    527                 label=f"Output {e}",
    528             )
    529         _plt.axvline(x=0, linestyle="--", color="k")
    530         _plt.legend()
    531         _plt.show()  # This doesn't freeze the notebook
    532 
    533 
    534 _plot_latency_v1 = _partial(_plot_latency_v_all, _V1_DATA_INFO)
    535 _plot_latency_v2 = _partial(_plot_latency_v_all, _V2_DATA_INFO)
    536 _plot_latency_v3 = _partial(_plot_latency_v_all, _V3_DATA_INFO)
    537 _plot_latency_v4 = _partial(_plot_latency_v_all, _V4_DATA_INFO)
    538 
    539 
    540 def _analyze_latency(
    541     user_latency: _Optional[int],
    542     input_version: _Version,
    543     input_path: str,
    544     output_path: str,
    545     silent: bool = False,
    546 ) -> _metadata.Latency:
    547     """
    548     :param is_proteus: Forget the version; d
    549     """
    550     if input_version.major == 1:
    551         calibrate, plot = _calibrate_latency_v1, _plot_latency_v1
    552     elif input_version.major == 2:
    553         calibrate, plot = _calibrate_latency_v2, _plot_latency_v2
    554     elif input_version.major == 3:
    555         calibrate, plot = _calibrate_latency_v3, _plot_latency_v3
    556     elif input_version.major == 4:
    557         calibrate, plot = _calibrate_latency_v4, _plot_latency_v4
    558     else:
    559         raise NotImplementedError(
    560             f"Input calibration not implemented for input version {input_version}"
    561         )
    562     if user_latency is not None:
    563         print(f"Delay is specified as {user_latency}")
    564     calibration_output = calibrate(_wav_to_np(output_path))
    565     latency = (
    566         user_latency if user_latency is not None else calibration_output.recommended
    567     )
    568     if not silent:
    569         plot(latency, input_path, output_path)
    570 
    571     return _metadata.Latency(manual=user_latency, calibration=calibration_output)
    572 
    573 
    574 def get_lstm_config(architecture):
    575     return {
    576         Architecture.STANDARD: {
    577             "num_layers": 1,
    578             "hidden_size": 24,
    579             "train_burn_in": 4096,
    580             "train_truncate": 512,
    581         },
    582         Architecture.LITE: {
    583             "num_layers": 2,
    584             "hidden_size": 8,
    585             "train_burn_in": 4096,
    586             "train_truncate": 512,
    587         },
    588         Architecture.FEATHER: {
    589             "num_layers": 1,
    590             "hidden_size": 16,
    591             "train_burn_in": 4096,
    592             "train_truncate": 512,
    593         },
    594         Architecture.NANO: {
    595             "num_layers": 1,
    596             "hidden_size": 12,
    597             "train_burn_in": 4096,
    598             "train_truncate": 512,
    599         },
    600     }[architecture]
    601 
    602 
    603 def _check_v1(*args, **kwargs) -> _metadata.DataChecks:
    604     return _metadata.DataChecks(version=1, passed=True)
    605 
    606 
    607 def _esr_validation_replicate_msg(threshold: float) -> str:
    608     return (
    609         f"Validation replicates have a self-ESR of over {threshold}. "
    610         "Your gear doesn't sound like itself when played twice!\n\n"
    611         "Possible causes:\n"
    612         " * Your signal chain is too noisy.\n"
    613         " * There's a time-based effect (chorus, delay, reverb) turned on.\n"
    614         " * Some knob got moved while reamping.\n"
    615         " * You started reamping before the amp had time to warm up fully."
    616     )
    617 
    618 
    619 def _check_v2(
    620     input_path, output_path, delay: int, silent: bool
    621 ) -> _metadata.DataChecks:
    622     with _torch.no_grad():
    623         print("V2 checks...")
    624         rate = _V2_DATA_INFO.rate
    625         y = _wav_to_tensor(output_path, rate=rate)
    626         t_blips = _V2_DATA_INFO.t_blips
    627         t_validate = _V2_DATA_INFO.t_validate
    628         y_val_1 = y[-(t_blips + 2 * t_validate) : -(t_blips + t_validate)]
    629         y_val_2 = y[-(t_blips + t_validate) : -t_blips]
    630         esr_replicate = _ESR(y_val_1, y_val_2).item()
    631         print(f"Replicate ESR is {esr_replicate:.8f}.")
    632         esr_replicate_threshold = 0.01
    633         if esr_replicate > esr_replicate_threshold:
    634             print(_esr_validation_replicate_msg(esr_replicate_threshold))
    635 
    636         # Do the blips line up?
    637         # If the ESR is too bad, then flag it.
    638         print("Checking blips...")
    639 
    640         def get_blips(y):
    641             """
    642             :return: [start/end,replicate]
    643             """
    644             i0, i1 = _V2_DATA_INFO.blip_locations[0]
    645             j0, j1 = _V2_DATA_INFO.blip_locations[1]
    646 
    647             i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)]
    648             start = -10
    649             end = 1000
    650             blips = _torch.stack(
    651                 [
    652                     _torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]),
    653                     _torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]),
    654                 ]
    655             )
    656             return blips
    657 
    658         blips = get_blips(y)
    659         esr_0 = _ESR(blips[0][0], blips[0][1]).item()  # Within start
    660         esr_1 = _ESR(blips[1][0], blips[1][1]).item()  # Within end
    661         esr_cross_0 = _ESR(blips[0][0], blips[1][0]).item()  # 1st repeat, start vs end
    662         esr_cross_1 = _ESR(blips[0][1], blips[1][1]).item()  # 2nd repeat, start vs end
    663 
    664         print("  ESRs:")
    665         print(f"    Start     : {esr_0}")
    666         print(f"    End       : {esr_1}")
    667         print(f"    Cross (1) : {esr_cross_0}")
    668         print(f"    Cross (2) : {esr_cross_1}")
    669 
    670         esr_threshold = 1.0e-2
    671 
    672         def plot_esr_blip_error(
    673             show_plot: bool,
    674             msg: str,
    675             arrays: _Sequence[_Sequence[float]],
    676             labels: _Sequence[str],
    677         ):
    678             """
    679             :param silent: Whether to make and show a plot about it
    680             """
    681             if show_plot:
    682                 _plt.figure()
    683                 [_plt.plot(array, label=label) for array, label in zip(arrays, labels)]
    684                 _plt.xlabel("Sample")
    685                 _plt.ylabel("Output")
    686                 _plt.legend()
    687                 _plt.grid()
    688             print(msg)
    689             if show_plot:
    690                 _plt.show()
    691             print(
    692                 "This is known to be a very sensitive test, so training will continue. "
    693                 "If the model doesn't look good, then this may be why!"
    694             )
    695 
    696         # Check consecutive blips
    697         show_blip_plots = False
    698         for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")):
    699             if e >= esr_threshold:
    700                 plot_esr_blip_error(
    701                     show_blip_plots,
    702                     f"Failed consecutive blip check at {when} of training signal. The "
    703                     "target tone doesn't seem to be replicable over short timespans."
    704                     "\n\n"
    705                     "  Possible causes:\n\n"
    706                     "    * Your recording setup is really noisy.\n"
    707                     "    * There's a noise gate that's messing things up.\n"
    708                     "    * There's a time-based effect (chorus, delay, reverb) in "
    709                     "the signal chain",
    710                     blip_pair,
    711                     ("Replicate 1", "Replicate 2"),
    712                 )
    713                 return _metadata.DataChecks(version=2, passed=False)
    714         # Check blips between start & end of train signal
    715         for e, blip_pair, replicate in zip(
    716             (esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2)
    717         ):
    718             if e >= esr_threshold:
    719                 plot_esr_blip_error(
    720                     show_blip_plots,
    721                     f"Failed start-to-end blip check for blip replicate {replicate}. "
    722                     "The target tone doesn't seem to be same at the end of the reamp "
    723                     "as it was at the start. Did some setting change during reamping?",
    724                     blip_pair,
    725                     (f"Start, replicate {replicate}", f"End, replicate {replicate}"),
    726                 )
    727                 return _metadata.DataChecks(version=2, passed=False)
    728         return _metadata.DataChecks(version=2, passed=True)
    729 
    730 
    731 def _check_v3(
    732     input_path, output_path, silent: bool, *args, **kwargs
    733 ) -> _metadata.DataChecks:
    734     with _torch.no_grad():
    735         print("V3 checks...")
    736         rate = _V3_DATA_INFO.rate
    737         y = _wav_to_tensor(output_path, rate=rate)
    738         n = len(_wav_to_tensor(input_path))  # to End-crop output
    739         y_val_1 = y[: _V3_DATA_INFO.t_validate]
    740         y_val_2 = y[n - _V3_DATA_INFO.t_validate : n]
    741         esr_replicate = _ESR(y_val_1, y_val_2).item()
    742         print(f"Replicate ESR is {esr_replicate:.8f}.")
    743         esr_replicate_threshold = 0.01
    744         if esr_replicate > esr_replicate_threshold:
    745             print(_esr_validation_replicate_msg(esr_replicate_threshold))
    746             if not silent:
    747                 _plt.figure()
    748                 t = _np.arange(len(y_val_1)) / rate
    749                 _plt.plot(t, y_val_1, label="Validation 1")
    750                 _plt.plot(t, y_val_2, label="Validation 2")
    751                 _plt.xlabel("Time (sec)")
    752                 _plt.legend()
    753                 _plt.title("V3 check: Validation replicate FAILURE")
    754                 _plt.show()
    755             return _metadata.DataChecks(version=3, passed=False)
    756     return _metadata.DataChecks(version=3, passed=True)
    757 
    758 
    759 def _check_v4(
    760     input_path, output_path, silent: bool, *args, **kwargs
    761 ) -> _metadata.DataChecks:
    762     # Things we can't check:
    763     # Latency compensation agreement
    764     # Data replicability
    765     print("Using Proteus audio file. Standard data checks aren't possible!")
    766     signal, info = _wav_to_np(output_path, info=True)
    767     passed = True
    768     if info.rate != _V4_DATA_INFO.rate:
    769         print(
    770             f"Output signal has sample rate {info.rate}; expected {_V4_DATA_INFO.rate}!"
    771         )
    772         passed = False
    773     # I don't care what's in the files except that they're long enough to hold the blip
    774     # and the last 10 seconds I decided to use as validation
    775     required_length = int((1.0 + 10.0) * _V4_DATA_INFO.rate)
    776     if len(signal) < required_length:
    777         print(
    778             "File doesn't meet the minimum length requirements for latency compensation and validation signal!"
    779         )
    780         passed = False
    781     return _metadata.DataChecks(version=4, passed=passed)
    782 
    783 
    784 def _check_data(
    785     input_path: str, output_path: str, input_version: _Version, delay: int, silent: bool
    786 ) -> _Optional[_metadata.DataChecks]:
    787     """
    788     Ensure that everything should go smoothly
    789 
    790     :return: True if looks good
    791     """
    792     if input_version.major == 1:
    793         f = _check_v1
    794     elif input_version.major == 2:
    795         f = _check_v2
    796     elif input_version.major == 3:
    797         f = _check_v3
    798     elif input_version.major == 4:
    799         f = _check_v4
    800     else:
    801         print(f"Checks not implemented for input version {input_version}; skip")
    802         return None
    803     out = f(input_path, output_path, delay, silent)
    804     # Issue 442: Deprecate inputs
    805     if input_version.major != 3:
    806         print(
    807             f"Input version {input_version} is deprecated and will be removed in "
    808             "version 0.11 of the trainer. To continue using it, you must ignore checks."
    809         )
    810         out.passed = False
    811     return out
    812 
    813 
    814 def get_wavenet_config(architecture):
    815     return {
    816         Architecture.STANDARD: {
    817             "layers_configs": [
    818                 {
    819                     "input_size": 1,
    820                     "condition_size": 1,
    821                     "channels": 16,
    822                     "head_size": 8,
    823                     "kernel_size": 3,
    824                     "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
    825                     "activation": "Tanh",
    826                     "gated": False,
    827                     "head_bias": False,
    828                 },
    829                 {
    830                     "condition_size": 1,
    831                     "input_size": 16,
    832                     "channels": 8,
    833                     "head_size": 1,
    834                     "kernel_size": 3,
    835                     "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
    836                     "activation": "Tanh",
    837                     "gated": False,
    838                     "head_bias": True,
    839                 },
    840             ],
    841             "head_scale": 0.02,
    842         },
    843         Architecture.LITE: {
    844             "layers_configs": [
    845                 {
    846                     "input_size": 1,
    847                     "condition_size": 1,
    848                     "channels": 12,
    849                     "head_size": 6,
    850                     "kernel_size": 3,
    851                     "dilations": [1, 2, 4, 8, 16, 32, 64],
    852                     "activation": "Tanh",
    853                     "gated": False,
    854                     "head_bias": False,
    855                 },
    856                 {
    857                     "condition_size": 1,
    858                     "input_size": 12,
    859                     "channels": 6,
    860                     "head_size": 1,
    861                     "kernel_size": 3,
    862                     "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
    863                     "activation": "Tanh",
    864                     "gated": False,
    865                     "head_bias": True,
    866                 },
    867             ],
    868             "head_scale": 0.02,
    869         },
    870         Architecture.FEATHER: {
    871             "layers_configs": [
    872                 {
    873                     "input_size": 1,
    874                     "condition_size": 1,
    875                     "channels": 8,
    876                     "head_size": 4,
    877                     "kernel_size": 3,
    878                     "dilations": [1, 2, 4, 8, 16, 32, 64],
    879                     "activation": "Tanh",
    880                     "gated": False,
    881                     "head_bias": False,
    882                 },
    883                 {
    884                     "condition_size": 1,
    885                     "input_size": 8,
    886                     "channels": 4,
    887                     "head_size": 1,
    888                     "kernel_size": 3,
    889                     "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
    890                     "activation": "Tanh",
    891                     "gated": False,
    892                     "head_bias": True,
    893                 },
    894             ],
    895             "head_scale": 0.02,
    896         },
    897         Architecture.NANO: {
    898             "layers_configs": [
    899                 {
    900                     "input_size": 1,
    901                     "condition_size": 1,
    902                     "channels": 4,
    903                     "head_size": 2,
    904                     "kernel_size": 3,
    905                     "dilations": [1, 2, 4, 8, 16, 32, 64],
    906                     "activation": "Tanh",
    907                     "gated": False,
    908                     "head_bias": False,
    909                 },
    910                 {
    911                     "condition_size": 1,
    912                     "input_size": 4,
    913                     "channels": 2,
    914                     "head_size": 1,
    915                     "kernel_size": 3,
    916                     "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
    917                     "activation": "Tanh",
    918                     "gated": False,
    919                     "head_bias": True,
    920                 },
    921             ],
    922             "head_scale": 0.02,
    923         },
    924     }[architecture]
    925 
    926 
    927 _CAB_MRSTFT_PRE_EMPH_WEIGHT = 2.0e-4
    928 _CAB_MRSTFT_PRE_EMPH_COEF = 0.85
    929 
    930 
    931 def _get_data_config(
    932     input_version: _Version,
    933     input_path: _Path,
    934     output_path: _Path,
    935     ny: int,
    936     latency: int,
    937 ) -> dict:
    938     def get_split_kwargs(data_info: _DataInfo):
    939         if data_info.major_version == 1:
    940             train_val_split = data_info.validation_start
    941             train_kwargs = {"stop_samples": train_val_split}
    942             validation_kwargs = {"start_samples": train_val_split}
    943         elif data_info.major_version == 2:
    944             validation_start = data_info.validation_start
    945             train_stop = validation_start
    946             validation_stop = validation_start + data_info.t_validate
    947             train_kwargs = {"stop_samples": train_stop}
    948             validation_kwargs = {
    949                 "start_samples": validation_start,
    950                 "stop_samples": validation_stop,
    951             }
    952         elif data_info.major_version == 3:
    953             validation_start = data_info.validation_start
    954             train_stop = validation_start
    955             train_kwargs = {"start_samples": 480_000, "stop_samples": train_stop}
    956             validation_kwargs = {"start_samples": validation_start}
    957         elif data_info.major_version == 4:
    958             validation_start = data_info.validation_start
    959             train_stop = validation_start
    960             train_kwargs = {"stop_samples": train_stop}
    961             # Proteus doesn't have silence to get a clean split. Bite the bullet.
    962             print(
    963                 "Using Proteus files:\n"
    964                 " * There isn't a silent point to split the validation set, so some of "
    965                 "your gear's response from the train set will leak into the start of "
    966                 "the validation set and impact validation accuracy (Bypassing data "
    967                 "quality check)\n"
    968                 " * Since the validation set is different, the ESRs reported for this "
    969                 "model aren't comparable to those from the other 'NAM' training files."
    970             )
    971             validation_kwargs = {
    972                 "start_samples": validation_start,
    973                 "require_input_pre_silence": False,
    974             }
    975         else:
    976             raise NotImplementedError(f"kwargs for input version {input_version}")
    977         return train_kwargs, validation_kwargs
    978 
    979     data_info = {
    980         1: _V1_DATA_INFO,
    981         2: _V2_DATA_INFO,
    982         3: _V3_DATA_INFO,
    983         4: _V4_DATA_INFO,
    984     }[input_version.major]
    985     train_kwargs, validation_kwargs = get_split_kwargs(data_info)
    986     data_config = {
    987         "train": {"ny": ny, **train_kwargs},
    988         "validation": {"ny": None, **validation_kwargs},
    989         "common": {
    990             "x_path": input_path,
    991             "y_path": output_path,
    992             "delay": latency,
    993             "allow_unequal_lengths": True,
    994         },
    995     }
    996     return data_config
    997 
    998 
    999 def _get_configs(
   1000     input_version: _Version,
   1001     input_path: str,
   1002     output_path: str,
   1003     latency: int,
   1004     epochs: int,
   1005     model_type: str,
   1006     architecture: Architecture,
   1007     ny: int,
   1008     lr: float,
   1009     lr_decay: float,
   1010     batch_size: int,
   1011     fit_mrstft: bool,
   1012 ):
   1013     data_config = _get_data_config(
   1014         input_version=input_version,
   1015         input_path=input_path,
   1016         output_path=output_path,
   1017         ny=ny,
   1018         latency=latency,
   1019     )
   1020 
   1021     if model_type == "WaveNet":
   1022         model_config = {
   1023             "net": {
   1024                 "name": "WaveNet",
   1025                 # This should do decently. If you really want a nice model, try turning up
   1026                 # "channels" in the first block and "input_size" in the second from 12 to 16.
   1027                 "config": get_wavenet_config(architecture),
   1028             },
   1029             "loss": {"val_loss": "esr"},
   1030             "optimizer": {"lr": lr},
   1031             "lr_scheduler": {
   1032                 "class": "ExponentialLR",
   1033                 "kwargs": {"gamma": 1.0 - lr_decay},
   1034             },
   1035         }
   1036     else:
   1037         model_config = {
   1038             "net": {
   1039                 "name": "LSTM",
   1040                 "config": get_lstm_config(architecture),
   1041             },
   1042             "loss": {
   1043                 "val_loss": "mse",
   1044                 "mask_first": 4096,
   1045                 "pre_emph_weight": 1.0,
   1046                 "pre_emph_coef": 0.85,
   1047             },
   1048             "optimizer": {"lr": 0.01},
   1049             "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}},
   1050         }
   1051     if fit_mrstft:
   1052         model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
   1053         model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
   1054 
   1055     if _torch.cuda.is_available():
   1056         device_config = {"accelerator": "gpu", "devices": 1}
   1057     elif _torch.backends.mps.is_available():
   1058         device_config = {"accelerator": "mps", "devices": 1}
   1059     else:
   1060         print("WARNING: No GPU was found. Training will be very slow!")
   1061         device_config = {}
   1062     learning_config = {
   1063         "train_dataloader": {
   1064             "batch_size": batch_size,
   1065             "shuffle": True,
   1066             "pin_memory": True,
   1067             "drop_last": True,
   1068             "num_workers": 0,
   1069         },
   1070         "val_dataloader": {},
   1071         "trainer": {"max_epochs": epochs, **device_config},
   1072     }
   1073     return data_config, model_config, learning_config
   1074 
   1075 
   1076 def _get_dataloaders(
   1077     data_config: _Dict, learning_config: _Dict, model: _LightningModule
   1078 ) -> _Tuple[_DataLoader, _DataLoader]:
   1079     data_config, learning_config = [
   1080         _deepcopy(c) for c in (data_config, learning_config)
   1081     ]
   1082     data_config["common"]["nx"] = model.net.receptive_field
   1083     dataset_train = _init_dataset(data_config, _Split.TRAIN)
   1084     dataset_validation = _init_dataset(data_config, _Split.VALIDATION)
   1085     train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"])
   1086     val_dataloader = _DataLoader(
   1087         dataset_validation, **learning_config["val_dataloader"]
   1088     )
   1089     return train_dataloader, val_dataloader
   1090 
   1091 
   1092 def _esr(pred: _torch.Tensor, target: _torch.Tensor) -> float:
   1093     return (
   1094         _torch.mean(_torch.square(pred - target)).item()
   1095         / _torch.mean(_torch.square(target)).item()
   1096     )
   1097 
   1098 
   1099 def _plot(
   1100     model,
   1101     ds,
   1102     window_start: _Optional[int] = None,
   1103     window_end: _Optional[int] = None,
   1104     filepath: _Optional[str] = None,
   1105     silent: bool = False,
   1106 ) -> float:
   1107     """
   1108     :return: The ESR
   1109     """
   1110     print("Plotting a comparison of your model with the target output...")
   1111     with _torch.no_grad():
   1112         tx = len(ds.x) / 48_000
   1113         print(f"Run (t={tx:.2f} sec)")
   1114         t0 = _time()
   1115         output = model(ds.x).flatten().cpu().numpy()
   1116         t1 = _time()
   1117         print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)")
   1118 
   1119     esr = _esr(_torch.Tensor(output), ds.y)
   1120     # Trying my best to put numbers to it...
   1121     if esr < 0.01:
   1122         esr_comment = "Great!"
   1123     elif esr < 0.035:
   1124         esr_comment = "Not bad!"
   1125     elif esr < 0.1:
   1126         esr_comment = "...This *might* sound ok!"
   1127     elif esr < 0.3:
   1128         esr_comment = "...This probably won't sound great :("
   1129     else:
   1130         esr_comment = "...Something seems to have gone wrong."
   1131     print(f"Error-signal ratio = {esr:.4g}")
   1132     print(esr_comment)
   1133 
   1134     _plt.figure(figsize=(16, 5))
   1135     _plt.plot(output[window_start:window_end], label="Prediction")
   1136     _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
   1137     _plt.title(f"ESR={esr:.4g}")
   1138     _plt.legend()
   1139     if filepath is not None:
   1140         _plt.savefig(filepath + ".png")
   1141     if not silent:
   1142         _plt.show()
   1143     return esr
   1144 
   1145 
   1146 def _print_nasty_checks_warning():
   1147     """
   1148     "ffs" -Dom
   1149     """
   1150     print(
   1151         "\n"
   1152         "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"
   1153         "X                                                                          X\n"
   1154         "X                                WARNING:                                  X\n"
   1155         "X                                                                          X\n"
   1156         "X       You are ignoring the checks! Your model might turn out bad!        X\n"
   1157         "X                                                                          X\n"
   1158         "X                              I warned you!                               X\n"
   1159         "X                                                                          X\n"
   1160         "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"
   1161     )
   1162 
   1163 
   1164 def _nasty_checks_modal():
   1165     msg = "You are ignoring the checks!\nYour model might turn out bad!"
   1166 
   1167     root = _tk.Tk()
   1168     root.withdraw()  # hide the root window
   1169     modal = _tk.Toplevel(root)
   1170     modal.geometry("300x100")
   1171     modal.title("Warning!")
   1172     label = _tk.Label(modal, text=msg)
   1173     label.pack(pady=10)
   1174     ok_button = _tk.Button(
   1175         modal,
   1176         text="I can only blame myself!",
   1177         command=lambda: [modal.destroy(), root.quit()],
   1178     )
   1179     ok_button.pack()
   1180     modal.grab_set()  # disable interaction with root window while modal is open
   1181     modal.mainloop()
   1182 
   1183 
   1184 class _ValidationStopping(_pl.callbacks.EarlyStopping):
   1185     """
   1186     Callback to indicate to stop training if the validation metric is good enough,
   1187     without the other conditions that EarlyStopping usually forces like patience.
   1188     """
   1189 
   1190     def __init__(self, *args, **kwargs):
   1191         super().__init__(*args, **kwargs)
   1192         self.patience = _np.inf
   1193 
   1194 
   1195 class _ModelCheckpoint(_pl.callbacks.model_checkpoint.ModelCheckpoint):
   1196     """
   1197     Extension to model checkpoint to save a .nam file as well as the .ckpt file.
   1198     """
   1199 
   1200     def __init__(
   1201         self,
   1202         *args,
   1203         user_metadata: _Optional[_UserMetadata] = None,
   1204         settings_metadata: _Optional[_metadata.Settings] = None,
   1205         data_metadata: _Optional[_metadata.Data] = None,
   1206         **kwargs,
   1207     ):
   1208         super().__init__(*args, **kwargs)
   1209         self._user_metadata = user_metadata
   1210         self._settings_metadata = settings_metadata
   1211         self._data_metadata = data_metadata
   1212 
   1213     _NAM_FILE_EXTENSION = _Exportable.FILE_EXTENSION
   1214 
   1215     @classmethod
   1216     def _get_nam_filepath(cls, filepath: str) -> _Path:
   1217         """
   1218         Given a .ckpt filepath, figure out a .nam for it.
   1219         """
   1220         if not filepath.endswith(cls.FILE_EXTENSION):
   1221             raise ValueError(
   1222                 f"Checkpoint filepath {filepath} doesn't end in expected extension "
   1223                 f"{cls.FILE_EXTENSION}"
   1224             )
   1225         return _Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION)
   1226 
   1227     @property
   1228     def _include_other_metadata(self) -> bool:
   1229         return self._settings_metadata is not None and self._data_metadata is not None
   1230 
   1231     def _save_checkpoint(self, trainer: _pl.Trainer, filepath: str):
   1232         # Save the .ckpt:
   1233         super()._save_checkpoint(trainer, filepath)
   1234         # Save the .nam:
   1235         nam_filepath = self._get_nam_filepath(filepath)
   1236         pl_model: _LightningModule = trainer.model
   1237         nam_model = pl_model.net
   1238         outdir = nam_filepath.parent
   1239         # HACK: Assume the extension
   1240         basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)]
   1241         other_metadata = (
   1242             None
   1243             if not self._include_other_metadata
   1244             else {
   1245                 _metadata.TRAINING_KEY: _metadata.TrainingMetadata(
   1246                     settings=self._settings_metadata,
   1247                     data=self._data_metadata,
   1248                     validation_esr=None,  # TODO how to get this?
   1249                 ).model_dump()
   1250             }
   1251         )
   1252         nam_model.export(
   1253             outdir,
   1254             basename=basename,
   1255             user_metadata=self._user_metadata,
   1256             other_metadata=other_metadata,
   1257         )
   1258 
   1259     def _remove_checkpoint(self, trainer: _pl.Trainer, filepath: str) -> None:
   1260         super()._remove_checkpoint(trainer, filepath)
   1261         nam_path = self._get_nam_filepath(filepath)
   1262         if nam_path.exists():
   1263             nam_path.unlink()
   1264 
   1265 
   1266 def get_callbacks(
   1267     threshold_esr: _Optional[float],
   1268     user_metadata: _Optional[_UserMetadata] = None,
   1269     settings_metadata: _Optional[_metadata.Settings] = None,
   1270     data_metadata: _Optional[_metadata.Data] = None,
   1271 ):
   1272     callbacks = [
   1273         _ModelCheckpoint(
   1274             filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
   1275             save_top_k=3,
   1276             monitor="val_loss",
   1277             every_n_epochs=1,
   1278             user_metadata=user_metadata,
   1279             settings_metadata=settings_metadata,
   1280             data_metadata=data_metadata,
   1281         ),
   1282         _ModelCheckpoint(
   1283             filename="checkpoint_last_{epoch:04d}_{step}",
   1284             every_n_epochs=1,
   1285             user_metadata=user_metadata,
   1286             settings_metadata=settings_metadata,
   1287             data_metadata=data_metadata,
   1288         ),
   1289     ]
   1290     if threshold_esr is not None:
   1291         callbacks.append(
   1292             _ValidationStopping(monitor="ESR", stopping_threshold=threshold_esr)
   1293         )
   1294     return callbacks
   1295 
   1296 
   1297 class TrainOutput(_NamedTuple):
   1298     """
   1299     :param model: The trained model
   1300     :param simpliifed_trianer_metadata: The metadata summarizing training with the
   1301         simplified trainer.
   1302     """
   1303 
   1304     model: _Optional[_LightningModule]
   1305     metadata: _metadata.TrainingMetadata
   1306 
   1307 
   1308 def _get_final_latency(latency_analysis: _metadata.Latency) -> int:
   1309     if latency_analysis.manual is not None:
   1310         latency = latency_analysis.manual
   1311         print(f"Latency provided as {latency_analysis.manual}; override calibration")
   1312     else:
   1313         latency = latency_analysis.calibration.recommended
   1314         print(f"Set latency to recommended {latency_analysis.calibration.recommended}")
   1315     return latency
   1316 
   1317 
   1318 def train(
   1319     input_path: str,
   1320     output_path: str,
   1321     train_path: str,
   1322     input_version: _Optional[_Version] = None,  # Deprecate?
   1323     epochs=100,
   1324     delay: _Optional[int] = None,
   1325     latency: _Optional[int] = None,
   1326     model_type: str = "WaveNet",
   1327     architecture: _Union[Architecture, str] = Architecture.STANDARD,
   1328     batch_size: int = 16,
   1329     ny: int = _NY_DEFAULT,
   1330     lr=0.004,
   1331     lr_decay=0.007,
   1332     seed: _Optional[int] = 0,
   1333     save_plot: bool = False,
   1334     silent: bool = False,
   1335     modelname: str = "model",
   1336     ignore_checks: bool = False,
   1337     local: bool = False,
   1338     fit_mrstft: bool = True,
   1339     threshold_esr: _Optional[bool] = None,
   1340     user_metadata: _Optional[_UserMetadata] = None,
   1341     fast_dev_run: _Union[bool, int] = False,
   1342 ) -> _Optional[TrainOutput]:
   1343     """
   1344     :param lr_decay: =1-gamma for Exponential learning rate decay.
   1345     :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
   1346     :param fast_dev_run: One-step training, used for tests.
   1347     """
   1348 
   1349     def parse_user_latency(
   1350         delay: _Optional[int], latency: _Optional[int]
   1351     ) -> _Optional[int]:
   1352         if delay is not None:
   1353             if latency is not None:
   1354                 raise ValueError("Both delay and latency are provided; use latency!")
   1355             print("WARNING: use of `delay` is deprecated; use `latency` instead")
   1356             return delay
   1357         return latency
   1358 
   1359     if seed is not None:
   1360         _torch.manual_seed(seed)
   1361 
   1362     # HACK: We need to check the sample rates and lengths of the audio here or else
   1363     # It will look like a bad self-ESR (Issue 473)
   1364     # Can move this into the "v3 checks" once the others are deprecated.
   1365     # And honestly remake this whole thing as a data processing pipeline.
   1366     sample_rate_validation = _check_audio_sample_rates(input_path, output_path)
   1367     if not sample_rate_validation.passed:
   1368         raise ValueError(
   1369             "Different sample rates detected for input "
   1370             f"({sample_rate_validation.input}) and output "
   1371             f"({sample_rate_validation.output}) audio!"
   1372         )
   1373     length_validation = _check_audio_lengths(input_path, output_path)
   1374     if not length_validation.passed:
   1375         raise ValueError(
   1376             "Your recording differs in length from the input file by "
   1377             f"{length_validation.delta_seconds:.2f} seconds. Check your reamp "
   1378             "in your DAW and ensure that they are the same length."
   1379         )
   1380 
   1381     if input_version is None:
   1382         input_version, strong_match = _detect_input_version(input_path)
   1383 
   1384     user_latency = parse_user_latency(delay, latency)
   1385     latency_analysis = _analyze_latency(
   1386         user_latency, input_version, input_path, output_path, silent=silent
   1387     )
   1388     final_latency = _get_final_latency(latency_analysis)
   1389 
   1390     data_check_output = _check_data(
   1391         input_path, output_path, input_version, final_latency, silent
   1392     )
   1393     if data_check_output is not None:
   1394         if data_check_output.passed:
   1395             print("-Checks passed")
   1396         else:
   1397             print("Failed checks!")
   1398             if ignore_checks:
   1399                 if local and not silent:
   1400                     _nasty_checks_modal()
   1401                 else:
   1402                     _print_nasty_checks_warning()
   1403             elif not local:  # And not ignore_checks
   1404                 print(
   1405                     "(To disable this check, run AT YOUR OWN RISK with "
   1406                     "`ignore_checks=True`.)"
   1407                 )
   1408             if not ignore_checks:
   1409                 print("Exiting core training...")
   1410                 return TrainOutput(
   1411                     model=None,
   1412                     metadata=_metadata.TrainingMetadata(
   1413                         settings=_metadata.Settings(ignore_checks=ignore_checks),
   1414                         data=_metadata.Data(
   1415                             latency=latency_analysis, checks=data_check_output
   1416                         ),
   1417                         validation_esr=None,
   1418                     ),
   1419                 )
   1420 
   1421     data_config, model_config, learning_config = _get_configs(
   1422         input_version,
   1423         input_path,
   1424         output_path,
   1425         final_latency,
   1426         epochs,
   1427         model_type,
   1428         Architecture(architecture),
   1429         ny,
   1430         lr,
   1431         lr_decay,
   1432         batch_size,
   1433         fit_mrstft,
   1434     )
   1435     assert (
   1436         "fast_dev_run" not in learning_config
   1437     ), "fast_dev_run is set as a kwarg to train()"
   1438 
   1439     print("Starting training. It's time to kick ass and chew bubblegum!")
   1440     # Issue:
   1441     # * Model needs sample rate from data, but data set needs nx from model.
   1442     # * Model is re-instantiated after training anyways.
   1443     # (Hacky) solution: set sample rate in model from dataloader after second
   1444     # instantiation from final checkpoint.
   1445     model = _LightningModule.init_from_config(model_config)
   1446     train_dataloader, val_dataloader = _get_dataloaders(
   1447         data_config, learning_config, model
   1448     )
   1449     if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate:
   1450         raise RuntimeError(
   1451             "Train and validation data loaders have different data set sample rates: "
   1452             f"{train_dataloader.dataset.sample_rate}, "
   1453             f"{val_dataloader.dataset.sample_rate}"
   1454         )
   1455     sample_rate = train_dataloader.dataset.sample_rate
   1456     model.net.sample_rate = sample_rate
   1457 
   1458     # Put together the metadata that's needed in checkpoints:
   1459     settings_metadata = _metadata.Settings(ignore_checks=ignore_checks)
   1460     data_metadata = _metadata.Data(latency=latency_analysis, checks=data_check_output)
   1461 
   1462     trainer = _pl.Trainer(
   1463         callbacks=get_callbacks(
   1464             threshold_esr,
   1465             user_metadata=user_metadata,
   1466             settings_metadata=settings_metadata,
   1467             data_metadata=data_metadata,
   1468         ),
   1469         default_root_dir=train_path,
   1470         fast_dev_run=fast_dev_run,
   1471         **learning_config["trainer"],
   1472     )
   1473     # Suppress the PossibleUserWarning about num_workers (Issue 345)
   1474     with _filter_warnings("ignore", category=_PossibleUserWarning):
   1475         trainer.fit(model, train_dataloader, val_dataloader)
   1476 
   1477     # Go to best checkpoint
   1478     best_checkpoint = trainer.checkpoint_callback.best_model_path
   1479     if best_checkpoint != "":
   1480         model = _LightningModule.load_from_checkpoint(
   1481             trainer.checkpoint_callback.best_model_path,
   1482             **_LightningModule.parse_config(model_config),
   1483         )
   1484     model.cpu()
   1485     model.eval()
   1486     model.net.sample_rate = sample_rate  # Hack, part 2
   1487 
   1488     def window_kwargs(version: _Version):
   1489         if version.major == 1:
   1490             return dict(
   1491                 window_start=100_000,  # Start of the plotting window, in samples
   1492                 window_end=101_000,  # End of the plotting window, in samples
   1493             )
   1494         elif version.major == 2:
   1495             # Same validation set even though it's a different spot in the reamp file
   1496             return dict(
   1497                 window_start=100_000,  # Start of the plotting window, in samples
   1498                 window_end=101_000,  # End of the plotting window, in samples
   1499             )
   1500         # Fallback:
   1501         return dict(
   1502             window_start=100_000,  # Start of the plotting window, in samples
   1503             window_end=101_000,  # End of the plotting window, in samples
   1504         )
   1505 
   1506     validation_esr = _plot(
   1507         model,
   1508         val_dataloader.dataset,
   1509         filepath=train_path + "/" + modelname if save_plot else None,
   1510         silent=silent,
   1511         **window_kwargs(input_version),
   1512     )
   1513     return TrainOutput(
   1514         model=model,
   1515         metadata=_metadata.TrainingMetadata(
   1516             settings=settings_metadata,
   1517             data=data_metadata,
   1518             validation_esr=validation_esr,
   1519         ),
   1520     )
   1521 
   1522 
   1523 class DataInputValidation(_BaseModel):
   1524     passed: bool
   1525 
   1526 
   1527 def validate_input(input_path) -> DataInputValidation:
   1528     """
   1529     :return: Could it be validated?
   1530     """
   1531     try:
   1532         _detect_input_version(input_path)
   1533         # succeeded...
   1534         return DataInputValidation(passed=True)
   1535     except _InputValidationError as e:
   1536         print(f"Input validation failed!\n\n{e}")
   1537         return DataInputValidation(passed=False)
   1538 
   1539 
   1540 class _PyTorchDataSplitValidation(_BaseModel):
   1541     """
   1542     :param msg: On exception, catch and assign. Otherwise None
   1543     """
   1544 
   1545     passed: bool
   1546     msg: _Optional[str]
   1547 
   1548 
   1549 class _PyTorchDataValidation(_BaseModel):
   1550     passed: bool
   1551     train: _PyTorchDataSplitValidation  # cf Split.TRAIN
   1552     validation: _PyTorchDataSplitValidation  # Split.VALIDATION
   1553 
   1554 
   1555 class _SampleRateValidation(_BaseModel):
   1556     passed: bool
   1557     input: int
   1558     output: int
   1559 
   1560 
   1561 class _LengthValidation(_BaseModel):
   1562     passed: bool
   1563     delta_seconds: float
   1564 
   1565 
   1566 class DataValidationOutput(_BaseModel):
   1567     passed: bool
   1568     passed_critical: bool
   1569     sample_rate: _SampleRateValidation
   1570     length: _LengthValidation
   1571     input_version: str
   1572     latency: _metadata.Latency
   1573     checks: _metadata.DataChecks
   1574     pytorch: _PyTorchDataValidation
   1575 
   1576 
   1577 def _check_audio_sample_rates(
   1578     input_path: _Path,
   1579     output_path: _Path,
   1580 ) -> _SampleRateValidation:
   1581     _, x_info = _wav_to_np(input_path, info=True)
   1582     _, y_info = _wav_to_np(output_path, info=True)
   1583 
   1584     return _SampleRateValidation(
   1585         passed=x_info.rate == y_info.rate,
   1586         input=x_info.rate,
   1587         output=y_info.rate,
   1588     )
   1589 
   1590 
   1591 def _check_audio_lengths(
   1592     input_path: _Path,
   1593     output_path: _Path,
   1594     max_under_seconds: _Optional[float] = 0.0,
   1595     max_over_seconds: _Optional[float] = 1.0,
   1596 ) -> _LengthValidation:
   1597     """
   1598     Check that the input and output have the right lengths compared to each
   1599     other.
   1600 
   1601     :param input_path: Path to input audio
   1602     :param output_path: Path to output audio
   1603     :param max_under_seconds: If not None, the maximum amount by which the
   1604         output can be shorter than the input. Should be non-negative i.e. a
   1605         value of 1.0 means that the output can't be more than a second shorter
   1606         than the input.
   1607     :param max_over_seconds: If not None, the maximum amount by which the
   1608         output can be longer than the input. Should be non-negative i.e. a
   1609         value of 1.0 means that the output can't be more than a second longer
   1610         than the input.
   1611     """
   1612     x, x_info = _wav_to_np(input_path, info=True)
   1613     y, y_info = _wav_to_np(output_path, info=True)
   1614 
   1615     length_input = len(x) / x_info.rate
   1616     length_output = len(y) / y_info.rate
   1617     delta_seconds = length_output - length_input
   1618 
   1619     passed = True
   1620     if max_under_seconds is not None and delta_seconds < -max_under_seconds:
   1621         passed = False
   1622     if max_over_seconds is not None and delta_seconds > max_over_seconds:
   1623         passed = False
   1624 
   1625     return _LengthValidation(passed=passed, delta_seconds=delta_seconds)
   1626 
   1627 
   1628 def validate_data(
   1629     input_path: _Path,
   1630     output_path: _Path,
   1631     user_latency: _Optional[int],
   1632     num_output_samples_per_datum: int = _NY_DEFAULT,
   1633 ):
   1634     """
   1635     Just do the checks to make sure that the data are ok.
   1636 
   1637     * Version identification
   1638     * Latency calibration
   1639     * Other checks
   1640     """
   1641     print("Validating data...")
   1642     passed = True  # Until proven otherwise
   1643     passed_critical = True  # These can't be ignored
   1644 
   1645     sample_rate_validation = _check_audio_sample_rates(input_path, output_path)
   1646     passed = passed and sample_rate_validation.passed
   1647     passed_critical = passed_critical and sample_rate_validation.passed
   1648 
   1649     length_validation = _check_audio_lengths(input_path, output_path)
   1650     passed = passed and length_validation.passed
   1651     passed_critical = passed_critical and length_validation.passed
   1652 
   1653     # Data version ID
   1654     input_version, strong_match = _detect_input_version(input_path)
   1655 
   1656     # Latency analysis
   1657     latency_analysis = _analyze_latency(
   1658         user_latency, input_version, input_path, output_path, silent=True
   1659     )
   1660     if latency_analysis.manual is None and any(
   1661         val for val in latency_analysis.calibration.warnings.model_dump().values()
   1662     ):
   1663         passed = False
   1664     final_latency = _get_final_latency(latency_analysis)
   1665 
   1666     # Other data checks based on input file version
   1667     data_checks = _check_data(
   1668         input_path,
   1669         output_path,
   1670         input_version,
   1671         latency_analysis.calibration.recommended,
   1672         silent=True,
   1673     )
   1674     passed = passed and data_checks.passed
   1675 
   1676     # Finally, try to make the PyTorch Dataset objects and note any failures:
   1677     data_config = _get_data_config(
   1678         input_version=input_version,
   1679         input_path=input_path,
   1680         output_path=output_path,
   1681         ny=num_output_samples_per_datum,
   1682         latency=final_latency,
   1683     )
   1684     # HACK this should depend on the model that's going to be used, but I think it will
   1685     # be unlikely to make a difference. Still, would be nice to fix.
   1686     data_config["common"]["nx"] = 4096
   1687 
   1688     pytorch_data_split_validation_dict: _Dict[str, _PyTorchDataSplitValidation] = {}
   1689     for split in _Split:
   1690         try:
   1691             _init_dataset(data_config, split)
   1692             pytorch_data_split_validation_dict[split.value] = (
   1693                 _PyTorchDataSplitValidation(passed=True, msg=None)
   1694             )
   1695         except _DataError as e:
   1696             pytorch_data_split_validation_dict[split.value] = (
   1697                 _PyTorchDataSplitValidation(passed=False, msg=str(e))
   1698             )
   1699     pytorch_data_validation = _PyTorchDataValidation(
   1700         passed=all(v.passed for v in pytorch_data_split_validation_dict.values()),
   1701         **pytorch_data_split_validation_dict,
   1702     )
   1703     passed = passed and pytorch_data_validation.passed
   1704     passed_critical = passed_critical and pytorch_data_validation.passed
   1705 
   1706     return DataValidationOutput(
   1707         passed=passed,
   1708         passed_critical=passed_critical,
   1709         sample_rate=sample_rate_validation,
   1710         length=length_validation,
   1711         input_version=str(input_version),
   1712         latency=latency_analysis,
   1713         checks=data_checks,
   1714         pytorch=pytorch_data_validation,
   1715     )