commit 0ee6fd6c3a0c918035156dc9a6c54bee8d9470bb
parent c241c3e9f70f94047f9c4b98d00d94ea1f6d1721
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 13 Apr 2024 18:05:39 -0700
[FEATURE] GUI: Resume training from checkpoint (#402)
* Add special error for loading with incompatible checkpoint
* Support training from checkpoint
* Error modal when training fails because of checkpoint
* Remove straggling mainloop comment
Diffstat:
3 files changed, 229 insertions(+), 68 deletions(-)
diff --git a/nam/train/_errors.py b/nam/train/_errors.py
@@ -0,0 +1,18 @@
+# File: _errors.py
+# Created Date: Saturday April 13th 2024
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+"""
+"What could go wrong?"
+"""
+
+__all__ = ["IncompatibleCheckpointError"]
+
+
+class IncompatibleCheckpointError(RuntimeError):
+ """
+ Raised when model loading fails because the checkpoint didn't match the model
+ or its hyperparameters
+ """
+
+ pass
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -11,6 +11,7 @@ import tkinter as tk
from copy import deepcopy
from enum import Enum
from functools import partial
+from pathlib import Path
from time import time
from typing import Dict, Optional, Sequence, Tuple, Union
@@ -26,6 +27,7 @@ from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.losses import esr
from ..util import filter_warnings
+from ._errors import IncompatibleCheckpointError
from ._version import PROTEUS_VERSION, Version
__all__ = ["train"]
@@ -868,6 +870,7 @@ def _get_configs(
lr_decay: float,
batch_size: int,
fit_cab: bool,
+ checkpoint: Optional[Path] = None,
):
def get_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
@@ -957,6 +960,8 @@ def _get_configs(
if fit_cab:
model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
+ if checkpoint:
+ model_config["checkpoint_path"] = checkpoint
if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
@@ -1130,6 +1135,7 @@ def train(
local: bool = False,
fit_cab: bool = False,
threshold_esr: Optional[bool] = None,
+ checkpoint: Optional[Path] = None,
) -> Optional[Model]:
"""
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
@@ -1178,6 +1184,7 @@ def train(
lr_decay,
batch_size,
fit_cab,
+ checkpoint=checkpoint,
)
print("Starting training. It's time to kick ass and chew bubblegum!")
@@ -1186,7 +1193,16 @@ def train(
# * Model is re-instantiated after training anyways.
# (Hacky) solution: set sample rate in model from dataloader after second
# instantiation from final checkpoint.
- model = Model.init_from_config(model_config)
+ try:
+ model = Model.init_from_config(model_config)
+ except RuntimeError as e:
+ if "Error(s) in loading state_dict for Model:" in str(e):
+ raise IncompatibleCheckpointError(
+ "Model initialization failed; the checkpoint used seems to be "
+ f"incompatible.\n\nOriginal error:\n\n{e}"
+ )
+ else:
+ raise e
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
@@ -1204,7 +1220,8 @@ def train(
)
# Suppress the PossibleUserWarning about num_workers (Issue 345)
with filter_warnings("ignore", category=PossibleUserWarning):
- trainer.fit(model, train_dataloader, val_dataloader)
+ trainer_fit_kwargs = {} if checkpoint is None else {"ckpt_path": checkpoint}
+ trainer.fit(model, train_dataloader, val_dataloader, **trainer_fit_kwargs)
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -43,6 +43,7 @@ try: # 3rd-party and 1st-party imports
from nam.models.metadata import GearType, UserMetadata, ToneType
# Ok private access here--this is technically allowed access
+ from nam.train._errors import IncompatibleCheckpointError
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
_install_is_valid = True
@@ -66,6 +67,7 @@ _TEXT_WIDTH = 70
_DEFAULT_DELAY = None
_DEFAULT_IGNORE_CHECKS = False
_DEFAULT_THRESHOLD_ESR = None
+_DEFAULT_CHECKPOINT = None
_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
@@ -82,6 +84,7 @@ class _AdvancedOptions(object):
:param ignore_checks: Keep going even if a check says that something is wrong.
:param threshold_esr: Stop training if the ESR gets better than this. If None, don't
stop.
+ :param checkpoint: If provided, try to restart from this checkpoint.
"""
architecture: core.Architecture
@@ -89,6 +92,7 @@ class _AdvancedOptions(object):
latency: Optional[int]
ignore_checks: bool
threshold_esr: Optional[float]
+ checkpoint: Optional[Path]
class _PathType(Enum):
@@ -109,13 +113,19 @@ class _PathButton(object):
info_str: str,
path_type: _PathType,
hooks: Optional[Sequence[Callable[[], None]]] = None,
+ color_when_not_set: str = "#EF0000", # Darker red
+ default: Optional[Path] = None,
):
+ """
+ :param hooks: Callables run at the end of setting the value.
+ """
self._button_text = button_text
self._info_str = info_str
- self._path: Optional[Path] = None
+ self._path: Optional[Path] = default
self._path_type = path_type
self._frame = frame
- self._button = tk.Button(
+ self._widgets = {}
+ self._widgets["button"] = tk.Button(
self._frame,
text=button_text,
width=_BUTTON_WIDTH,
@@ -123,8 +133,8 @@ class _PathButton(object):
fg="black",
command=self._set_val,
)
- self._button.pack(side=tk.LEFT)
- self._label = tk.Label(
+ self._widgets["button"].pack(side=tk.LEFT)
+ self._widgets["label"] = tk.Label(
self._frame,
width=_TEXT_WIDTH,
height=_BUTTON_HEIGHT,
@@ -132,23 +142,38 @@ class _PathButton(object):
bg=None,
anchor="w",
)
- self._label.pack(side=tk.LEFT)
+ self._widgets["label"].pack(side=tk.LEFT)
self._hooks = hooks
+ self._color_when_not_set = color_when_not_set
self._set_text()
+ def __setitem__(self, key, val):
+ """
+ Implement tk-style setter for state
+ """
+ if key == "state":
+ for widget in self._widgets.values():
+ widget["state"] = val
+ else:
+ raise RuntimeError(
+ f"{self.__class__.__name__} instance does not support item assignment for non-state key {key}!"
+ )
+
@property
def val(self) -> Optional[Path]:
return self._path
def _set_text(self):
if self._path is None:
- self._label["fg"] = "#EF0000" # Darker red
- self._label["text"] = self._info_str
+ self._widgets["label"]["fg"] = self._color_when_not_set
+ self._widgets["label"]["text"] = self._info_str
else:
val = self.val
val = val[0] if isinstance(val, tuple) and len(val) == 1 else val
- self._label["fg"] = "black"
- self._label["text"] = f"{self._button_text.capitalize()} set to {val}"
+ self._widgets["label"]["fg"] = "black"
+ self._widgets["label"][
+ "text"
+ ] = f"{self._button_text.capitalize()} set to {val}"
def _set_val(self):
res = {
@@ -169,7 +194,7 @@ class _InputPathButton(_PathButton):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Download the training file!
- self._button_download_input = tk.Button(
+ self._widgets["button_download_input"] = tk.Button(
self._frame,
text="Download input file",
width=_BUTTON_WIDTH,
@@ -177,7 +202,7 @@ class _InputPathButton(_PathButton):
fg="black",
command=self._download_input_file,
)
- self._button_download_input.pack(side=tk.RIGHT)
+ self._widgets["button_download_input"].pack(side=tk.RIGHT)
@classmethod
def _download_input_file(cls):
@@ -201,6 +226,29 @@ class _InputPathButton(_PathButton):
return
+class _ClearablePathButton(_PathButton):
+ """
+ Can clear a path
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, color_when_not_set="black", **kwargs)
+ # Download the training file!
+ self._widgets["button_clear"] = tk.Button(
+ self._frame,
+ text="Clear",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._clear_path,
+ )
+ self._widgets["button_clear"].pack(side=tk.RIGHT)
+
+ def _clear_path(self):
+ self._path = None
+ self._set_text()
+
+
class _CheckboxKeys(Enum):
"""
Keys for checkboxes
@@ -212,15 +260,50 @@ class _CheckboxKeys(Enum):
IGNORE_DATA_CHECKS = "ignore_data_checks"
+class _BasicModal(object):
+ """
+ Message and OK button
+ """
+
+ def __init__(self, resume_main, msg: str):
+ self._root = tk.Toplevel()
+ self._text = tk.Label(self._root, text=msg)
+ self._text.pack()
+ self._ok = tk.Button(
+ self._root,
+ text="Ok",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._close,
+ )
+ self._ok.pack()
+ self._resume_main = resume_main
+
+ def _close(self):
+ self._root.destroy()
+ self._resume_main()
+
+
+class _GUIWidgets(Enum):
+ INPUT_PATH = "input_path"
+ OUTPUT_PATH = "output_path"
+ TRAINING_DESTINATION = "training_destination"
+ METADATA = "metadata"
+ ADVANCED_OPTIONS = "advanced_options"
+ TRAIN = "train"
+
+
class _GUI(object):
def __init__(self):
self._root = tk.Tk()
self._root.title(f"NAM Trainer - v{__version__}")
+ self._widgets = {}
# Buttons for paths:
self._frame_input = tk.Frame(self._root)
self._frame_input.pack(anchor="w")
- self._path_button_input = _InputPathButton(
+ self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton(
self._frame_input,
"Input Audio",
f"Select input (DI) file (e.g. {LATEST_VERSION.name})",
@@ -230,7 +313,7 @@ class _GUI(object):
self._frame_output_path = tk.Frame(self._root)
self._frame_output_path.pack(anchor="w")
- self._path_button_output = _PathButton(
+ self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton(
self._frame_output_path,
"Output Audio",
"Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)",
@@ -240,7 +323,7 @@ class _GUI(object):
self._frame_train_destination = tk.Frame(self._root)
self._frame_train_destination.pack(anchor="w")
- self._path_button_train_destination = _PathButton(
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton(
self._frame_train_destination,
"Train Destination",
"Select training output directory",
@@ -252,7 +335,7 @@ class _GUI(object):
self.user_metadata = UserMetadata()
self._frame_metadata = tk.Frame(self._root)
self._frame_metadata.pack(anchor="w")
- self._button_metadata = tk.Button(
+ self._widgets["metadata"] = tk.Button(
self._frame_metadata,
text="Metadata...",
width=_BUTTON_WIDTH,
@@ -260,7 +343,7 @@ class _GUI(object):
fg="black",
command=self._open_metadata,
)
- self._button_metadata.pack()
+ self._widgets["metadata"].pack()
self.user_metadata_flag = False
# This should probably be to the right somewhere
@@ -281,10 +364,11 @@ class _GUI(object):
_DEFAULT_DELAY,
_DEFAULT_IGNORE_CHECKS,
_DEFAULT_THRESHOLD_ESR,
+ _DEFAULT_CHECKPOINT,
)
# Window to edit them:
- self._button_advanced_options = tk.Button(
+ self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = tk.Button(
self._frame_advanced_options,
text="Advanced options...",
width=_BUTTON_WIDTH,
@@ -292,11 +376,11 @@ class _GUI(object):
fg="black",
command=self._open_advanced_options,
)
- self._button_advanced_options.pack()
+ self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack()
# Train button
- self._button_train = tk.Button(
+ self._widgets[_GUIWidgets.TRAIN] = tk.Button(
self._frame_train,
text="Train",
width=_BUTTON_WIDTH,
@@ -304,7 +388,7 @@ class _GUI(object):
fg="black",
command=self._train,
)
- self._button_train.pack()
+ self._widgets[_GUIWidgets.TRAIN].pack()
self._check_button_states()
@@ -316,14 +400,14 @@ class _GUI(object):
if any(
pb.val is None
for pb in (
- self._path_button_input,
- self._path_button_output,
- self._path_button_train_destination,
+ self._widgets[_GUIWidgets.INPUT_PATH],
+ self._widgets[_GUIWidgets.OUTPUT_PATH],
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION],
)
):
- self._button_train["state"] = tk.DISABLED
+ self._widgets[_GUIWidgets.TRAIN]["state"] = tk.DISABLED
return
- self._button_train["state"] = tk.NORMAL
+ self._widgets[_GUIWidgets.TRAIN]["state"] = tk.NORMAL
def _get_additional_options_frame(self):
# Checkboxes
@@ -346,6 +430,7 @@ class _GUI(object):
self._frame_checkboxes, text=text, variable=variable
)
self._checkboxes[key] = Checkbox(variable, check_button)
+ self._widgets[key] = check_button # For tracking in set-all-widgets ops
self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False)
@@ -370,30 +455,39 @@ class _GUI(object):
def mainloop(self):
self._root.mainloop()
+ def _disable(self):
+ self._set_all_widget_states_to(tk.DISABLED)
+
def _open_advanced_options(self):
"""
- Open advanced options
+ Open window for advanced options
"""
- ao = _AdvancedOptionsGUI(self)
- # I should probably disable the main GUI...
- ao.mainloop()
- # ...and then re-enable it once it gets closed.
+
+ self._wait_while_func(lambda resume: _AdvancedOptionsGUI(resume, self))
def _open_metadata(self):
"""
- Open dialog for metadata
+ Open window for metadata
"""
- mdata = _UserMetadataGUI(self)
- # I should probably disable the main GUI...
- mdata.mainloop()
+
+ self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self))
+
+ def _resume(self):
+ self._set_all_widget_states_to(tk.NORMAL)
+ self._check_button_states()
+
+ def _set_all_widget_states_to(self, state):
+ for widget in self._widgets.values():
+ widget["state"] = state
def _train(self):
# Advanced options:
num_epochs = self.advanced_options.num_epochs
architecture = self.advanced_options.architecture
delay = self.advanced_options.latency
- file_list = self._path_button_output.val
+ file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
threshold_esr = self.advanced_options.threshold_esr
+ checkpoint = self.advanced_options.checkpoint
# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
@@ -408,33 +502,43 @@ class _GUI(object):
print("Now training {}".format(file))
basename = re.sub(r"\.wav$", "", file.split("/")[-1])
- trained_model = core.train(
- self._path_button_input.val,
- file,
- self._path_button_train_destination.val,
- epochs=num_epochs,
- delay=delay,
- architecture=architecture,
- batch_size=batch_size,
- lr=lr,
- lr_decay=lr_decay,
- seed=seed,
- silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
- save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
- modelname=basename,
- ignore_checks=self._checkboxes[
- _CheckboxKeys.IGNORE_DATA_CHECKS
- ].variable.get(),
- local=True,
- fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
- threshold_esr=threshold_esr,
- )
+ try:
+ trained_model = core.train(
+ self._widgets[_GUIWidgets.INPUT_PATH].val,
+ file,
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
+ epochs=num_epochs,
+ delay=delay,
+ architecture=architecture,
+ batch_size=batch_size,
+ lr=lr,
+ lr_decay=lr_decay,
+ seed=seed,
+ silent=self._checkboxes[
+ _CheckboxKeys.SILENT_TRAINING
+ ].variable.get(),
+ save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
+ modelname=basename,
+ ignore_checks=self._checkboxes[
+ _CheckboxKeys.IGNORE_DATA_CHECKS
+ ].variable.get(),
+ local=True,
+ fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
+ threshold_esr=threshold_esr,
+ checkpoint=checkpoint,
+ )
+ except IncompatibleCheckpointError as e:
+ trained_model = None
+ self._wait_while_func(
+ _BasicModal, "Training failed due to incompatible checkpoint!"
+ )
+
if trained_model is None:
print("Model training failed! Skip exporting...")
continue
print("Model training complete!")
print("Exporting...")
- outdir = self._path_button_train_destination.val
+ outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val
print(f"Exporting trained model to {outdir}...")
trained_model.net.export(
outdir,
@@ -449,6 +553,15 @@ class _GUI(object):
# the user re-visits the window and clicks "ok"
self.user_metadata_flag = False
+ def _wait_while_func(self, func, *args, **kwargs):
+ """
+ Disable this GUI while something happens.
+ That function _needs_ to call the provided self._resume when it's ready to
+ release me!
+ """
+ self._disable()
+ func(self._resume, *args, **kwargs)
+
# some typing functions
def _non_negative_int(val):
@@ -594,9 +707,10 @@ class _AdvancedOptionsGUI(object):
A window to hold advanced options (Architecture and number of epochs)
"""
- def __init__(self, parent: _GUI):
+ def __init__(self, resume_main, parent: _GUI):
+ self._resume_main = resume_main
self._parent = parent
- self._root = tk.Tk()
+ self._root = tk.Toplevel()
self._root.title("Advanced Options")
# Architecture: radio buttons
@@ -641,6 +755,17 @@ class _AdvancedOptionsGUI(object):
type=_float_or_null,
)
+ # Restart from a checkpoint
+ self._frame_checkpoint = tk.Frame(self._root)
+ self._frame_checkpoint.pack()
+ self._path_button_checkpoint = _ClearablePathButton(
+ self._frame_checkpoint,
+ "Checkpoint",
+ "[Optional] Select a checkpoint (.ckpt file) to restart training from",
+ _PathType.FILE,
+ default=self._parent.advanced_options.checkpoint,
+ )
+
# "Ok": apply and destory
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
@@ -654,9 +779,6 @@ class _AdvancedOptionsGUI(object):
)
self._button_ok.pack()
- def mainloop(self):
- self._root.mainloop()
-
def _apply_and_destroy(self):
"""
Set values to parent and destroy this object
@@ -676,14 +798,20 @@ class _AdvancedOptionsGUI(object):
self._parent.advanced_options.threshold_esr = (
None if threshold_esr == "null" else threshold_esr
)
+ checkpoint_path = self._path_button_checkpoint.val
+ self._parent.advanced_options.checkpoint = (
+ None if checkpoint_path is None else Path(checkpoint_path)
+ )
self._root.destroy()
+ self._resume_main()
class _UserMetadataGUI(object):
# Things that are auto-filled:
# Model date
# gain
- def __init__(self, parent: _GUI):
+ def __init__(self, resume_main, parent: _GUI):
+ self._resume_main = resume_main
self._parent = parent
self._root = tk.Tk()
self._root.title("Metadata")
@@ -758,9 +886,6 @@ class _UserMetadataGUI(object):
)
self._button_ok.pack()
- def mainloop(self):
- self._root.mainloop()
-
def _apply_and_destroy(self):
"""
Set values to parent and destroy this object
@@ -774,6 +899,7 @@ class _UserMetadataGUI(object):
self._parent.user_metadata_flag = True
self._root.destroy()
+ self._resume_main()
def _install_error():