commit 686fd33adb9e033515931178fab3c211df6e6f00
parent 75b6b14808f0bf7b99f8b3659e27e9cd5f334a38
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 5 Mar 2023 20:01:50 -0600
Fix graceful shutdowns for Lightning in Windows (#118)
* Recover graceful shutdowns, a few other improvements
* Checkpoints
* Recover graceful shutdowns elsewhere for GUI trainer
Diffstat:
3 files changed, 92 insertions(+), 26 deletions(-)
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -2,9 +2,22 @@
# Created Date: Saturday February 5th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
+# Hack to recover graceful shutdowns in Windows.
+# This has to happen ASAP
+# See:
+# https://github.com/sdatkinson/neural-amp-modeler/issues/105
+# https://stackoverflow.com/a/44822794
+def _ensure_graceful_shutdowns():
+ import os
+
+ if os.name == "nt": # OS is Windows
+ os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
+
+
+_ensure_graceful_shutdowns()
+
import json
from argparse import ArgumentParser
-from datetime import datetime
from pathlib import Path
from time import time
from typing import Optional, Union
@@ -18,15 +31,11 @@ from torch.utils.data import DataLoader
from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset
from nam.models import Model
+from nam.util import timestamp
torch.manual_seed(0)
-def timestamp() -> str:
- t = datetime.now()
- return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}"
-
-
def ensure_outdir(outdir: str) -> Path:
outdir = Path(outdir, timestamp())
outdir.mkdir(parents=True, exist_ok=False)
@@ -72,7 +81,7 @@ def plot(
return
with torch.no_grad():
tx = len(ds.x) / 48_000
- print(f"Run (t={tx})")
+ print(f"Run (t={tx:.2f})")
t0 = time()
args = (ds.vals, ds.x) if isinstance(ds, ParametricDataset) else (ds.x,)
output = model(*args).flatten().cpu().numpy()
@@ -100,12 +109,17 @@ def _create_callbacks(learning_config):
"""
# Checkpoints should be run every time the validation check is run.
# So base it off of learning_config["trainer"]["val_check_interval"] if it's there.
- if "val_check_interval" in learning_config["trainer"]:
+ validate_inside_epoch = "val_check_interval" in learning_config["trainer"]
+ if validate_inside_epoch:
kwargs = {
"every_n_train_steps": learning_config["trainer"]["val_check_interval"]
}
else:
- kwargs = {"every_n_epochs": learning_config["trainer"].get("check_val_every_n_epoch", 1)}
+ kwargs = {
+ "every_n_epochs": learning_config["trainer"].get(
+ "check_val_every_n_epoch", 1
+ )
+ }
checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}",
@@ -113,10 +127,20 @@ def _create_callbacks(learning_config):
monitor="val_loss",
**kwargs,
)
- checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="checkpoint_last_{epoch:04d}_{step}", **kwargs
+
+ # return [checkpoint_best, checkpoint_last]
+ # The last epoch that was finished.
+ checkpoint_epoch = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1
)
- return [checkpoint_best, checkpoint_last]
+ if not validate_inside_epoch:
+ return [checkpoint_best, checkpoint_epoch]
+ else:
+ # The last validation pass, whether at the end of an epoch or not
+ checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_last_{epoch:04d}_{step}", **kwargs
+ )
+ return [checkpoint_best, checkpoint_last, checkpoint_epoch]
def main(args):
@@ -128,6 +152,12 @@ def main(args):
model_config = json.load(fp)
with open(args.learning_config_path, "r") as fp:
learning_config = json.load(fp)
+ main_inner(data_config, model_config, learning_config, outdir, args.no_show)
+
+
+def main_inner(
+ data_config, model_config, learning_config, outdir, no_show, make_plots=True
+):
# Write
for basename, config in (
("data", data_config),
@@ -141,7 +171,9 @@ def main(args):
# Add receptive field to data config:
data_config["common"] = data_config.get("common", {})
if "nx" in data_config["common"]:
- warn(f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}")
+ warn(
+ f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}"
+ )
data_config["common"]["nx"] = model.net.receptive_field
dataset_train = init_dataset(data_config, Split.TRAIN)
@@ -170,15 +202,16 @@ def main(args):
**Model.parse_config(model_config),
)
model.eval()
- plot(
- model,
- dataset_validation,
- savefig=Path(outdir, "comparison.png"),
- window_start=100_000,
- window_end=110_000,
- show=False,
- )
- plot(model, dataset_validation, show=not args.no_show)
+ if make_plots:
+ plot(
+ model,
+ dataset_validation,
+ savefig=Path(outdir, "comparison.png"),
+ window_start=100_000,
+ window_end=110_000,
+ show=False,
+ )
+ plot(model, dataset_validation, show=not no_show)
if __name__ == "__main__":
diff --git a/nam/__init__.py b/nam/__init__.py
@@ -2,10 +2,24 @@
# File Created: Tuesday, 2nd February 2021 9:42:50 pm
# Author: Steven Atkinson (steven@atkinson.mn)
+# Hack to recover graceful shutdowns in Windows.
+# This has to happen ASAP
+# See:
+# https://github.com/sdatkinson/neural-amp-modeler/issues/105
+# https://stackoverflow.com/a/44822794
+def _ensure_graceful_shutdowns():
+ import os
+
+ if os.name == "nt": # OS is Windows
+ os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
+
+
+_ensure_graceful_shutdowns()
+
from ._version import __version__ # Must be before models or else circular
from . import _core # noqa F401
from . import data # noqa F401
from . import models # noqa F401
-from . import train # noqa F401
from . import util # noqa F401
+from . import train # noqa F401
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -6,9 +6,24 @@
GUI for training
Usage:
->>> import nam.train.gui
+>>> from nam.train.gui import run
+>>> run()
"""
+# Hack to recover graceful shutdowns in Windows.
+# This has to happen ASAP
+# See:
+# https://github.com/sdatkinson/neural-amp-modeler/issues/105
+# https://stackoverflow.com/a/44822794
+def _ensure_graceful_shutdowns():
+ import os
+
+ if os.name == "nt": # OS is Windows
+ os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
+
+
+_ensure_graceful_shutdowns()
+
import tkinter as tk
from dataclasses import dataclass
from enum import Enum
@@ -17,8 +32,8 @@ from tkinter import filedialog
from typing import Callable, Optional, Sequence
try:
- from .. import __version__
- from . import core
+ from nam import __version__
+ from nam.train import core
_install_is_valid = True
except ImportError:
@@ -455,3 +470,7 @@ def run():
_gui.mainloop()
else:
_install_error()
+
+
+if __name__ == "__main__":
+ run()