neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

colab.py (4889B)


      1 # File: colab.py
      2 # Created Date: Sunday December 4th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 Hide the mess in Colab to make things look pretty for users.
      7 """
      8 
      9 from pathlib import Path as _Path
     10 from typing import Optional as _Optional, Tuple as _Tuple
     11 
     12 from ..models.metadata import UserMetadata as _UserMetadata
     13 from ._names import (
     14     INPUT_BASENAMES as _INPUT_BASENAMES,
     15     LATEST_VERSION as _LATEST_VERSION,
     16     Version as _Version,
     17 )
     18 from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version
     19 from .core import TrainOutput as _TrainOutput, train as _train
     20 from .metadata import TRAINING_KEY as _TRAINING_KEY
     21 
     22 _BUGGY_INPUT_BASENAMES = {
     23     # 1.1.0 has the spikes at the wrong spots.
     24     "v1_1_0.wav"
     25 }
     26 _OUTPUT_BASENAME = "output.wav"
     27 _TRAIN_PATH = "."
     28 
     29 
     30 def _check_for_files() -> _Tuple[_Version, str]:
     31     # TODO use hash logic as in GUI trainer!
     32     print("Checking that we have all of the required audio files...")
     33     for name in _BUGGY_INPUT_BASENAMES:
     34         if _Path(name).exists():
     35             raise RuntimeError(
     36                 f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}"
     37             )
     38     for input_version, input_basename, other_names in _INPUT_BASENAMES:
     39         if _Path(input_basename).exists():
     40             if input_version == _PROTEUS_VERSION:
     41                 print(f"Using Proteus input file...")
     42             elif input_version != _LATEST_VERSION.version:
     43                 print(
     44                     f"WARNING: Using out-of-date input file {input_basename}. "
     45                     "Recommend downloading and using the latest version, "
     46                     f"{_LATEST_VERSION.name}."
     47                 )
     48             break
     49         if other_names is not None:
     50             for other_name in other_names:
     51                 if _Path(other_name).exists():
     52                     raise RuntimeError(
     53                         f"Found out-of-date input file {other_name}. Rename it to {input_basename} and re-run."
     54                     )
     55     else:
     56         raise FileNotFoundError(
     57             f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION.name}"
     58         )
     59     # We found it
     60     if not _Path(_OUTPUT_BASENAME).exists():
     61         raise FileNotFoundError(
     62             f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}."
     63         )
     64     if input_version != _PROTEUS_VERSION:
     65         print(f"Found {input_basename}, version {input_version}")
     66     else:
     67         print(f"Found Proteus input {input_basename}.")
     68     return input_version, input_basename
     69 
     70 
     71 def _get_valid_export_directory():
     72     def get_path(version):
     73         return _Path("exported_models", f"version_{version}")
     74 
     75     version = 0
     76     while get_path(version).exists():
     77         version += 1
     78     return get_path(version)
     79 
     80 
     81 def run(
     82     epochs: int = 100,
     83     delay: _Optional[int] = None,
     84     model_type: str = "WaveNet",
     85     architecture: str = "standard",
     86     lr: float = 0.004,
     87     lr_decay: float = 0.007,
     88     seed: _Optional[int] = 0,
     89     user_metadata: _Optional[_UserMetadata] = None,
     90     ignore_checks: bool = False,
     91     fit_mrstft: bool = True,
     92 ):
     93     """
     94     :param epochs: How many epochs we'll train for.
     95     :param delay: How far the output algs the input due to round-trip latency during
     96         reamping, in samples.
     97     :param stage_1_channels: The number of channels in the WaveNet's first stage.
     98     :param stage_2_channels: The number of channels in the WaveNet's second stage.
     99     :param lr: The initial learning rate
    100     :param lr_decay: The amount by which the learning rate decays each epoch
    101     :param seed: RNG seed for reproducibility.
    102     :param user_metadata: User-specified metadata to include in the .nam file.
    103     :param ignore_checks: Ignores the data quality checks and YOLOs it
    104     """
    105 
    106     input_version, input_basename = _check_for_files()
    107 
    108     train_output: _TrainOutput = _train(
    109         input_basename,
    110         _OUTPUT_BASENAME,
    111         _TRAIN_PATH,
    112         input_version=input_version,
    113         epochs=epochs,
    114         latency=delay,
    115         model_type=model_type,
    116         architecture=architecture,
    117         lr=lr,
    118         lr_decay=lr_decay,
    119         seed=seed,
    120         local=False,
    121         ignore_checks=ignore_checks,
    122         fit_mrstft=fit_mrstft,
    123     )
    124     model = train_output.model
    125     training_metadata = train_output.metadata
    126 
    127     if model is None:
    128         print("No model returned; skip exporting!")
    129     else:
    130         print("Exporting your model...")
    131         model_export_outdir = _get_valid_export_directory()
    132         model_export_outdir.mkdir(parents=True, exist_ok=False)
    133         model.net.export(
    134             model_export_outdir,
    135             user_metadata=user_metadata,
    136             other_metadata={_TRAINING_KEY: training_metadata.model_dump()},
    137         )
    138         print(f"Model exported to {model_export_outdir}. Enjoy!")