colab.py (4889B)
1 # File: colab.py 2 # Created Date: Sunday December 4th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 Hide the mess in Colab to make things look pretty for users. 7 """ 8 9 from pathlib import Path as _Path 10 from typing import Optional as _Optional, Tuple as _Tuple 11 12 from ..models.metadata import UserMetadata as _UserMetadata 13 from ._names import ( 14 INPUT_BASENAMES as _INPUT_BASENAMES, 15 LATEST_VERSION as _LATEST_VERSION, 16 Version as _Version, 17 ) 18 from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version 19 from .core import TrainOutput as _TrainOutput, train as _train 20 from .metadata import TRAINING_KEY as _TRAINING_KEY 21 22 _BUGGY_INPUT_BASENAMES = { 23 # 1.1.0 has the spikes at the wrong spots. 24 "v1_1_0.wav" 25 } 26 _OUTPUT_BASENAME = "output.wav" 27 _TRAIN_PATH = "." 28 29 30 def _check_for_files() -> _Tuple[_Version, str]: 31 # TODO use hash logic as in GUI trainer! 32 print("Checking that we have all of the required audio files...") 33 for name in _BUGGY_INPUT_BASENAMES: 34 if _Path(name).exists(): 35 raise RuntimeError( 36 f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}" 37 ) 38 for input_version, input_basename, other_names in _INPUT_BASENAMES: 39 if _Path(input_basename).exists(): 40 if input_version == _PROTEUS_VERSION: 41 print(f"Using Proteus input file...") 42 elif input_version != _LATEST_VERSION.version: 43 print( 44 f"WARNING: Using out-of-date input file {input_basename}. " 45 "Recommend downloading and using the latest version, " 46 f"{_LATEST_VERSION.name}." 47 ) 48 break 49 if other_names is not None: 50 for other_name in other_names: 51 if _Path(other_name).exists(): 52 raise RuntimeError( 53 f"Found out-of-date input file {other_name}. Rename it to {input_basename} and re-run." 54 ) 55 else: 56 raise FileNotFoundError( 57 f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION.name}" 58 ) 59 # We found it 60 if not _Path(_OUTPUT_BASENAME).exists(): 61 raise FileNotFoundError( 62 f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}." 63 ) 64 if input_version != _PROTEUS_VERSION: 65 print(f"Found {input_basename}, version {input_version}") 66 else: 67 print(f"Found Proteus input {input_basename}.") 68 return input_version, input_basename 69 70 71 def _get_valid_export_directory(): 72 def get_path(version): 73 return _Path("exported_models", f"version_{version}") 74 75 version = 0 76 while get_path(version).exists(): 77 version += 1 78 return get_path(version) 79 80 81 def run( 82 epochs: int = 100, 83 delay: _Optional[int] = None, 84 model_type: str = "WaveNet", 85 architecture: str = "standard", 86 lr: float = 0.004, 87 lr_decay: float = 0.007, 88 seed: _Optional[int] = 0, 89 user_metadata: _Optional[_UserMetadata] = None, 90 ignore_checks: bool = False, 91 fit_mrstft: bool = True, 92 ): 93 """ 94 :param epochs: How many epochs we'll train for. 95 :param delay: How far the output algs the input due to round-trip latency during 96 reamping, in samples. 97 :param stage_1_channels: The number of channels in the WaveNet's first stage. 98 :param stage_2_channels: The number of channels in the WaveNet's second stage. 99 :param lr: The initial learning rate 100 :param lr_decay: The amount by which the learning rate decays each epoch 101 :param seed: RNG seed for reproducibility. 102 :param user_metadata: User-specified metadata to include in the .nam file. 103 :param ignore_checks: Ignores the data quality checks and YOLOs it 104 """ 105 106 input_version, input_basename = _check_for_files() 107 108 train_output: _TrainOutput = _train( 109 input_basename, 110 _OUTPUT_BASENAME, 111 _TRAIN_PATH, 112 input_version=input_version, 113 epochs=epochs, 114 latency=delay, 115 model_type=model_type, 116 architecture=architecture, 117 lr=lr, 118 lr_decay=lr_decay, 119 seed=seed, 120 local=False, 121 ignore_checks=ignore_checks, 122 fit_mrstft=fit_mrstft, 123 ) 124 model = train_output.model 125 training_metadata = train_output.metadata 126 127 if model is None: 128 print("No model returned; skip exporting!") 129 else: 130 print("Exporting your model...") 131 model_export_outdir = _get_valid_export_directory() 132 model_export_outdir.mkdir(parents=True, exist_ok=False) 133 model.net.export( 134 model_export_outdir, 135 user_metadata=user_metadata, 136 other_metadata={_TRAINING_KEY: training_metadata.model_dump()}, 137 ) 138 print(f"Model exported to {model_export_outdir}. Enjoy!")