neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 686fd33adb9e033515931178fab3c211df6e6f00
parent 75b6b14808f0bf7b99f8b3659e27e9cd5f334a38
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun,  5 Mar 2023 20:01:50 -0600

Fix graceful shutdowns for Lightning in Windows (#118)

* Recover graceful shutdowns, a few other improvements

* Checkpoints

* Recover graceful shutdowns elsewhere for GUI trainer
Diffstat:
Mbin/train/main.py | 77+++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
Mnam/__init__.py | 16+++++++++++++++-
Mnam/train/gui.py | 25++++++++++++++++++++++---
3 files changed, 92 insertions(+), 26 deletions(-)

diff --git a/bin/train/main.py b/bin/train/main.py @@ -2,9 +2,22 @@ # Created Date: Saturday February 5th 2022 # Author: Steven Atkinson (steven@atkinson.mn) +# Hack to recover graceful shutdowns in Windows. +# This has to happen ASAP +# See: +# https://github.com/sdatkinson/neural-amp-modeler/issues/105 +# https://stackoverflow.com/a/44822794 +def _ensure_graceful_shutdowns(): + import os + + if os.name == "nt": # OS is Windows + os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1" + + +_ensure_graceful_shutdowns() + import json from argparse import ArgumentParser -from datetime import datetime from pathlib import Path from time import time from typing import Optional, Union @@ -18,15 +31,11 @@ from torch.utils.data import DataLoader from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset from nam.models import Model +from nam.util import timestamp torch.manual_seed(0) -def timestamp() -> str: - t = datetime.now() - return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}" - - def ensure_outdir(outdir: str) -> Path: outdir = Path(outdir, timestamp()) outdir.mkdir(parents=True, exist_ok=False) @@ -72,7 +81,7 @@ def plot( return with torch.no_grad(): tx = len(ds.x) / 48_000 - print(f"Run (t={tx})") + print(f"Run (t={tx:.2f})") t0 = time() args = (ds.vals, ds.x) if isinstance(ds, ParametricDataset) else (ds.x,) output = model(*args).flatten().cpu().numpy() @@ -100,12 +109,17 @@ def _create_callbacks(learning_config): """ # Checkpoints should be run every time the validation check is run. # So base it off of learning_config["trainer"]["val_check_interval"] if it's there. - if "val_check_interval" in learning_config["trainer"]: + validate_inside_epoch = "val_check_interval" in learning_config["trainer"] + if validate_inside_epoch: kwargs = { "every_n_train_steps": learning_config["trainer"]["val_check_interval"] } else: - kwargs = {"every_n_epochs": learning_config["trainer"].get("check_val_every_n_epoch", 1)} + kwargs = { + "every_n_epochs": learning_config["trainer"].get( + "check_val_every_n_epoch", 1 + ) + } checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint( filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}", @@ -113,10 +127,20 @@ def _create_callbacks(learning_config): monitor="val_loss", **kwargs, ) - checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_last_{epoch:04d}_{step}", **kwargs + + # return [checkpoint_best, checkpoint_last] + # The last epoch that was finished. + checkpoint_epoch = pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1 ) - return [checkpoint_best, checkpoint_last] + if not validate_inside_epoch: + return [checkpoint_best, checkpoint_epoch] + else: + # The last validation pass, whether at the end of an epoch or not + checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_last_{epoch:04d}_{step}", **kwargs + ) + return [checkpoint_best, checkpoint_last, checkpoint_epoch] def main(args): @@ -128,6 +152,12 @@ def main(args): model_config = json.load(fp) with open(args.learning_config_path, "r") as fp: learning_config = json.load(fp) + main_inner(data_config, model_config, learning_config, outdir, args.no_show) + + +def main_inner( + data_config, model_config, learning_config, outdir, no_show, make_plots=True +): # Write for basename, config in ( ("data", data_config), @@ -141,7 +171,9 @@ def main(args): # Add receptive field to data config: data_config["common"] = data_config.get("common", {}) if "nx" in data_config["common"]: - warn(f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}") + warn( + f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}" + ) data_config["common"]["nx"] = model.net.receptive_field dataset_train = init_dataset(data_config, Split.TRAIN) @@ -170,15 +202,16 @@ def main(args): **Model.parse_config(model_config), ) model.eval() - plot( - model, - dataset_validation, - savefig=Path(outdir, "comparison.png"), - window_start=100_000, - window_end=110_000, - show=False, - ) - plot(model, dataset_validation, show=not args.no_show) + if make_plots: + plot( + model, + dataset_validation, + savefig=Path(outdir, "comparison.png"), + window_start=100_000, + window_end=110_000, + show=False, + ) + plot(model, dataset_validation, show=not no_show) if __name__ == "__main__": diff --git a/nam/__init__.py b/nam/__init__.py @@ -2,10 +2,24 @@ # File Created: Tuesday, 2nd February 2021 9:42:50 pm # Author: Steven Atkinson (steven@atkinson.mn) +# Hack to recover graceful shutdowns in Windows. +# This has to happen ASAP +# See: +# https://github.com/sdatkinson/neural-amp-modeler/issues/105 +# https://stackoverflow.com/a/44822794 +def _ensure_graceful_shutdowns(): + import os + + if os.name == "nt": # OS is Windows + os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1" + + +_ensure_graceful_shutdowns() + from ._version import __version__ # Must be before models or else circular from . import _core # noqa F401 from . import data # noqa F401 from . import models # noqa F401 -from . import train # noqa F401 from . import util # noqa F401 +from . import train # noqa F401 diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -6,9 +6,24 @@ GUI for training Usage: ->>> import nam.train.gui +>>> from nam.train.gui import run +>>> run() """ +# Hack to recover graceful shutdowns in Windows. +# This has to happen ASAP +# See: +# https://github.com/sdatkinson/neural-amp-modeler/issues/105 +# https://stackoverflow.com/a/44822794 +def _ensure_graceful_shutdowns(): + import os + + if os.name == "nt": # OS is Windows + os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1" + + +_ensure_graceful_shutdowns() + import tkinter as tk from dataclasses import dataclass from enum import Enum @@ -17,8 +32,8 @@ from tkinter import filedialog from typing import Callable, Optional, Sequence try: - from .. import __version__ - from . import core + from nam import __version__ + from nam.train import core _install_is_valid = True except ImportError: @@ -455,3 +470,7 @@ def run(): _gui.mainloop() else: _install_error() + + +if __name__ == "__main__": + run()