core.py (60431B)
1 # File: core.py 2 # Created Date: Tuesday December 20th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 The core of the "simplified trainer" 7 8 Used by the GUI and Colab trainers. 9 """ 10 11 import hashlib as _hashlib 12 import tkinter as _tk 13 from copy import deepcopy as _deepcopy 14 from enum import Enum as _Enum 15 from functools import partial as _partial 16 from pathlib import Path as _Path 17 from time import time as _time 18 from typing import ( 19 Dict as _Dict, 20 NamedTuple as _NamedTuple, 21 Optional as _Optional, 22 Sequence as _Sequence, 23 Tuple as _Tuple, 24 Union as _Union, 25 ) 26 27 import matplotlib.pyplot as _plt 28 import numpy as _np 29 import pytorch_lightning as _pl 30 import torch as _torch 31 from pydantic import BaseModel as _BaseModel 32 from pytorch_lightning.utilities.warnings import ( 33 PossibleUserWarning as _PossibleUserWarning, 34 ) 35 from torch.utils.data import DataLoader as _DataLoader 36 37 from ..data import ( 38 DataError as _DataError, 39 Split as _Split, 40 init_dataset as _init_dataset, 41 wav_to_np as _wav_to_np, 42 wav_to_tensor as _wav_to_tensor, 43 ) 44 from ..models.exportable import Exportable as _Exportable 45 from ..models.losses import esr as _ESR 46 from ..models.metadata import UserMetadata as _UserMetadata 47 from ..util import filter_warnings as _filter_warnings 48 from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version 49 from .lightning_module import LightningModule as _LightningModule 50 from . import metadata as _metadata 51 52 # Training using the simplified trainers in NAM is done at 48k. 53 STANDARD_SAMPLE_RATE = 48_000.0 54 # Default number of output samples per datum. 55 _NY_DEFAULT = 8192 56 57 58 class Architecture(_Enum): 59 STANDARD = "standard" 60 LITE = "lite" 61 FEATHER = "feather" 62 NANO = "nano" 63 64 65 class _InputValidationError(ValueError): 66 pass 67 68 69 def _detect_input_version(input_path) -> _Tuple[_Version, bool]: 70 """ 71 Check to see if the input matches any of the known inputs 72 73 :return: version, strong match 74 """ 75 76 def detect_strong(input_path) -> _Optional[_Version]: 77 def assign_hash(path): 78 # Use this to create hashes for new files 79 md5 = _hashlib.md5() 80 buffer_size = 65536 81 with open(path, "rb") as f: 82 while True: 83 data = f.read(buffer_size) 84 if not data: 85 break 86 md5.update(data) 87 file_hash = md5.hexdigest() 88 return file_hash 89 90 file_hash = assign_hash(input_path) 91 print(f"Strong hash: {file_hash}") 92 93 version = { 94 "4d54a958861bf720ec4637f43d44a7ef": _Version(1, 0, 0), 95 "7c3b6119c74465f79d96c761a0e27370": _Version(1, 1, 1), 96 "ede3b9d82135ce10c7ace3bb27469422": _Version(2, 0, 0), 97 "36cd1af62985c2fac3e654333e36431e": _Version(3, 0, 0), 98 "80e224bd5622fd6153ff1fd9f34cb3bd": _PROTEUS_VERSION, 99 }.get(file_hash) 100 if version is None: 101 print( 102 f"Provided input file {input_path} does not strong-match any known " 103 "standard input files." 104 ) 105 return version 106 107 def detect_weak(input_path) -> _Optional[_Version]: 108 def assign_hash(path): 109 Hash = _Optional[str] 110 Hashes = _Tuple[Hash, Hash] 111 112 def _hash(x: _np.ndarray) -> str: 113 return _hashlib.md5(x).hexdigest() 114 115 def assign_hashes_v1(path) -> Hashes: 116 # Use this to create recognized hashes for new files 117 x, info = _wav_to_np(path, info=True) 118 rate = info.rate 119 if rate != _V1_DATA_INFO.rate: 120 return None, None 121 # Times of intervals, in seconds 122 t_blips = _V1_DATA_INFO.t_blips 123 t_sweep = 3 * rate 124 t_white = 3 * rate 125 t_validation = _V1_DATA_INFO.t_validate 126 # v1 and v2 start with 1 blips, sine sweeps, and white noise 127 start_hash = _hash(x[: t_blips + t_sweep + t_white]) 128 # v1 ends with validation signal 129 end_hash = _hash(x[-t_validation:]) 130 return start_hash, end_hash 131 132 def assign_hashes_v2(path) -> Hashes: 133 # Use this to create recognized hashes for new files 134 x, info = _wav_to_np(path, info=True) 135 rate = info.rate 136 if rate != _V2_DATA_INFO.rate: 137 return None, None 138 # Times of intervals, in seconds 139 t_blips = _V2_DATA_INFO.t_blips 140 t_sweep = 3 * rate 141 t_white = 3 * rate 142 t_validation = _V1_DATA_INFO.t_validate 143 # v1 and v2 start with 1 blips, sine sweeps, and white noise 144 start_hash = _hash(x[: (t_blips + t_sweep + t_white)]) 145 # v2 ends with 2x validation & blips 146 end_hash = _hash(x[-(2 * t_validation + t_blips) :]) 147 return start_hash, end_hash 148 149 def assign_hashes_v3(path) -> Hashes: 150 # Use this to create recognized hashes for new files 151 x, info = _wav_to_np(path, info=True) 152 rate = info.rate 153 if rate != _V3_DATA_INFO.rate: 154 return None, None 155 # Times of intervals, in seconds 156 # See below. 157 end_of_start_interval = 17 * rate # Start at 0 158 start_of_end_interval = -9 * rate 159 start_hash = _hash(x[:end_of_start_interval]) 160 end_hash = _hash(x[start_of_end_interval:]) 161 return start_hash, end_hash 162 163 def assign_hash_v4(path) -> Hash: 164 # Use this to create recognized hashes for new files 165 x, info = _wav_to_np(path, info=True) 166 rate = info.rate 167 if rate != _V4_DATA_INFO.rate: 168 return None 169 # I don't care about anything in the file except the starting blip and 170 start_hash = _hash(x[: int(1 * _V4_DATA_INFO.rate)]) 171 return start_hash 172 173 start_hash_v1, end_hash_v1 = assign_hashes_v1(path) 174 start_hash_v2, end_hash_v2 = assign_hashes_v2(path) 175 start_hash_v3, end_hash_v3 = assign_hashes_v3(path) 176 hash_v4 = assign_hash_v4(path) 177 return ( 178 start_hash_v1, 179 end_hash_v1, 180 start_hash_v2, 181 end_hash_v2, 182 start_hash_v3, 183 end_hash_v3, 184 hash_v4, 185 ) 186 187 ( 188 start_hash_v1, 189 end_hash_v1, 190 start_hash_v2, 191 end_hash_v2, 192 start_hash_v3, 193 end_hash_v3, 194 hash_v4, 195 ) = assign_hash(input_path) 196 print( 197 "Weak hashes:\n" 198 f" Start (v1) : {start_hash_v1}\n" 199 f" End (v1) : {end_hash_v1}\n" 200 f" Start (v2) : {start_hash_v2}\n" 201 f" End (v2) : {end_hash_v2}\n" 202 f" Start (v3) : {start_hash_v3}\n" 203 f" End (v3) : {end_hash_v3}\n" 204 f" Proteus : {hash_v4}\n" 205 ) 206 207 # Check for matches, starting with most recent. Proteus last since its match is 208 # the most permissive. 209 version = { 210 ( 211 "dadb5d62f6c3973a59bf01439799809b", 212 "8458126969a3f9d8e19a53554eb1fd52", 213 ): _Version(3, 0, 0) 214 }.get((start_hash_v3, end_hash_v3)) 215 if version is not None: 216 return version 217 version = { 218 ( 219 "1c4d94fbcb47e4d820bef611c1d4ae65", 220 "28694e7bf9ab3f8ae6ef86e9545d4663", 221 ): _Version(2, 0, 0) 222 }.get((start_hash_v2, end_hash_v2)) 223 if version is not None: 224 return version 225 version = { 226 ( 227 "bb4e140c9299bae67560d280917eb52b", 228 "9b2468fcb6e9460a399fc5f64389d353", 229 ): _Version( 230 1, 0, 0 231 ), # FIXME! 232 ( 233 "9f20c6b5f7fef68dd88307625a573a14", 234 "8458126969a3f9d8e19a53554eb1fd52", 235 ): _Version(1, 1, 1), 236 }.get((start_hash_v1, end_hash_v1)) 237 if version is not None: 238 return version 239 version = {"46151c8030798081acc00a725325a07d": _PROTEUS_VERSION}.get(hash_v4) 240 return version 241 242 version = detect_strong(input_path) 243 if version is not None: 244 strong_match = True 245 return version, strong_match 246 print("Falling back to weak-matching...") 247 version = detect_weak(input_path) 248 if version is None: 249 raise _InputValidationError( 250 f"Input file at {input_path} cannot be recognized as any known version!" 251 ) 252 strong_match = False 253 254 return version, strong_match 255 256 257 class _DataInfo(_BaseModel): 258 """ 259 :param major_version: Data major version 260 """ 261 262 major_version: int 263 rate: _Optional[float] 264 t_blips: int 265 first_blips_start: int 266 t_validate: int 267 train_start: int 268 validation_start: int 269 noise_interval: _Tuple[int, int] 270 blip_locations: _Sequence[_Sequence[int]] 271 272 273 _V1_DATA_INFO = _DataInfo( 274 major_version=1, 275 rate=STANDARD_SAMPLE_RATE, 276 t_blips=48_000, 277 first_blips_start=0, 278 t_validate=432_000, 279 train_start=0, 280 validation_start=-432_000, 281 noise_interval=(0, 6000), 282 blip_locations=((12_000, 36_000),), 283 ) 284 # V2: 285 # (0:00-0:02) Blips at 0:00.5 and 0:01.5 286 # (0:02-0:05) Chirps 287 # (0:05-0:07) Noise 288 # (0:07-2:50.5) General training data 289 # (2:50.5-2:51) Silence 290 # (2:51-3:00) Validation 1 291 # (3:00-3:09) Validation 2 292 # (3:09-3:11) Blips at 3:09.5 and 3:10.5 293 _V2_DATA_INFO = _DataInfo( 294 major_version=2, 295 rate=STANDARD_SAMPLE_RATE, 296 t_blips=96_000, 297 first_blips_start=0, 298 t_validate=432_000, 299 train_start=0, 300 validation_start=-960_000, # 96_000 + 2 * 432_000 301 noise_interval=(12_000, 18_000), 302 blip_locations=((24_000, 72_000), (-72_000, -24_000)), 303 ) 304 # V3: 305 # (0:00-0:09) Validation 1 306 # (0:09-0:10) Silence 307 # (0:10-0:12) Blips at 0:10.5 and 0:11.5 308 # (0:12-0:15) Chirps 309 # (0:15-0:17) Noise 310 # (0:17-3:00.5) General training data 311 # (3:00.5-3:01) Silence 312 # (3:01-3:10) Validation 2 313 _V3_DATA_INFO = _DataInfo( 314 major_version=3, 315 rate=STANDARD_SAMPLE_RATE, 316 t_blips=96_000, 317 first_blips_start=480_000, 318 t_validate=432_000, 319 train_start=480_000, 320 validation_start=-432_000, 321 noise_interval=(492_000, 498_000), 322 blip_locations=((504_000, 552_000),), 323 ) 324 # V4 (aka GuitarML Proteus) 325 # https://github.com/GuitarML/Releases/releases/download/v1.0.0/Proteus_Capture_Utility.zip 326 # * 44.1k 327 # * Odd length... 328 # * There's a blip on sample zero. This has to be ignored or else over-compensated 329 # latencies will come out wrong! 330 # (0:00-0:01) Blips at 0:00.0 and 0:00.5 331 # (0:01-0:09) Sine sweeps 332 # (0:09-0:17) White noise 333 # (0:17:0.20) Rising white noise (to 0:20.333 appx) 334 # (0:20-3:30.858) General training data (ends on sample 9,298,872) 335 # I'm arbitrarily assigning the last 10 seconds as validation data. 336 _V4_DATA_INFO = _DataInfo( 337 major_version=4, 338 rate=44_100.0, 339 t_blips=44_099, # Need to ignore the first blip! 340 first_blips_start=1, # Need to ignore the first blip! 341 t_validate=441_000, 342 # Blips are problematic for training because they don't have preceding silence 343 train_start=44_100, 344 validation_start=-441_000, 345 noise_interval=(6_000, 12_000), 346 blip_locations=((22_050,),), 347 ) 348 349 _DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003 350 _DELAY_CALIBRATION_REL_THRESHOLD = 0.001 351 _DELAY_CALIBRATION_SAFETY_FACTOR = 1 # Might be able to make this zero... 352 353 354 def _warn_lookaheads(indices: _Sequence[int]) -> str: 355 return ( 356 f"WARNING: delays from some blips ({','.join([str(i) for i in indices])}) are " 357 "at the minimum value possible. This usually means that something is " 358 "wrong with your data. Check if trianing ends with a poor result!" 359 ) 360 361 362 def _calibrate_latency_v_all( 363 data_info: _DataInfo, 364 y, 365 abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD, 366 rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD, 367 safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR, 368 ) -> _metadata.LatencyCalibration: 369 """ 370 Calibrate the delay in teh input-output pair based on blips. 371 This only uses the blips in the first set of blip locations! 372 373 :param y: The output audio, in complete. 374 """ 375 376 def report_any_latency_warnings( 377 delays: _Sequence[int], 378 ) -> _metadata.LatencyCalibrationWarnings: 379 # Warnings associated with any single delay: 380 381 # "Lookahead warning": if the delay is equal to the lookahead, then it's 382 # probably an error. 383 lookahead_warnings = [i for i, d in enumerate(delays, 1) if d == -lookahead] 384 matches_lookahead = len(lookahead_warnings) > 0 385 if matches_lookahead: 386 print(_warn_lookaheads(lookahead_warnings)) 387 388 # Ensemble warnings 389 390 # If they're _really_ different, then something might be wrong. 391 max_disagreement_threshold = 20 392 max_disagreement_too_high = ( 393 _np.max(delays) - _np.min(delays) >= max_disagreement_threshold 394 ) 395 if max_disagreement_too_high: 396 print( 397 "WARNING: Latencies are anomalously different from each other (more " 398 f"than {max_disagreement_threshold} samples). If this model turns out " 399 "badly, then you might need to provide the latency manually." 400 ) 401 402 return _metadata.LatencyCalibrationWarnings( 403 matches_lookahead=matches_lookahead, 404 disagreement_too_high=max_disagreement_too_high, 405 ) 406 407 lookahead = 1_000 408 lookback = 10_000 409 # Calibrate the level for the trigger: 410 y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips] 411 background_level = _np.max( 412 _np.abs( 413 y[ 414 data_info.noise_interval[0] 415 - data_info.first_blips_start : data_info.noise_interval[1] 416 - data_info.first_blips_start 417 ] 418 ) 419 ) 420 trigger_threshold = max( 421 background_level + abs_threshold, 422 (1.0 + rel_threshold) * background_level, 423 ) 424 425 y_scans = [] 426 for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1): 427 # Relative to start of the data 428 i_rel = i_abs - data_info.first_blips_start 429 start_looking = i_rel - lookahead 430 stop_looking = i_rel + lookback 431 y_scans.append(y[start_looking:stop_looking]) 432 y_scan_average = _np.mean(_np.stack(y_scans), axis=0) 433 triggered = _np.where(_np.abs(y_scan_average) > trigger_threshold)[0] 434 if len(triggered) == 0: 435 msg = ( 436 "No response activated the trigger in response to input spikes. " 437 "Is something wrong with the reamp?" 438 ) 439 print(msg) 440 print("SHARE THIS PLOT IF YOU ASK FOR HELP") 441 _plt.figure() 442 _plt.plot( 443 _np.arange(-lookahead, lookback), 444 y_scan_average, 445 color="C0", 446 label="Signal average", 447 ) 448 for y_scan in y_scans: 449 _plt.plot(_np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2) 450 _plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") 451 _plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold") 452 _plt.axhline(y=trigger_threshold, color="k", linestyle="--") 453 _plt.xlim((-lookahead, lookback)) 454 _plt.xlabel("Samples") 455 _plt.ylabel("Response") 456 _plt.legend() 457 _plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP") 458 _plt.show() 459 raise RuntimeError(msg) 460 else: 461 j = triggered[0] 462 delay = j + start_looking - i_rel 463 464 print(f"Delay based on average is {delay}") 465 warnings = report_any_latency_warnings([delay]) 466 467 delay_post_safety_factor = delay - safety_factor 468 print( 469 f"After aplying safety factor of {safety_factor}, the final delay is " 470 f"{delay_post_safety_factor}" 471 ) 472 return _metadata.LatencyCalibration( 473 algorithm_version=1, 474 delays=[delay], 475 safety_factor=safety_factor, 476 recommended=delay_post_safety_factor, 477 warnings=warnings, 478 ) 479 480 481 _calibrate_latency_v1 = _partial(_calibrate_latency_v_all, _V1_DATA_INFO) 482 _calibrate_latency_v2 = _partial(_calibrate_latency_v_all, _V2_DATA_INFO) 483 _calibrate_latency_v3 = _partial(_calibrate_latency_v_all, _V3_DATA_INFO) 484 _calibrate_latency_v4 = _partial(_calibrate_latency_v_all, _V4_DATA_INFO) 485 486 487 def _plot_latency_v_all( 488 data_info: _DataInfo, latency: int, input_path: str, output_path: str, _nofail=True 489 ): 490 print("Plotting the latency for manual inspection...") 491 x = _wav_to_np(input_path)[ 492 data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips 493 ] 494 y = _wav_to_np(output_path)[ 495 data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips 496 ] 497 # Only get the blips we really want. 498 i = _np.where(_np.abs(x) > 0.5 * _np.abs(x).max())[0] 499 if len(i) == 0: 500 print("Failed to find the spike in the input file.") 501 print( 502 "Plotting the input and output; there should be spikes at around the " 503 "marked locations." 504 ) 505 t = _np.arange( 506 data_info.first_blips_start, data_info.first_blips_start + data_info.t_blips 507 ) 508 expected_spikes = data_info.blip_locations[0] # For v1 specifically 509 fig, axs = _plt.subplots(len((x, y)), 1) 510 for ax, curve in zip(axs, (x, y)): 511 ax.plot(t, curve) 512 [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes] 513 _plt.show() 514 if _nofail: 515 raise RuntimeError("Failed to plot delay") 516 else: 517 _plt.figure() 518 di = 20 519 # V1's got not a spike but a longer plateau; take the front of it. 520 if data_info.major_version == 1: 521 i = [i[0]] 522 for e, ii in enumerate(i, 1): 523 _plt.plot( 524 _np.arange(-di, di), 525 y[ii - di + latency : ii + di + latency], 526 ".-", 527 label=f"Output {e}", 528 ) 529 _plt.axvline(x=0, linestyle="--", color="k") 530 _plt.legend() 531 _plt.show() # This doesn't freeze the notebook 532 533 534 _plot_latency_v1 = _partial(_plot_latency_v_all, _V1_DATA_INFO) 535 _plot_latency_v2 = _partial(_plot_latency_v_all, _V2_DATA_INFO) 536 _plot_latency_v3 = _partial(_plot_latency_v_all, _V3_DATA_INFO) 537 _plot_latency_v4 = _partial(_plot_latency_v_all, _V4_DATA_INFO) 538 539 540 def _analyze_latency( 541 user_latency: _Optional[int], 542 input_version: _Version, 543 input_path: str, 544 output_path: str, 545 silent: bool = False, 546 ) -> _metadata.Latency: 547 """ 548 :param is_proteus: Forget the version; d 549 """ 550 if input_version.major == 1: 551 calibrate, plot = _calibrate_latency_v1, _plot_latency_v1 552 elif input_version.major == 2: 553 calibrate, plot = _calibrate_latency_v2, _plot_latency_v2 554 elif input_version.major == 3: 555 calibrate, plot = _calibrate_latency_v3, _plot_latency_v3 556 elif input_version.major == 4: 557 calibrate, plot = _calibrate_latency_v4, _plot_latency_v4 558 else: 559 raise NotImplementedError( 560 f"Input calibration not implemented for input version {input_version}" 561 ) 562 if user_latency is not None: 563 print(f"Delay is specified as {user_latency}") 564 calibration_output = calibrate(_wav_to_np(output_path)) 565 latency = ( 566 user_latency if user_latency is not None else calibration_output.recommended 567 ) 568 if not silent: 569 plot(latency, input_path, output_path) 570 571 return _metadata.Latency(manual=user_latency, calibration=calibration_output) 572 573 574 def get_lstm_config(architecture): 575 return { 576 Architecture.STANDARD: { 577 "num_layers": 1, 578 "hidden_size": 24, 579 "train_burn_in": 4096, 580 "train_truncate": 512, 581 }, 582 Architecture.LITE: { 583 "num_layers": 2, 584 "hidden_size": 8, 585 "train_burn_in": 4096, 586 "train_truncate": 512, 587 }, 588 Architecture.FEATHER: { 589 "num_layers": 1, 590 "hidden_size": 16, 591 "train_burn_in": 4096, 592 "train_truncate": 512, 593 }, 594 Architecture.NANO: { 595 "num_layers": 1, 596 "hidden_size": 12, 597 "train_burn_in": 4096, 598 "train_truncate": 512, 599 }, 600 }[architecture] 601 602 603 def _check_v1(*args, **kwargs) -> _metadata.DataChecks: 604 return _metadata.DataChecks(version=1, passed=True) 605 606 607 def _esr_validation_replicate_msg(threshold: float) -> str: 608 return ( 609 f"Validation replicates have a self-ESR of over {threshold}. " 610 "Your gear doesn't sound like itself when played twice!\n\n" 611 "Possible causes:\n" 612 " * Your signal chain is too noisy.\n" 613 " * There's a time-based effect (chorus, delay, reverb) turned on.\n" 614 " * Some knob got moved while reamping.\n" 615 " * You started reamping before the amp had time to warm up fully." 616 ) 617 618 619 def _check_v2( 620 input_path, output_path, delay: int, silent: bool 621 ) -> _metadata.DataChecks: 622 with _torch.no_grad(): 623 print("V2 checks...") 624 rate = _V2_DATA_INFO.rate 625 y = _wav_to_tensor(output_path, rate=rate) 626 t_blips = _V2_DATA_INFO.t_blips 627 t_validate = _V2_DATA_INFO.t_validate 628 y_val_1 = y[-(t_blips + 2 * t_validate) : -(t_blips + t_validate)] 629 y_val_2 = y[-(t_blips + t_validate) : -t_blips] 630 esr_replicate = _ESR(y_val_1, y_val_2).item() 631 print(f"Replicate ESR is {esr_replicate:.8f}.") 632 esr_replicate_threshold = 0.01 633 if esr_replicate > esr_replicate_threshold: 634 print(_esr_validation_replicate_msg(esr_replicate_threshold)) 635 636 # Do the blips line up? 637 # If the ESR is too bad, then flag it. 638 print("Checking blips...") 639 640 def get_blips(y): 641 """ 642 :return: [start/end,replicate] 643 """ 644 i0, i1 = _V2_DATA_INFO.blip_locations[0] 645 j0, j1 = _V2_DATA_INFO.blip_locations[1] 646 647 i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)] 648 start = -10 649 end = 1000 650 blips = _torch.stack( 651 [ 652 _torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]), 653 _torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]), 654 ] 655 ) 656 return blips 657 658 blips = get_blips(y) 659 esr_0 = _ESR(blips[0][0], blips[0][1]).item() # Within start 660 esr_1 = _ESR(blips[1][0], blips[1][1]).item() # Within end 661 esr_cross_0 = _ESR(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end 662 esr_cross_1 = _ESR(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end 663 664 print(" ESRs:") 665 print(f" Start : {esr_0}") 666 print(f" End : {esr_1}") 667 print(f" Cross (1) : {esr_cross_0}") 668 print(f" Cross (2) : {esr_cross_1}") 669 670 esr_threshold = 1.0e-2 671 672 def plot_esr_blip_error( 673 show_plot: bool, 674 msg: str, 675 arrays: _Sequence[_Sequence[float]], 676 labels: _Sequence[str], 677 ): 678 """ 679 :param silent: Whether to make and show a plot about it 680 """ 681 if show_plot: 682 _plt.figure() 683 [_plt.plot(array, label=label) for array, label in zip(arrays, labels)] 684 _plt.xlabel("Sample") 685 _plt.ylabel("Output") 686 _plt.legend() 687 _plt.grid() 688 print(msg) 689 if show_plot: 690 _plt.show() 691 print( 692 "This is known to be a very sensitive test, so training will continue. " 693 "If the model doesn't look good, then this may be why!" 694 ) 695 696 # Check consecutive blips 697 show_blip_plots = False 698 for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")): 699 if e >= esr_threshold: 700 plot_esr_blip_error( 701 show_blip_plots, 702 f"Failed consecutive blip check at {when} of training signal. The " 703 "target tone doesn't seem to be replicable over short timespans." 704 "\n\n" 705 " Possible causes:\n\n" 706 " * Your recording setup is really noisy.\n" 707 " * There's a noise gate that's messing things up.\n" 708 " * There's a time-based effect (chorus, delay, reverb) in " 709 "the signal chain", 710 blip_pair, 711 ("Replicate 1", "Replicate 2"), 712 ) 713 return _metadata.DataChecks(version=2, passed=False) 714 # Check blips between start & end of train signal 715 for e, blip_pair, replicate in zip( 716 (esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2) 717 ): 718 if e >= esr_threshold: 719 plot_esr_blip_error( 720 show_blip_plots, 721 f"Failed start-to-end blip check for blip replicate {replicate}. " 722 "The target tone doesn't seem to be same at the end of the reamp " 723 "as it was at the start. Did some setting change during reamping?", 724 blip_pair, 725 (f"Start, replicate {replicate}", f"End, replicate {replicate}"), 726 ) 727 return _metadata.DataChecks(version=2, passed=False) 728 return _metadata.DataChecks(version=2, passed=True) 729 730 731 def _check_v3( 732 input_path, output_path, silent: bool, *args, **kwargs 733 ) -> _metadata.DataChecks: 734 with _torch.no_grad(): 735 print("V3 checks...") 736 rate = _V3_DATA_INFO.rate 737 y = _wav_to_tensor(output_path, rate=rate) 738 n = len(_wav_to_tensor(input_path)) # to End-crop output 739 y_val_1 = y[: _V3_DATA_INFO.t_validate] 740 y_val_2 = y[n - _V3_DATA_INFO.t_validate : n] 741 esr_replicate = _ESR(y_val_1, y_val_2).item() 742 print(f"Replicate ESR is {esr_replicate:.8f}.") 743 esr_replicate_threshold = 0.01 744 if esr_replicate > esr_replicate_threshold: 745 print(_esr_validation_replicate_msg(esr_replicate_threshold)) 746 if not silent: 747 _plt.figure() 748 t = _np.arange(len(y_val_1)) / rate 749 _plt.plot(t, y_val_1, label="Validation 1") 750 _plt.plot(t, y_val_2, label="Validation 2") 751 _plt.xlabel("Time (sec)") 752 _plt.legend() 753 _plt.title("V3 check: Validation replicate FAILURE") 754 _plt.show() 755 return _metadata.DataChecks(version=3, passed=False) 756 return _metadata.DataChecks(version=3, passed=True) 757 758 759 def _check_v4( 760 input_path, output_path, silent: bool, *args, **kwargs 761 ) -> _metadata.DataChecks: 762 # Things we can't check: 763 # Latency compensation agreement 764 # Data replicability 765 print("Using Proteus audio file. Standard data checks aren't possible!") 766 signal, info = _wav_to_np(output_path, info=True) 767 passed = True 768 if info.rate != _V4_DATA_INFO.rate: 769 print( 770 f"Output signal has sample rate {info.rate}; expected {_V4_DATA_INFO.rate}!" 771 ) 772 passed = False 773 # I don't care what's in the files except that they're long enough to hold the blip 774 # and the last 10 seconds I decided to use as validation 775 required_length = int((1.0 + 10.0) * _V4_DATA_INFO.rate) 776 if len(signal) < required_length: 777 print( 778 "File doesn't meet the minimum length requirements for latency compensation and validation signal!" 779 ) 780 passed = False 781 return _metadata.DataChecks(version=4, passed=passed) 782 783 784 def _check_data( 785 input_path: str, output_path: str, input_version: _Version, delay: int, silent: bool 786 ) -> _Optional[_metadata.DataChecks]: 787 """ 788 Ensure that everything should go smoothly 789 790 :return: True if looks good 791 """ 792 if input_version.major == 1: 793 f = _check_v1 794 elif input_version.major == 2: 795 f = _check_v2 796 elif input_version.major == 3: 797 f = _check_v3 798 elif input_version.major == 4: 799 f = _check_v4 800 else: 801 print(f"Checks not implemented for input version {input_version}; skip") 802 return None 803 out = f(input_path, output_path, delay, silent) 804 # Issue 442: Deprecate inputs 805 if input_version.major != 3: 806 print( 807 f"Input version {input_version} is deprecated and will be removed in " 808 "version 0.11 of the trainer. To continue using it, you must ignore checks." 809 ) 810 out.passed = False 811 return out 812 813 814 def get_wavenet_config(architecture): 815 return { 816 Architecture.STANDARD: { 817 "layers_configs": [ 818 { 819 "input_size": 1, 820 "condition_size": 1, 821 "channels": 16, 822 "head_size": 8, 823 "kernel_size": 3, 824 "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], 825 "activation": "Tanh", 826 "gated": False, 827 "head_bias": False, 828 }, 829 { 830 "condition_size": 1, 831 "input_size": 16, 832 "channels": 8, 833 "head_size": 1, 834 "kernel_size": 3, 835 "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], 836 "activation": "Tanh", 837 "gated": False, 838 "head_bias": True, 839 }, 840 ], 841 "head_scale": 0.02, 842 }, 843 Architecture.LITE: { 844 "layers_configs": [ 845 { 846 "input_size": 1, 847 "condition_size": 1, 848 "channels": 12, 849 "head_size": 6, 850 "kernel_size": 3, 851 "dilations": [1, 2, 4, 8, 16, 32, 64], 852 "activation": "Tanh", 853 "gated": False, 854 "head_bias": False, 855 }, 856 { 857 "condition_size": 1, 858 "input_size": 12, 859 "channels": 6, 860 "head_size": 1, 861 "kernel_size": 3, 862 "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], 863 "activation": "Tanh", 864 "gated": False, 865 "head_bias": True, 866 }, 867 ], 868 "head_scale": 0.02, 869 }, 870 Architecture.FEATHER: { 871 "layers_configs": [ 872 { 873 "input_size": 1, 874 "condition_size": 1, 875 "channels": 8, 876 "head_size": 4, 877 "kernel_size": 3, 878 "dilations": [1, 2, 4, 8, 16, 32, 64], 879 "activation": "Tanh", 880 "gated": False, 881 "head_bias": False, 882 }, 883 { 884 "condition_size": 1, 885 "input_size": 8, 886 "channels": 4, 887 "head_size": 1, 888 "kernel_size": 3, 889 "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], 890 "activation": "Tanh", 891 "gated": False, 892 "head_bias": True, 893 }, 894 ], 895 "head_scale": 0.02, 896 }, 897 Architecture.NANO: { 898 "layers_configs": [ 899 { 900 "input_size": 1, 901 "condition_size": 1, 902 "channels": 4, 903 "head_size": 2, 904 "kernel_size": 3, 905 "dilations": [1, 2, 4, 8, 16, 32, 64], 906 "activation": "Tanh", 907 "gated": False, 908 "head_bias": False, 909 }, 910 { 911 "condition_size": 1, 912 "input_size": 4, 913 "channels": 2, 914 "head_size": 1, 915 "kernel_size": 3, 916 "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], 917 "activation": "Tanh", 918 "gated": False, 919 "head_bias": True, 920 }, 921 ], 922 "head_scale": 0.02, 923 }, 924 }[architecture] 925 926 927 _CAB_MRSTFT_PRE_EMPH_WEIGHT = 2.0e-4 928 _CAB_MRSTFT_PRE_EMPH_COEF = 0.85 929 930 931 def _get_data_config( 932 input_version: _Version, 933 input_path: _Path, 934 output_path: _Path, 935 ny: int, 936 latency: int, 937 ) -> dict: 938 def get_split_kwargs(data_info: _DataInfo): 939 if data_info.major_version == 1: 940 train_val_split = data_info.validation_start 941 train_kwargs = {"stop_samples": train_val_split} 942 validation_kwargs = {"start_samples": train_val_split} 943 elif data_info.major_version == 2: 944 validation_start = data_info.validation_start 945 train_stop = validation_start 946 validation_stop = validation_start + data_info.t_validate 947 train_kwargs = {"stop_samples": train_stop} 948 validation_kwargs = { 949 "start_samples": validation_start, 950 "stop_samples": validation_stop, 951 } 952 elif data_info.major_version == 3: 953 validation_start = data_info.validation_start 954 train_stop = validation_start 955 train_kwargs = {"start_samples": 480_000, "stop_samples": train_stop} 956 validation_kwargs = {"start_samples": validation_start} 957 elif data_info.major_version == 4: 958 validation_start = data_info.validation_start 959 train_stop = validation_start 960 train_kwargs = {"stop_samples": train_stop} 961 # Proteus doesn't have silence to get a clean split. Bite the bullet. 962 print( 963 "Using Proteus files:\n" 964 " * There isn't a silent point to split the validation set, so some of " 965 "your gear's response from the train set will leak into the start of " 966 "the validation set and impact validation accuracy (Bypassing data " 967 "quality check)\n" 968 " * Since the validation set is different, the ESRs reported for this " 969 "model aren't comparable to those from the other 'NAM' training files." 970 ) 971 validation_kwargs = { 972 "start_samples": validation_start, 973 "require_input_pre_silence": False, 974 } 975 else: 976 raise NotImplementedError(f"kwargs for input version {input_version}") 977 return train_kwargs, validation_kwargs 978 979 data_info = { 980 1: _V1_DATA_INFO, 981 2: _V2_DATA_INFO, 982 3: _V3_DATA_INFO, 983 4: _V4_DATA_INFO, 984 }[input_version.major] 985 train_kwargs, validation_kwargs = get_split_kwargs(data_info) 986 data_config = { 987 "train": {"ny": ny, **train_kwargs}, 988 "validation": {"ny": None, **validation_kwargs}, 989 "common": { 990 "x_path": input_path, 991 "y_path": output_path, 992 "delay": latency, 993 "allow_unequal_lengths": True, 994 }, 995 } 996 return data_config 997 998 999 def _get_configs( 1000 input_version: _Version, 1001 input_path: str, 1002 output_path: str, 1003 latency: int, 1004 epochs: int, 1005 model_type: str, 1006 architecture: Architecture, 1007 ny: int, 1008 lr: float, 1009 lr_decay: float, 1010 batch_size: int, 1011 fit_mrstft: bool, 1012 ): 1013 data_config = _get_data_config( 1014 input_version=input_version, 1015 input_path=input_path, 1016 output_path=output_path, 1017 ny=ny, 1018 latency=latency, 1019 ) 1020 1021 if model_type == "WaveNet": 1022 model_config = { 1023 "net": { 1024 "name": "WaveNet", 1025 # This should do decently. If you really want a nice model, try turning up 1026 # "channels" in the first block and "input_size" in the second from 12 to 16. 1027 "config": get_wavenet_config(architecture), 1028 }, 1029 "loss": {"val_loss": "esr"}, 1030 "optimizer": {"lr": lr}, 1031 "lr_scheduler": { 1032 "class": "ExponentialLR", 1033 "kwargs": {"gamma": 1.0 - lr_decay}, 1034 }, 1035 } 1036 else: 1037 model_config = { 1038 "net": { 1039 "name": "LSTM", 1040 "config": get_lstm_config(architecture), 1041 }, 1042 "loss": { 1043 "val_loss": "mse", 1044 "mask_first": 4096, 1045 "pre_emph_weight": 1.0, 1046 "pre_emph_coef": 0.85, 1047 }, 1048 "optimizer": {"lr": 0.01}, 1049 "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}}, 1050 } 1051 if fit_mrstft: 1052 model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT 1053 model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF 1054 1055 if _torch.cuda.is_available(): 1056 device_config = {"accelerator": "gpu", "devices": 1} 1057 elif _torch.backends.mps.is_available(): 1058 device_config = {"accelerator": "mps", "devices": 1} 1059 else: 1060 print("WARNING: No GPU was found. Training will be very slow!") 1061 device_config = {} 1062 learning_config = { 1063 "train_dataloader": { 1064 "batch_size": batch_size, 1065 "shuffle": True, 1066 "pin_memory": True, 1067 "drop_last": True, 1068 "num_workers": 0, 1069 }, 1070 "val_dataloader": {}, 1071 "trainer": {"max_epochs": epochs, **device_config}, 1072 } 1073 return data_config, model_config, learning_config 1074 1075 1076 def _get_dataloaders( 1077 data_config: _Dict, learning_config: _Dict, model: _LightningModule 1078 ) -> _Tuple[_DataLoader, _DataLoader]: 1079 data_config, learning_config = [ 1080 _deepcopy(c) for c in (data_config, learning_config) 1081 ] 1082 data_config["common"]["nx"] = model.net.receptive_field 1083 dataset_train = _init_dataset(data_config, _Split.TRAIN) 1084 dataset_validation = _init_dataset(data_config, _Split.VALIDATION) 1085 train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"]) 1086 val_dataloader = _DataLoader( 1087 dataset_validation, **learning_config["val_dataloader"] 1088 ) 1089 return train_dataloader, val_dataloader 1090 1091 1092 def _esr(pred: _torch.Tensor, target: _torch.Tensor) -> float: 1093 return ( 1094 _torch.mean(_torch.square(pred - target)).item() 1095 / _torch.mean(_torch.square(target)).item() 1096 ) 1097 1098 1099 def _plot( 1100 model, 1101 ds, 1102 window_start: _Optional[int] = None, 1103 window_end: _Optional[int] = None, 1104 filepath: _Optional[str] = None, 1105 silent: bool = False, 1106 ) -> float: 1107 """ 1108 :return: The ESR 1109 """ 1110 print("Plotting a comparison of your model with the target output...") 1111 with _torch.no_grad(): 1112 tx = len(ds.x) / 48_000 1113 print(f"Run (t={tx:.2f} sec)") 1114 t0 = _time() 1115 output = model(ds.x).flatten().cpu().numpy() 1116 t1 = _time() 1117 print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)") 1118 1119 esr = _esr(_torch.Tensor(output), ds.y) 1120 # Trying my best to put numbers to it... 1121 if esr < 0.01: 1122 esr_comment = "Great!" 1123 elif esr < 0.035: 1124 esr_comment = "Not bad!" 1125 elif esr < 0.1: 1126 esr_comment = "...This *might* sound ok!" 1127 elif esr < 0.3: 1128 esr_comment = "...This probably won't sound great :(" 1129 else: 1130 esr_comment = "...Something seems to have gone wrong." 1131 print(f"Error-signal ratio = {esr:.4g}") 1132 print(esr_comment) 1133 1134 _plt.figure(figsize=(16, 5)) 1135 _plt.plot(output[window_start:window_end], label="Prediction") 1136 _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") 1137 _plt.title(f"ESR={esr:.4g}") 1138 _plt.legend() 1139 if filepath is not None: 1140 _plt.savefig(filepath + ".png") 1141 if not silent: 1142 _plt.show() 1143 return esr 1144 1145 1146 def _print_nasty_checks_warning(): 1147 """ 1148 "ffs" -Dom 1149 """ 1150 print( 1151 "\n" 1152 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n" 1153 "X X\n" 1154 "X WARNING: X\n" 1155 "X X\n" 1156 "X You are ignoring the checks! Your model might turn out bad! X\n" 1157 "X X\n" 1158 "X I warned you! X\n" 1159 "X X\n" 1160 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n" 1161 ) 1162 1163 1164 def _nasty_checks_modal(): 1165 msg = "You are ignoring the checks!\nYour model might turn out bad!" 1166 1167 root = _tk.Tk() 1168 root.withdraw() # hide the root window 1169 modal = _tk.Toplevel(root) 1170 modal.geometry("300x100") 1171 modal.title("Warning!") 1172 label = _tk.Label(modal, text=msg) 1173 label.pack(pady=10) 1174 ok_button = _tk.Button( 1175 modal, 1176 text="I can only blame myself!", 1177 command=lambda: [modal.destroy(), root.quit()], 1178 ) 1179 ok_button.pack() 1180 modal.grab_set() # disable interaction with root window while modal is open 1181 modal.mainloop() 1182 1183 1184 class _ValidationStopping(_pl.callbacks.EarlyStopping): 1185 """ 1186 Callback to indicate to stop training if the validation metric is good enough, 1187 without the other conditions that EarlyStopping usually forces like patience. 1188 """ 1189 1190 def __init__(self, *args, **kwargs): 1191 super().__init__(*args, **kwargs) 1192 self.patience = _np.inf 1193 1194 1195 class _ModelCheckpoint(_pl.callbacks.model_checkpoint.ModelCheckpoint): 1196 """ 1197 Extension to model checkpoint to save a .nam file as well as the .ckpt file. 1198 """ 1199 1200 def __init__( 1201 self, 1202 *args, 1203 user_metadata: _Optional[_UserMetadata] = None, 1204 settings_metadata: _Optional[_metadata.Settings] = None, 1205 data_metadata: _Optional[_metadata.Data] = None, 1206 **kwargs, 1207 ): 1208 super().__init__(*args, **kwargs) 1209 self._user_metadata = user_metadata 1210 self._settings_metadata = settings_metadata 1211 self._data_metadata = data_metadata 1212 1213 _NAM_FILE_EXTENSION = _Exportable.FILE_EXTENSION 1214 1215 @classmethod 1216 def _get_nam_filepath(cls, filepath: str) -> _Path: 1217 """ 1218 Given a .ckpt filepath, figure out a .nam for it. 1219 """ 1220 if not filepath.endswith(cls.FILE_EXTENSION): 1221 raise ValueError( 1222 f"Checkpoint filepath {filepath} doesn't end in expected extension " 1223 f"{cls.FILE_EXTENSION}" 1224 ) 1225 return _Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION) 1226 1227 @property 1228 def _include_other_metadata(self) -> bool: 1229 return self._settings_metadata is not None and self._data_metadata is not None 1230 1231 def _save_checkpoint(self, trainer: _pl.Trainer, filepath: str): 1232 # Save the .ckpt: 1233 super()._save_checkpoint(trainer, filepath) 1234 # Save the .nam: 1235 nam_filepath = self._get_nam_filepath(filepath) 1236 pl_model: _LightningModule = trainer.model 1237 nam_model = pl_model.net 1238 outdir = nam_filepath.parent 1239 # HACK: Assume the extension 1240 basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)] 1241 other_metadata = ( 1242 None 1243 if not self._include_other_metadata 1244 else { 1245 _metadata.TRAINING_KEY: _metadata.TrainingMetadata( 1246 settings=self._settings_metadata, 1247 data=self._data_metadata, 1248 validation_esr=None, # TODO how to get this? 1249 ).model_dump() 1250 } 1251 ) 1252 nam_model.export( 1253 outdir, 1254 basename=basename, 1255 user_metadata=self._user_metadata, 1256 other_metadata=other_metadata, 1257 ) 1258 1259 def _remove_checkpoint(self, trainer: _pl.Trainer, filepath: str) -> None: 1260 super()._remove_checkpoint(trainer, filepath) 1261 nam_path = self._get_nam_filepath(filepath) 1262 if nam_path.exists(): 1263 nam_path.unlink() 1264 1265 1266 def get_callbacks( 1267 threshold_esr: _Optional[float], 1268 user_metadata: _Optional[_UserMetadata] = None, 1269 settings_metadata: _Optional[_metadata.Settings] = None, 1270 data_metadata: _Optional[_metadata.Data] = None, 1271 ): 1272 callbacks = [ 1273 _ModelCheckpoint( 1274 filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}", 1275 save_top_k=3, 1276 monitor="val_loss", 1277 every_n_epochs=1, 1278 user_metadata=user_metadata, 1279 settings_metadata=settings_metadata, 1280 data_metadata=data_metadata, 1281 ), 1282 _ModelCheckpoint( 1283 filename="checkpoint_last_{epoch:04d}_{step}", 1284 every_n_epochs=1, 1285 user_metadata=user_metadata, 1286 settings_metadata=settings_metadata, 1287 data_metadata=data_metadata, 1288 ), 1289 ] 1290 if threshold_esr is not None: 1291 callbacks.append( 1292 _ValidationStopping(monitor="ESR", stopping_threshold=threshold_esr) 1293 ) 1294 return callbacks 1295 1296 1297 class TrainOutput(_NamedTuple): 1298 """ 1299 :param model: The trained model 1300 :param simpliifed_trianer_metadata: The metadata summarizing training with the 1301 simplified trainer. 1302 """ 1303 1304 model: _Optional[_LightningModule] 1305 metadata: _metadata.TrainingMetadata 1306 1307 1308 def _get_final_latency(latency_analysis: _metadata.Latency) -> int: 1309 if latency_analysis.manual is not None: 1310 latency = latency_analysis.manual 1311 print(f"Latency provided as {latency_analysis.manual}; override calibration") 1312 else: 1313 latency = latency_analysis.calibration.recommended 1314 print(f"Set latency to recommended {latency_analysis.calibration.recommended}") 1315 return latency 1316 1317 1318 def train( 1319 input_path: str, 1320 output_path: str, 1321 train_path: str, 1322 input_version: _Optional[_Version] = None, # Deprecate? 1323 epochs=100, 1324 delay: _Optional[int] = None, 1325 latency: _Optional[int] = None, 1326 model_type: str = "WaveNet", 1327 architecture: _Union[Architecture, str] = Architecture.STANDARD, 1328 batch_size: int = 16, 1329 ny: int = _NY_DEFAULT, 1330 lr=0.004, 1331 lr_decay=0.007, 1332 seed: _Optional[int] = 0, 1333 save_plot: bool = False, 1334 silent: bool = False, 1335 modelname: str = "model", 1336 ignore_checks: bool = False, 1337 local: bool = False, 1338 fit_mrstft: bool = True, 1339 threshold_esr: _Optional[bool] = None, 1340 user_metadata: _Optional[_UserMetadata] = None, 1341 fast_dev_run: _Union[bool, int] = False, 1342 ) -> _Optional[TrainOutput]: 1343 """ 1344 :param lr_decay: =1-gamma for Exponential learning rate decay. 1345 :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. 1346 :param fast_dev_run: One-step training, used for tests. 1347 """ 1348 1349 def parse_user_latency( 1350 delay: _Optional[int], latency: _Optional[int] 1351 ) -> _Optional[int]: 1352 if delay is not None: 1353 if latency is not None: 1354 raise ValueError("Both delay and latency are provided; use latency!") 1355 print("WARNING: use of `delay` is deprecated; use `latency` instead") 1356 return delay 1357 return latency 1358 1359 if seed is not None: 1360 _torch.manual_seed(seed) 1361 1362 # HACK: We need to check the sample rates and lengths of the audio here or else 1363 # It will look like a bad self-ESR (Issue 473) 1364 # Can move this into the "v3 checks" once the others are deprecated. 1365 # And honestly remake this whole thing as a data processing pipeline. 1366 sample_rate_validation = _check_audio_sample_rates(input_path, output_path) 1367 if not sample_rate_validation.passed: 1368 raise ValueError( 1369 "Different sample rates detected for input " 1370 f"({sample_rate_validation.input}) and output " 1371 f"({sample_rate_validation.output}) audio!" 1372 ) 1373 length_validation = _check_audio_lengths(input_path, output_path) 1374 if not length_validation.passed: 1375 raise ValueError( 1376 "Your recording differs in length from the input file by " 1377 f"{length_validation.delta_seconds:.2f} seconds. Check your reamp " 1378 "in your DAW and ensure that they are the same length." 1379 ) 1380 1381 if input_version is None: 1382 input_version, strong_match = _detect_input_version(input_path) 1383 1384 user_latency = parse_user_latency(delay, latency) 1385 latency_analysis = _analyze_latency( 1386 user_latency, input_version, input_path, output_path, silent=silent 1387 ) 1388 final_latency = _get_final_latency(latency_analysis) 1389 1390 data_check_output = _check_data( 1391 input_path, output_path, input_version, final_latency, silent 1392 ) 1393 if data_check_output is not None: 1394 if data_check_output.passed: 1395 print("-Checks passed") 1396 else: 1397 print("Failed checks!") 1398 if ignore_checks: 1399 if local and not silent: 1400 _nasty_checks_modal() 1401 else: 1402 _print_nasty_checks_warning() 1403 elif not local: # And not ignore_checks 1404 print( 1405 "(To disable this check, run AT YOUR OWN RISK with " 1406 "`ignore_checks=True`.)" 1407 ) 1408 if not ignore_checks: 1409 print("Exiting core training...") 1410 return TrainOutput( 1411 model=None, 1412 metadata=_metadata.TrainingMetadata( 1413 settings=_metadata.Settings(ignore_checks=ignore_checks), 1414 data=_metadata.Data( 1415 latency=latency_analysis, checks=data_check_output 1416 ), 1417 validation_esr=None, 1418 ), 1419 ) 1420 1421 data_config, model_config, learning_config = _get_configs( 1422 input_version, 1423 input_path, 1424 output_path, 1425 final_latency, 1426 epochs, 1427 model_type, 1428 Architecture(architecture), 1429 ny, 1430 lr, 1431 lr_decay, 1432 batch_size, 1433 fit_mrstft, 1434 ) 1435 assert ( 1436 "fast_dev_run" not in learning_config 1437 ), "fast_dev_run is set as a kwarg to train()" 1438 1439 print("Starting training. It's time to kick ass and chew bubblegum!") 1440 # Issue: 1441 # * Model needs sample rate from data, but data set needs nx from model. 1442 # * Model is re-instantiated after training anyways. 1443 # (Hacky) solution: set sample rate in model from dataloader after second 1444 # instantiation from final checkpoint. 1445 model = _LightningModule.init_from_config(model_config) 1446 train_dataloader, val_dataloader = _get_dataloaders( 1447 data_config, learning_config, model 1448 ) 1449 if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate: 1450 raise RuntimeError( 1451 "Train and validation data loaders have different data set sample rates: " 1452 f"{train_dataloader.dataset.sample_rate}, " 1453 f"{val_dataloader.dataset.sample_rate}" 1454 ) 1455 sample_rate = train_dataloader.dataset.sample_rate 1456 model.net.sample_rate = sample_rate 1457 1458 # Put together the metadata that's needed in checkpoints: 1459 settings_metadata = _metadata.Settings(ignore_checks=ignore_checks) 1460 data_metadata = _metadata.Data(latency=latency_analysis, checks=data_check_output) 1461 1462 trainer = _pl.Trainer( 1463 callbacks=get_callbacks( 1464 threshold_esr, 1465 user_metadata=user_metadata, 1466 settings_metadata=settings_metadata, 1467 data_metadata=data_metadata, 1468 ), 1469 default_root_dir=train_path, 1470 fast_dev_run=fast_dev_run, 1471 **learning_config["trainer"], 1472 ) 1473 # Suppress the PossibleUserWarning about num_workers (Issue 345) 1474 with _filter_warnings("ignore", category=_PossibleUserWarning): 1475 trainer.fit(model, train_dataloader, val_dataloader) 1476 1477 # Go to best checkpoint 1478 best_checkpoint = trainer.checkpoint_callback.best_model_path 1479 if best_checkpoint != "": 1480 model = _LightningModule.load_from_checkpoint( 1481 trainer.checkpoint_callback.best_model_path, 1482 **_LightningModule.parse_config(model_config), 1483 ) 1484 model.cpu() 1485 model.eval() 1486 model.net.sample_rate = sample_rate # Hack, part 2 1487 1488 def window_kwargs(version: _Version): 1489 if version.major == 1: 1490 return dict( 1491 window_start=100_000, # Start of the plotting window, in samples 1492 window_end=101_000, # End of the plotting window, in samples 1493 ) 1494 elif version.major == 2: 1495 # Same validation set even though it's a different spot in the reamp file 1496 return dict( 1497 window_start=100_000, # Start of the plotting window, in samples 1498 window_end=101_000, # End of the plotting window, in samples 1499 ) 1500 # Fallback: 1501 return dict( 1502 window_start=100_000, # Start of the plotting window, in samples 1503 window_end=101_000, # End of the plotting window, in samples 1504 ) 1505 1506 validation_esr = _plot( 1507 model, 1508 val_dataloader.dataset, 1509 filepath=train_path + "/" + modelname if save_plot else None, 1510 silent=silent, 1511 **window_kwargs(input_version), 1512 ) 1513 return TrainOutput( 1514 model=model, 1515 metadata=_metadata.TrainingMetadata( 1516 settings=settings_metadata, 1517 data=data_metadata, 1518 validation_esr=validation_esr, 1519 ), 1520 ) 1521 1522 1523 class DataInputValidation(_BaseModel): 1524 passed: bool 1525 1526 1527 def validate_input(input_path) -> DataInputValidation: 1528 """ 1529 :return: Could it be validated? 1530 """ 1531 try: 1532 _detect_input_version(input_path) 1533 # succeeded... 1534 return DataInputValidation(passed=True) 1535 except _InputValidationError as e: 1536 print(f"Input validation failed!\n\n{e}") 1537 return DataInputValidation(passed=False) 1538 1539 1540 class _PyTorchDataSplitValidation(_BaseModel): 1541 """ 1542 :param msg: On exception, catch and assign. Otherwise None 1543 """ 1544 1545 passed: bool 1546 msg: _Optional[str] 1547 1548 1549 class _PyTorchDataValidation(_BaseModel): 1550 passed: bool 1551 train: _PyTorchDataSplitValidation # cf Split.TRAIN 1552 validation: _PyTorchDataSplitValidation # Split.VALIDATION 1553 1554 1555 class _SampleRateValidation(_BaseModel): 1556 passed: bool 1557 input: int 1558 output: int 1559 1560 1561 class _LengthValidation(_BaseModel): 1562 passed: bool 1563 delta_seconds: float 1564 1565 1566 class DataValidationOutput(_BaseModel): 1567 passed: bool 1568 passed_critical: bool 1569 sample_rate: _SampleRateValidation 1570 length: _LengthValidation 1571 input_version: str 1572 latency: _metadata.Latency 1573 checks: _metadata.DataChecks 1574 pytorch: _PyTorchDataValidation 1575 1576 1577 def _check_audio_sample_rates( 1578 input_path: _Path, 1579 output_path: _Path, 1580 ) -> _SampleRateValidation: 1581 _, x_info = _wav_to_np(input_path, info=True) 1582 _, y_info = _wav_to_np(output_path, info=True) 1583 1584 return _SampleRateValidation( 1585 passed=x_info.rate == y_info.rate, 1586 input=x_info.rate, 1587 output=y_info.rate, 1588 ) 1589 1590 1591 def _check_audio_lengths( 1592 input_path: _Path, 1593 output_path: _Path, 1594 max_under_seconds: _Optional[float] = 0.0, 1595 max_over_seconds: _Optional[float] = 1.0, 1596 ) -> _LengthValidation: 1597 """ 1598 Check that the input and output have the right lengths compared to each 1599 other. 1600 1601 :param input_path: Path to input audio 1602 :param output_path: Path to output audio 1603 :param max_under_seconds: If not None, the maximum amount by which the 1604 output can be shorter than the input. Should be non-negative i.e. a 1605 value of 1.0 means that the output can't be more than a second shorter 1606 than the input. 1607 :param max_over_seconds: If not None, the maximum amount by which the 1608 output can be longer than the input. Should be non-negative i.e. a 1609 value of 1.0 means that the output can't be more than a second longer 1610 than the input. 1611 """ 1612 x, x_info = _wav_to_np(input_path, info=True) 1613 y, y_info = _wav_to_np(output_path, info=True) 1614 1615 length_input = len(x) / x_info.rate 1616 length_output = len(y) / y_info.rate 1617 delta_seconds = length_output - length_input 1618 1619 passed = True 1620 if max_under_seconds is not None and delta_seconds < -max_under_seconds: 1621 passed = False 1622 if max_over_seconds is not None and delta_seconds > max_over_seconds: 1623 passed = False 1624 1625 return _LengthValidation(passed=passed, delta_seconds=delta_seconds) 1626 1627 1628 def validate_data( 1629 input_path: _Path, 1630 output_path: _Path, 1631 user_latency: _Optional[int], 1632 num_output_samples_per_datum: int = _NY_DEFAULT, 1633 ): 1634 """ 1635 Just do the checks to make sure that the data are ok. 1636 1637 * Version identification 1638 * Latency calibration 1639 * Other checks 1640 """ 1641 print("Validating data...") 1642 passed = True # Until proven otherwise 1643 passed_critical = True # These can't be ignored 1644 1645 sample_rate_validation = _check_audio_sample_rates(input_path, output_path) 1646 passed = passed and sample_rate_validation.passed 1647 passed_critical = passed_critical and sample_rate_validation.passed 1648 1649 length_validation = _check_audio_lengths(input_path, output_path) 1650 passed = passed and length_validation.passed 1651 passed_critical = passed_critical and length_validation.passed 1652 1653 # Data version ID 1654 input_version, strong_match = _detect_input_version(input_path) 1655 1656 # Latency analysis 1657 latency_analysis = _analyze_latency( 1658 user_latency, input_version, input_path, output_path, silent=True 1659 ) 1660 if latency_analysis.manual is None and any( 1661 val for val in latency_analysis.calibration.warnings.model_dump().values() 1662 ): 1663 passed = False 1664 final_latency = _get_final_latency(latency_analysis) 1665 1666 # Other data checks based on input file version 1667 data_checks = _check_data( 1668 input_path, 1669 output_path, 1670 input_version, 1671 latency_analysis.calibration.recommended, 1672 silent=True, 1673 ) 1674 passed = passed and data_checks.passed 1675 1676 # Finally, try to make the PyTorch Dataset objects and note any failures: 1677 data_config = _get_data_config( 1678 input_version=input_version, 1679 input_path=input_path, 1680 output_path=output_path, 1681 ny=num_output_samples_per_datum, 1682 latency=final_latency, 1683 ) 1684 # HACK this should depend on the model that's going to be used, but I think it will 1685 # be unlikely to make a difference. Still, would be nice to fix. 1686 data_config["common"]["nx"] = 4096 1687 1688 pytorch_data_split_validation_dict: _Dict[str, _PyTorchDataSplitValidation] = {} 1689 for split in _Split: 1690 try: 1691 _init_dataset(data_config, split) 1692 pytorch_data_split_validation_dict[split.value] = ( 1693 _PyTorchDataSplitValidation(passed=True, msg=None) 1694 ) 1695 except _DataError as e: 1696 pytorch_data_split_validation_dict[split.value] = ( 1697 _PyTorchDataSplitValidation(passed=False, msg=str(e)) 1698 ) 1699 pytorch_data_validation = _PyTorchDataValidation( 1700 passed=all(v.passed for v in pytorch_data_split_validation_dict.values()), 1701 **pytorch_data_split_validation_dict, 1702 ) 1703 passed = passed and pytorch_data_validation.passed 1704 passed_critical = passed_critical and pytorch_data_validation.passed 1705 1706 return DataValidationOutput( 1707 passed=passed, 1708 passed_critical=passed_critical, 1709 sample_rate=sample_rate_validation, 1710 length=length_validation, 1711 input_version=str(input_version), 1712 latency=latency_analysis, 1713 checks=data_checks, 1714 pytorch=pytorch_data_validation, 1715 )