neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 633330f6bb4739a8ce1d6fb1067aca3a02375612
parent e28321d49aa0f2cb37b5d14fda3e0d5a09ac129a
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 27 Jul 2024 14:06:00 -0700

[FEATURE,BREAKING] Define full-featured trainer as a CLI entry point (#445)

* feat: make the CLI available without cloning the repo

* fix: remove accidental output folder

* chore: bump version

* fix: revert standalone CLI with pip

* feat: add cli command to “console_scripts”

* Revisions

* remove bin/train/main.py
* Move config files
* Update docs

* Remove bin/train/main.py
Move bin/train/easy_colab.ipynb to colab.ipynb

* Fix bin test

* Some reorganizing

* More rearranging and cleanup of full trainer

---------

Co-authored-by: Eraz1997 <eraz1997@live.it>
Co-authored-by: Enrico Schifano <enrs@bendingspoons.com>
Diffstat:
Dbin/train/main.py | 235-------------------------------------------------------------------------------
Rbin/train/easy_colab.ipynb -> colab.ipynb | 0
Mdocs/source/installation.rst | 2++
Mdocs/source/tutorials/colab.rst | 6+++---
Ddocs/source/tutorials/command-line.rst | 110-------------------------------------------------------------------------------
Adocs/source/tutorials/full.rst | 89+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mdocs/source/tutorials/main.rst | 4++--
Anam/cli.py | 112+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Anam/train/full.py | 198+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mnam/train/gui/__init__.py | 66------------------------------------------------------------------
Rbin/train/inputs/data/single_pair.json -> nam_full_configs/data/single_pair.json | 0
Rbin/train/inputs/data/two_pairs.json -> nam_full_configs/data/two_pairs.json | 0
Rbin/train/inputs/learning/default.json -> nam_full_configs/learning/default.json | 0
Rbin/train/inputs/learning/demo.json -> nam_full_configs/learning/demo.json | 0
Rbin/train/inputs/models/convnet.json -> nam_full_configs/models/convnet.json | 0
Rbin/train/inputs/models/demonet.json -> nam_full_configs/models/demonet.json | 0
Rbin/train/inputs/models/lstm.json -> nam_full_configs/models/lstm.json | 0
Rbin/train/inputs/models/wavenet.json -> nam_full_configs/models/wavenet.json | 0
Msetup.py | 3++-
Mtests/test_bin/test_train/test_main.py | 7+------
Atests/test_nam/test_cli.py | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mtests/test_nam/test_train/test_gui.py | 60------------------------------------------------------------
22 files changed, 480 insertions(+), 483 deletions(-)

diff --git a/bin/train/main.py b/bin/train/main.py @@ -1,235 +0,0 @@ -# File: train.py -# Created Date: Saturday February 5th 2022 -# Author: Steven Atkinson (steven@atkinson.mn) - - -# Hack to recover graceful shutdowns in Windows. -# This has to happen ASAP -# See: -# https://github.com/sdatkinson/neural-amp-modeler/issues/105 -# https://stackoverflow.com/a/44822794 -def _ensure_graceful_shutdowns(): - import os - - if os.name == "nt": # OS is Windows - os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1" - - -_ensure_graceful_shutdowns() - -import json -from argparse import ArgumentParser -from pathlib import Path -from time import time -from typing import Optional, Union -from warnings import warn - -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -from pytorch_lightning.utilities.warnings import PossibleUserWarning -import torch -from torch.utils.data import DataLoader - -from nam.data import ConcatDataset, Split, init_dataset -from nam.models import Model -from nam.util import filter_warnings, timestamp - -torch.manual_seed(0) - - -def ensure_outdir(outdir: str) -> Path: - outdir = Path(outdir, timestamp()) - outdir.mkdir(parents=True, exist_ok=False) - return outdir - - -def _rms(x: Union[np.ndarray, torch.Tensor]) -> float: - if isinstance(x, np.ndarray): - return np.sqrt(np.mean(np.square(x))) - elif isinstance(x, torch.Tensor): - return torch.sqrt(torch.mean(torch.square(x))).item() - else: - raise TypeError(type(x)) - - -def plot( - model, - ds, - savefig=None, - show=True, - window_start: Optional[int] = None, - window_end: Optional[int] = None, -): - if isinstance(ds, ConcatDataset): - - def extend_savefig(i, savefig): - if savefig is None: - return None - savefig = Path(savefig) - extension = savefig.name.split(".")[-1] - stem = savefig.name[: -len(extension) - 1] - return Path(savefig.parent, f"{stem}_{i}.{extension}") - - for i, ds_i in enumerate(ds.datasets): - plot( - model, - ds_i, - savefig=extend_savefig(i, savefig), - show=show and i == len(ds.datasets) - 1, - window_start=window_start, - window_end=window_end, - ) - return - with torch.no_grad(): - tx = len(ds.x) / 48_000 - print(f"Run (t={tx:.2f})") - t0 = time() - output = model(ds.x).flatten().cpu().numpy() - t1 = time() - try: - rt = f"{tx / (t1 - t0):.2f}" - except ZeroDivisionError as e: - rt = "???" - print(f"Took {t1 - t0:.2f} ({rt}x)") - - plt.figure(figsize=(16, 5)) - plt.plot(output[window_start:window_end], label="Prediction") - plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") - nrmse = _rms(torch.Tensor(output) - ds.y) / _rms(ds.y) - esr = nrmse**2 - plt.title(f"ESR={esr:.3f}") - plt.legend() - if savefig is not None: - plt.savefig(savefig) - if show: - plt.show() - - -def _create_callbacks(learning_config): - """ - Checkpointing, essentially - """ - # Checkpoints should be run every time the validation check is run. - # So base it off of learning_config["trainer"]["val_check_interval"] if it's there. - validate_inside_epoch = "val_check_interval" in learning_config["trainer"] - if validate_inside_epoch: - kwargs = { - "every_n_train_steps": learning_config["trainer"]["val_check_interval"] - } - else: - kwargs = { - "every_n_epochs": learning_config["trainer"].get( - "check_val_every_n_epoch", 1 - ) - } - - checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}", - save_top_k=3, - monitor="val_loss", - **kwargs, - ) - - # return [checkpoint_best, checkpoint_last] - # The last epoch that was finished. - checkpoint_epoch = pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1 - ) - if not validate_inside_epoch: - return [checkpoint_best, checkpoint_epoch] - else: - # The last validation pass, whether at the end of an epoch or not - checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_last_{epoch:04d}_{step}", **kwargs - ) - return [checkpoint_best, checkpoint_last, checkpoint_epoch] - - -def main(args): - outdir = ensure_outdir(args.outdir) - # Read - with open(args.data_config_path, "r") as fp: - data_config = json.load(fp) - with open(args.model_config_path, "r") as fp: - model_config = json.load(fp) - with open(args.learning_config_path, "r") as fp: - learning_config = json.load(fp) - main_inner(data_config, model_config, learning_config, outdir, args.no_show) - - -def main_inner( - data_config, model_config, learning_config, outdir, no_show, make_plots=True -): - # Write - for basename, config in ( - ("data", data_config), - ("model", model_config), - ("learning", learning_config), - ): - with open(Path(outdir, f"config_{basename}.json"), "w") as fp: - json.dump(config, fp, indent=4) - - model = Model.init_from_config(model_config) - # Add receptive field to data config: - data_config["common"] = data_config.get("common", {}) - if "nx" in data_config["common"]: - warn( - f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}" - ) - data_config["common"]["nx"] = model.net.receptive_field - - dataset_train = init_dataset(data_config, Split.TRAIN) - dataset_validation = init_dataset(data_config, Split.VALIDATION) - if dataset_train.sample_rate != dataset_validation.sample_rate: - raise RuntimeError( - "Train and validation data loaders have different data set sample rates: " - f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}" - ) - model.net.sample_rate = dataset_train.sample_rate - train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) - val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) - - trainer = pl.Trainer( - callbacks=_create_callbacks(learning_config), - default_root_dir=outdir, - **learning_config["trainer"], - ) - with filter_warnings("ignore", category=PossibleUserWarning): - trainer.fit( - model, - train_dataloader, - val_dataloader, - **learning_config.get("trainer_fit_kwargs", {}), - ) - # Go to best checkpoint - best_checkpoint = trainer.checkpoint_callback.best_model_path - if best_checkpoint != "": - model = Model.load_from_checkpoint( - trainer.checkpoint_callback.best_model_path, - **Model.parse_config(model_config), - ) - model.cpu() - model.eval() - if make_plots: - plot( - model, - dataset_validation, - savefig=Path(outdir, "comparison.png"), - window_start=100_000, - window_end=110_000, - show=False, - ) - plot(model, dataset_validation, show=not no_show) - # Export! - model.net.export(outdir) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("data_config_path", type=str) - parser.add_argument("model_config_path", type=str) - parser.add_argument("learning_config_path", type=str) - parser.add_argument("outdir") - parser.add_argument("--no-show", action="store_true", help="Don't show plots") - main(parser.parse_args()) diff --git a/bin/train/easy_colab.ipynb b/colab.ipynb diff --git a/docs/source/installation.rst b/docs/source/installation.rst @@ -1,3 +1,5 @@ +.. _installation: + Local Installation ================== diff --git a/docs/source/tutorials/colab.rst b/docs/source/tutorials/colab.rst @@ -3,7 +3,7 @@ Training in the cloud with Google Colab If you don't have a good computer for training ML models, you use Google Colab to train in the cloud using the pre-made Jupyter notebook at -`bin/train/easy_colab.ipynb <https://github.com/sdatkinson/neural-amp-modeler/blob/main/bin/train/easy_colab.ipynb>`_, +`colab.ipynb <https://github.com/sdatkinson/neural-amp-modeler/blob/main/colab.ipynb>`_, which is designed to be used with `Google Colab <https://colab.research.google.com/>`_. @@ -11,10 +11,10 @@ Opening the notebook -------------------- To open the notebook in Colab, follow -`this link <https://colab.research.google.com/github/sdatkinson/neural-amp-modeler/blob/d248a71/bin/train/easy_colab.ipynb>`_. +`this link <https://colab.research.google.com/github/sdatkinson/neural-amp-modeler/blob/d248a71/blob/main/colab.ipynb>`_. .. note:: Most browsers work, but Firefox can be a bit temperamental. This isn't - NAM's fault; Colab just prefers Chrome (unsurprisingly). + NAM's fault; Google Colab just prefers Chrome (unsurprisingly). You'll be met with a screen like this: diff --git a/docs/source/tutorials/command-line.rst b/docs/source/tutorials/command-line.rst @@ -1,110 +0,0 @@ -Training locally from the command line -====================================== - -The command line trainer is the full-featured option for training models with -NAM. - -Installation ------------- - -Currently, you'll want to clone the source repo to train from the command line. - -Installation uses `Anaconda <https://www.anaconda.com/>`_ for package management. - -For computers with a CUDA-capable GPU (recommended): - -.. code-block:: console - - conda env create -f environment_gpu.yml - -.. note:: You may need to modify the CUDA version if your GPU is older. Have a - look at - `nVIDIA's documentation <https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions>`_ - if you're not sure. - -Otherwise, for a CPU-only install (will train much more slowly): - -.. code-block:: console - - conda env create -f environment_cpu.yml - -.. note:: If Anaconda takes a long time "`Solving environment...`", then you can - speed up installing the environment by using the mamba experimental sovler - with ``--experimental-solver=libmamba``. - -Then activate the environment you've created with - -.. code-block:: console - - conda activate nam - -Training --------- - -Since the command-line trainer is intended for maximum flexibiility, you can -train from any input/output pair of reamp files you want. However, if you want -to skip the reamping and use some pre-made files for your first time, you can -download these files: - -* `v1_1_1.wav <https://drive.google.com/file/d/1CMj2uv_x8GIs-3X1reo7squHOVfkOa6s/view?usp=drive_link>`_ - (input) -* `output.wav <https://drive.google.com/file/d/1e0pDzsWgtqBU87NGqa-4FbriDCkccg3q/view?usp=drive_link>`_ - (output) - -Next, edit ``bin/train/data/single_pair.json`` to point to relevant audio files: - -.. code-block:: json - - "common": { - "x_path": "C:\\path\\to\\v1_1_1.wav", - "y_path": "C:\\path\\to\\output.wav", - "delay": 0 - } - -.. note:: If you're providing your own audio files, then you need to provide - the latency (in samples) between the input and output file. A positive - number of samples means that the output lags the input by the provided - number of samples; a negative value means that the output `precedes` the - input (e.g. because your DAW over-compensated). If you're not sure exactly - how much latency there is, it's usually a good idea to add a few samples - just so that the model doesn't need to predict the future! - -Next, to train, open up a terminal. Activate your nam environment and call the -training with - -.. code-block:: console - - python bin/train/main.py \ - bin/train/inputs/data/single_pair.json \ - bin/train/inputs/models/demonet.json \ - bin/train/inputs/learning/demo.json \ - bin/train/outputs/MyAmp - -* ``data/single_pair.json`` contains the information about the data you're - training on. -* ``models/demonet.json`` contains information about the model architecture that - is being trained. The example used here uses a `feather` configured `wavenet`. -* ``learning/demo.json`` contains information about the training run itself - (e.g. number of epochs). - -The configuration above runs a short (demo) training. For a real training you -may prefer to run something like: - -.. code-block:: console - - python bin/train/main.py \ - bin/train/inputs/data/single_pair.json \ - bin/train/inputs/models/wavenet.json \ - bin/train/inputs/learning/default.json \ - bin/train/outputs/MyAmp - -.. note:: NAM uses - `PyTorch Lightning <https://lightning.ai/pages/open-source/>`_ - under the hood as a modeling framework, and you can control many of the - PyTorch Lightning configuration options from - ``bin/train/inputs/learning/default.json``. - -Once training is done, a file called ``model.nam`` is created in the output -directory. To use it, point -`the plugin <https://github.com/sdatkinson/NeuralAmpModelerPlugin>`_ at the file -and you're good to go! diff --git a/docs/source/tutorials/full.rst b/docs/source/tutorials/full.rst @@ -0,0 +1,89 @@ +Training locally with the full-featured NAM +=========================================== + +The command line trainer is the full-featured option for training models with +NAM. To start, you'll want to follow the installation instructions here at +:ref:`installation`. + +After completing this, you will be able to use the full-featured NAM trainer by +typing + +.. code-block:: console + + $ nam-full + +from the command line. + +Training +-------- + +Training uses three configuration files to specify: + +1. What data you're training with: (``nam_full_configs/data/``), +2. What model architecture you're using (``nam_full_configs/models/``), and +3. Details of the learning algorithm model (``nam_full_configs/learning/``). + +To train a model of your own gear, you'll need to have a paired input/output +signal from it (either by reamping a pre-recorded test signal or by +simultaneously recording your DI and the effected tone). For your first time, +you can download the following pre-made files: + +* `v1_1_1.wav <https://drive.google.com/file/d/1CMj2uv_x8GIs-3X1reo7squHOVfkOa6s/view?usp=drive_link>`_ + (input) +* `output.wav <https://drive.google.com/file/d/1e0pDzsWgtqBU87NGqa-4FbriDCkccg3q/view?usp=drive_link>`_ + (output) + +Next, make a file called e.g. ``data.json`` by copying +`nam_full_configs/data/single_pair.json <https://github.com/sdatkinson/neural-amp-modeler/blob/main/nam_full_configs/data/single_pair.json>`_ +and editing it to point to your audio files like this: + +.. code-block:: json + + "common": { + "x_path": "C:\\path\\to\\v1_1_1.wav", + "y_path": "C:\\path\\to\\output.wav", + "delay": 0 + } + +.. note:: If you're providing your own audio files, then you need to provide + the latency (in samples) between the input and output file. A positive + number of samples means that the output lags the input by the provided + number of samples; a negative value means that the output `precedes` the + input (e.g. because your DAW over-compensated). If you're not sure exactly + how much latency there is, it's usually a good idea to add a few samples + just so that the model doesn't need to "predict the future"! + +Next, copy to e.g. ``model.json`` a file for whicever model architecture you want to +use (e.g. +`nam_full_configs/models/wavenet.json <https://github.com/sdatkinson/neural-amp-modeler/blob/main/nam_full_configs/models/wavenet.json>`_ +for the standard WaveNet from the simplified trainers), and copy to e.g. +``learning.json`` the contents of +`nam_full_configs/learning/demo.json <https://github.com/sdatkinson/neural-amp-modeler/blob/main/nam_full_configs/learning/demo.json>`_ +(for a quick demo run) or +`default.json <https://github.com/sdatkinson/neural-amp-modeler/blob/main/nam_full_configs/learning/default.json>`_ +(for something more like a normal use case). + +Next, to train, open up a terminal. Activate your ``nam`` environment and call +the training script with + +.. code-block:: console + + nam-full \ + path/to/data.json \ + path/to/model.json \ + path/to/learning.json \ + path/to/outputs + +where the first three input paths are where you saved for files, and you choose +the final output path to save your training results where you'd like. + +.. note:: NAM uses + `PyTorch Lightning <https://lightning.ai/pages/open-source/>`_ + under the hood as a modeling framework, and you can control many of the + PyTorch Lightning configuration options from + ``nam_full_configs/learning/default.json``. + +Once training is done, a file called ``model.nam`` is created in the output +directory. To use it, point +`the plugin <https://github.com/sdatkinson/NeuralAmpModelerPlugin>`_ at the file +and you're good to go! diff --git a/docs/source/tutorials/main.rst b/docs/source/tutorials/main.rst @@ -6,4 +6,4 @@ Tutorials colab gui - command-line -\ No newline at end of file + full +\ No newline at end of file diff --git a/nam/cli.py b/nam/cli.py @@ -0,0 +1,112 @@ +# File: cli.py +# Created Date: Saturday July 27th 2024 +# Author: Steven Atkinson (steven@atkinson.mn) + +""" +Command line interface entry points (GUI trainer, full trainer) +""" + + +# This must happen first +def _ensure_graceful_shutdowns(): + """ + Hack to recover graceful shutdowns in Windows. + This has to happen ASAP + See: + https://github.com/sdatkinson/neural-amp-modeler/issues/105 + https://stackoverflow.com/a/44822794 + """ + import os + + if os.name == "nt": # OS is Windows + os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1" + + +_ensure_graceful_shutdowns() + + +# This must happen ASAP but not before the graceful shutdown hack +def _apply_extensions(): + """ + Find and apply extensions to NAM + """ + + def removesuffix(s: str, suffix: str) -> str: + # Remove once 3.8 is dropped + if len(suffix) == 0: + return s + return s[: -len(suffix)] if s.endswith(suffix) else s + + import importlib + import os + import sys + + # DRY: Make sure this matches the test! + home_path = os.environ["HOMEPATH"] if os.name == "nt" else os.environ["HOME"] + extensions_path = os.path.join(home_path, ".neural-amp-modeler", "extensions") + if not os.path.exists(extensions_path): + return + if not os.path.isdir(extensions_path): + print( + f"WARNING: non-directory object found at expected extensions path {extensions_path}; skip" + ) + print("Applying extensions...") + if extensions_path not in sys.path: + sys.path.append(extensions_path) + extensions_path_not_in_sys_path = True + else: + extensions_path_not_in_sys_path = False + for name in os.listdir(extensions_path): + if name in {"__pycache__", ".DS_Store"}: + continue + try: + importlib.import_module(removesuffix(name, ".py")) # Runs it + print(f" {name} [SUCCESS]") + except Exception as e: + print(f" {name} [FAILED]") + print(e) + if extensions_path_not_in_sys_path: + for i, p in enumerate(sys.path): + if p == extensions_path: + sys.path = sys.path[:i] + sys.path[i + 1 :] + break + else: + raise RuntimeError("Failed to remove extensions path from sys.path?") + print("Done!") + + +_apply_extensions() + +import json +from argparse import ArgumentParser +from pathlib import Path + +from nam.train.full import main as _nam_full +from nam.train.gui import run as nam_gui # noqa F401 Used as an entry point +from nam.util import timestamp + + +def nam_full(): + parser = ArgumentParser() + parser.add_argument("data_config_path", type=str) + parser.add_argument("model_config_path", type=str) + parser.add_argument("learning_config_path", type=str) + parser.add_argument("outdir") + parser.add_argument("--no-show", action="store_true", help="Don't show plots") + + args = parser.parse_args() + + def ensure_outdir(outdir: str) -> Path: + outdir = Path(outdir, timestamp()) + outdir.mkdir(parents=True, exist_ok=False) + return outdir + + outdir = ensure_outdir(args.outdir) + # Read + with open(args.data_config_path, "r") as fp: + data_config = json.load(fp) + with open(args.model_config_path, "r") as fp: + model_config = json.load(fp) + with open(args.learning_config_path, "r") as fp: + learning_config = json.load(fp) + _nam_full(data_config, model_config, learning_config, outdir, args.no_show) diff --git a/nam/train/full.py b/nam/train/full.py @@ -0,0 +1,198 @@ +# File: full.py +# Created Date: Tuesday March 26th 2024 +# Author: Enrico Schifano (eraz1997@live.it) + +import json +from pathlib import Path +from time import time +from typing import Optional, Union +from warnings import warn + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +from pytorch_lightning.utilities.warnings import PossibleUserWarning +import torch +from torch.utils.data import DataLoader + +from nam.data import ConcatDataset, Split, init_dataset +from nam.models import Model +from nam.util import filter_warnings + +torch.manual_seed(0) + + +def _rms(x: Union[np.ndarray, torch.Tensor]) -> float: + if isinstance(x, np.ndarray): + return np.sqrt(np.mean(np.square(x))) + elif isinstance(x, torch.Tensor): + return torch.sqrt(torch.mean(torch.square(x))).item() + else: + raise TypeError(type(x)) + + +def _plot( + model, + ds, + savefig=None, + show=True, + window_start: Optional[int] = None, + window_end: Optional[int] = None, +): + if isinstance(ds, ConcatDataset): + + def extend_savefig(i, savefig): + if savefig is None: + return None + savefig = Path(savefig) + extension = savefig.name.split(".")[-1] + stem = savefig.name[: -len(extension) - 1] + return Path(savefig.parent, f"{stem}_{i}.{extension}") + + for i, ds_i in enumerate(ds.datasets): + _plot( + model, + ds_i, + savefig=extend_savefig(i, savefig), + show=show and i == len(ds.datasets) - 1, + window_start=window_start, + window_end=window_end, + ) + return + with torch.no_grad(): + tx = len(ds.x) / 48_000 + print(f"Run (t={tx:.2f})") + t0 = time() + output = model(ds.x).flatten().cpu().numpy() + t1 = time() + try: + rt = f"{tx / (t1 - t0):.2f}" + except ZeroDivisionError as e: + rt = "???" + print(f"Took {t1 - t0:.2f} ({rt}x)") + + plt.figure(figsize=(16, 5)) + plt.plot(output[window_start:window_end], label="Prediction") + plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") + nrmse = _rms(torch.Tensor(output) - ds.y) / _rms(ds.y) + esr = nrmse**2 + plt.title(f"ESR={esr:.3f}") + plt.legend() + if savefig is not None: + plt.savefig(savefig) + if show: + plt.show() + + +def _create_callbacks(learning_config): + """ + Checkpointing, essentially + """ + # Checkpoints should be run every time the validation check is run. + # So base it off of learning_config["trainer"]["val_check_interval"] if it's there. + validate_inside_epoch = "val_check_interval" in learning_config["trainer"] + if validate_inside_epoch: + kwargs = { + "every_n_train_steps": learning_config["trainer"]["val_check_interval"] + } + else: + kwargs = { + "every_n_epochs": learning_config["trainer"].get( + "check_val_every_n_epoch", 1 + ) + } + + checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}", + save_top_k=3, + monitor="val_loss", + **kwargs, + ) + + # return [checkpoint_best, checkpoint_last] + # The last epoch that was finished. + checkpoint_epoch = pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1 + ) + if not validate_inside_epoch: + return [checkpoint_best, checkpoint_epoch] + else: + # The last validation pass, whether at the end of an epoch or not + checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_last_{epoch:04d}_{step}", **kwargs + ) + return [checkpoint_best, checkpoint_last, checkpoint_epoch] + + +def main( + data_config, + model_config, + learning_config, + outdir: Path, + no_show: bool = False, + make_plots=True, +): + if not outdir.exists(): + raise RuntimeError(f"No output location found at {outdir}") + # Write + for basename, config in ( + ("data", data_config), + ("model", model_config), + ("learning", learning_config), + ): + with open(Path(outdir, f"config_{basename}.json"), "w") as fp: + json.dump(config, fp, indent=4) + + model = Model.init_from_config(model_config) + # Add receptive field to data config: + data_config["common"] = data_config.get("common", {}) + if "nx" in data_config["common"]: + warn( + f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}" + ) + data_config["common"]["nx"] = model.net.receptive_field + + dataset_train = init_dataset(data_config, Split.TRAIN) + dataset_validation = init_dataset(data_config, Split.VALIDATION) + if dataset_train.sample_rate != dataset_validation.sample_rate: + raise RuntimeError( + "Train and validation data loaders have different data set sample rates: " + f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}" + ) + model.net.sample_rate = dataset_train.sample_rate + train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + + trainer = pl.Trainer( + callbacks=_create_callbacks(learning_config), + default_root_dir=outdir, + **learning_config["trainer"], + ) + with filter_warnings("ignore", category=PossibleUserWarning): + trainer.fit( + model, + train_dataloader, + val_dataloader, + **learning_config.get("trainer_fit_kwargs", {}), + ) + # Go to best checkpoint + best_checkpoint = trainer.checkpoint_callback.best_model_path + if best_checkpoint != "": + model = Model.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path, + **Model.parse_config(model_config), + ) + model.cpu() + model.eval() + if make_plots: + _plot( + model, + dataset_validation, + savefig=Path(outdir, "comparison.png"), + window_start=100_000, + window_end=110_000, + show=False, + ) + _plot(model, dataset_validation, show=not no_show) + # Export! + model.net.export(outdir) diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -10,72 +10,6 @@ Usage: >>> run() """ - -# Hack to recover graceful shutdowns in Windows. -# This has to happen ASAP -# See: -# https://github.com/sdatkinson/neural-amp-modeler/issues/105 -# https://stackoverflow.com/a/44822794 -def _ensure_graceful_shutdowns(): - import os - - if os.name == "nt": # OS is Windows - os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1" - - -_ensure_graceful_shutdowns() - - -def _apply_extensions(): - def removesuffix(s: str, suffix: str) -> str: - # Remove once 3.8 is dropped - if len(suffix) == 0: - return s - return s[: -len(suffix)] if s.endswith(suffix) else s - - import importlib - import os - import sys - - # DRY: Make sure this matches the test! - home_path = os.environ["HOMEPATH"] if os.name == "nt" else os.environ["HOME"] - extensions_path = os.path.join( - home_path, ".neural-amp-modeler", "extensions" - ) - if not os.path.exists(extensions_path): - return - if not os.path.isdir(extensions_path): - print( - f"WARNING: non-directory object found at expected extensions path {extensions_path}; skip" - ) - print("Applying extensions...") - if extensions_path not in sys.path: - sys.path.append(extensions_path) - extensions_path_not_in_sys_path = True - else: - extensions_path_not_in_sys_path = False - for name in os.listdir(extensions_path): - if name in {"__pycache__", ".DS_Store"}: - continue - try: - importlib.import_module(removesuffix(name, ".py")) # Runs it - print(f" {name} [SUCCESS]") - except Exception as e: - print(f" {name} [FAILED]") - print(e) - if extensions_path_not_in_sys_path: - for i, p in enumerate(sys.path): - if p == extensions_path: - sys.path = sys.path[:i] + sys.path[i + 1 :] - break - else: - raise RuntimeError("Failed to remove extensions path from sys.path?") - print("Done!") - - -_apply_extensions() - - import re import tkinter as tk import sys diff --git a/bin/train/inputs/data/single_pair.json b/nam_full_configs/data/single_pair.json diff --git a/bin/train/inputs/data/two_pairs.json b/nam_full_configs/data/two_pairs.json diff --git a/bin/train/inputs/learning/default.json b/nam_full_configs/learning/default.json diff --git a/bin/train/inputs/learning/demo.json b/nam_full_configs/learning/demo.json diff --git a/bin/train/inputs/models/convnet.json b/nam_full_configs/models/convnet.json diff --git a/bin/train/inputs/models/demonet.json b/nam_full_configs/models/demonet.json diff --git a/bin/train/inputs/models/lstm.json b/nam_full_configs/models/lstm.json diff --git a/bin/train/inputs/models/wavenet.json b/nam_full_configs/models/wavenet.json diff --git a/setup.py b/setup.py @@ -54,7 +54,8 @@ setup( include_package_data=True, entry_points={ "console_scripts": [ - "nam = nam.train.gui:run", + "nam = nam.cli:nam_gui", # GUI trainer + "nam-full = nam.cli:nam_full", # Full-featured trainer ] }, ) diff --git a/tests/test_bin/test_train/test_main.py b/tests/test_bin/test_train/test_main.py @@ -15,10 +15,6 @@ import torch from nam.data import np_to_wav -_BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path( - "bin", "train", "main.py" -) - class _Device(Enum): CPU = "cpu" @@ -181,8 +177,7 @@ class Test(object): self._setup_files(tempdir, device) check_call( [ - "python", - str(_BIN_TRAIN_MAIN_PY_PATH), + "nam-full", # HACK not DRY w/ setup.py str(self._data_config_path(tempdir)), str(self._model_config_path(tempdir)), str(self._learning_config_path(tempdir)), diff --git a/tests/test_nam/test_cli.py b/tests/test_nam/test_cli.py @@ -0,0 +1,71 @@ +# File: test_cli.py +# Created Date: Saturday July 27th 2024 +# Author: Steven Atkinson (steven@atkinson.mn) + +import importlib +import os +from pathlib import Path + +import pytest + + +def test_extensions(): + """ + Test that we can use a simple extension. + """ + # DRY: Make sure this matches the code! + home_path = os.environ["HOMEPATH"] if os.name == "nt" else os.environ["HOME"] + extensions_path = os.path.join(home_path, ".neural-amp-modeler", "extensions") + + def get_name(): + i = 0 + while True: + basename = f"test_extension_{i}.py" + path = Path(extensions_path, basename) + if not path.exists(): + return path + else: + i += 1 + + path = get_name() + path.parent.mkdir(parents=True, exist_ok=True) + + try: + # Make the extension + # It's going to set an attribute inside nam.core. We'll know the extension worked if + # that attr is set. + attr_name = "my_test_attr" + attr_val = "THIS IS A TEST ATTRIBUTE I SHOULDN'T BE HERE" + with open(path, "w") as f: + f.writelines( + [ + 'print("RUNNING TEST!")\n', + "from nam.train import core\n", + f'name = "{attr_name}"\n', + "assert not hasattr(core, name)\n" + f'setattr(core, name, "{attr_val}")\n', + ] + ) + + # Now trigger the extension by importing the CLI module: + from nam import cli + + # If some other test already imported this, then we need to trigger a re-load or + # else the extension won't get picked up! + importlib.reload(cli) + + # Now let's have a look: + from nam.train import core + + assert hasattr(core, attr_name) + assert getattr(core, attr_name) == attr_val + finally: + if path.exists(): + path.unlink() + # You might want to comment that .unlink() and uncomment this if this test isn't + # passing and you're struggling: + # pass + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_nam/test_train/test_gui.py b/tests/test_nam/test_train/test_gui.py @@ -22,65 +22,5 @@ class TestPathButton(object): label.pack() -def test_extensions(): - """ - Test that we can use a simple extension. - """ - # DRY: Make sure this matches the code! - home_path = os.environ["HOMEPATH"] if os.name == "nt" else os.environ["HOME"] - extensions_path = os.path.join( - home_path, ".neural-amp-modeler", "extensions" - ) - - def get_name(): - i = 0 - while True: - basename = f"test_extension_{i}.py" - path = Path(extensions_path, basename) - if not path.exists(): - return path - else: - i += 1 - - path = get_name() - path.parent.mkdir(parents=True, exist_ok=True) - - try: - # Make the extension - # It's going to set an attribute inside nam.core. We'll know the extension worked if - # that attr is set. - attr_name = "my_test_attr" - attr_val = "THIS IS A TEST ATTRIBUTE I SHOULDN'T BE HERE" - with open(path, "w") as f: - f.writelines( - [ - 'print("RUNNING TEST!")\n', - "from nam.train import core\n", - f'name = "{attr_name}"\n', - "assert not hasattr(core, name)\n" - f'setattr(core, name, "{attr_val}")\n', - ] - ) - - # Now trigger the extension by importing from the GUI module: - from nam.train import gui # noqa F401 - - # If some other test already imported this, then we need to trigger a re-load or - # else the extension won't get picked up! - importlib.reload(gui) - - # Now let's have a look: - from nam.train import core - - assert hasattr(core, attr_name) - assert getattr(core, attr_name) == attr_val - finally: - if path.exists(): - path.unlink() - # You might want to comment that .unlink() and uncomment this if this test isn't - # passing and you're struggling: - # pass - - if __name__ == "__main__": pytest.main()