neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 0ee6fd6c3a0c918035156dc9a6c54bee8d9470bb
parent c241c3e9f70f94047f9c4b98d00d94ea1f6d1721
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 13 Apr 2024 18:05:39 -0700

[FEATURE] GUI: Resume training from checkpoint (#402)

* Add special error for loading with incompatible checkpoint

* Support training from checkpoint

* Error modal when training fails because of checkpoint

* Remove straggling mainloop comment
Diffstat:
Anam/train/_errors.py | 18++++++++++++++++++
Mnam/train/core.py | 21+++++++++++++++++++--
Mnam/train/gui.py | 258+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
3 files changed, 229 insertions(+), 68 deletions(-)

diff --git a/nam/train/_errors.py b/nam/train/_errors.py @@ -0,0 +1,18 @@ +# File: _errors.py +# Created Date: Saturday April 13th 2024 +# Author: Steven Atkinson (steven@atkinson.mn) + +""" +"What could go wrong?" +""" + +__all__ = ["IncompatibleCheckpointError"] + + +class IncompatibleCheckpointError(RuntimeError): + """ + Raised when model loading fails because the checkpoint didn't match the model + or its hyperparameters + """ + + pass diff --git a/nam/train/core.py b/nam/train/core.py @@ -11,6 +11,7 @@ import tkinter as tk from copy import deepcopy from enum import Enum from functools import partial +from pathlib import Path from time import time from typing import Dict, Optional, Sequence, Tuple, Union @@ -26,6 +27,7 @@ from ..data import Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model from ..models.losses import esr from ..util import filter_warnings +from ._errors import IncompatibleCheckpointError from ._version import PROTEUS_VERSION, Version __all__ = ["train"] @@ -868,6 +870,7 @@ def _get_configs( lr_decay: float, batch_size: int, fit_cab: bool, + checkpoint: Optional[Path] = None, ): def get_kwargs(data_info: _DataInfo): if data_info.major_version == 1: @@ -957,6 +960,8 @@ def _get_configs( if fit_cab: model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF + if checkpoint: + model_config["checkpoint_path"] = checkpoint if torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} @@ -1130,6 +1135,7 @@ def train( local: bool = False, fit_cab: bool = False, threshold_esr: Optional[bool] = None, + checkpoint: Optional[Path] = None, ) -> Optional[Model]: """ :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. @@ -1178,6 +1184,7 @@ def train( lr_decay, batch_size, fit_cab, + checkpoint=checkpoint, ) print("Starting training. It's time to kick ass and chew bubblegum!") @@ -1186,7 +1193,16 @@ def train( # * Model is re-instantiated after training anyways. # (Hacky) solution: set sample rate in model from dataloader after second # instantiation from final checkpoint. - model = Model.init_from_config(model_config) + try: + model = Model.init_from_config(model_config) + except RuntimeError as e: + if "Error(s) in loading state_dict for Model:" in str(e): + raise IncompatibleCheckpointError( + "Model initialization failed; the checkpoint used seems to be " + f"incompatible.\n\nOriginal error:\n\n{e}" + ) + else: + raise e train_dataloader, val_dataloader = _get_dataloaders( data_config, learning_config, model ) @@ -1204,7 +1220,8 @@ def train( ) # Suppress the PossibleUserWarning about num_workers (Issue 345) with filter_warnings("ignore", category=PossibleUserWarning): - trainer.fit(model, train_dataloader, val_dataloader) + trainer_fit_kwargs = {} if checkpoint is None else {"ckpt_path": checkpoint} + trainer.fit(model, train_dataloader, val_dataloader, **trainer_fit_kwargs) # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -43,6 +43,7 @@ try: # 3rd-party and 1st-party imports from nam.models.metadata import GearType, UserMetadata, ToneType # Ok private access here--this is technically allowed access + from nam.train._errors import IncompatibleCheckpointError from nam.train._names import INPUT_BASENAMES, LATEST_VERSION _install_is_valid = True @@ -66,6 +67,7 @@ _TEXT_WIDTH = 70 _DEFAULT_DELAY = None _DEFAULT_IGNORE_CHECKS = False _DEFAULT_THRESHOLD_ESR = None +_DEFAULT_CHECKPOINT = None _ADVANCED_OPTIONS_LEFT_WIDTH = 12 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12 @@ -82,6 +84,7 @@ class _AdvancedOptions(object): :param ignore_checks: Keep going even if a check says that something is wrong. :param threshold_esr: Stop training if the ESR gets better than this. If None, don't stop. + :param checkpoint: If provided, try to restart from this checkpoint. """ architecture: core.Architecture @@ -89,6 +92,7 @@ class _AdvancedOptions(object): latency: Optional[int] ignore_checks: bool threshold_esr: Optional[float] + checkpoint: Optional[Path] class _PathType(Enum): @@ -109,13 +113,19 @@ class _PathButton(object): info_str: str, path_type: _PathType, hooks: Optional[Sequence[Callable[[], None]]] = None, + color_when_not_set: str = "#EF0000", # Darker red + default: Optional[Path] = None, ): + """ + :param hooks: Callables run at the end of setting the value. + """ self._button_text = button_text self._info_str = info_str - self._path: Optional[Path] = None + self._path: Optional[Path] = default self._path_type = path_type self._frame = frame - self._button = tk.Button( + self._widgets = {} + self._widgets["button"] = tk.Button( self._frame, text=button_text, width=_BUTTON_WIDTH, @@ -123,8 +133,8 @@ class _PathButton(object): fg="black", command=self._set_val, ) - self._button.pack(side=tk.LEFT) - self._label = tk.Label( + self._widgets["button"].pack(side=tk.LEFT) + self._widgets["label"] = tk.Label( self._frame, width=_TEXT_WIDTH, height=_BUTTON_HEIGHT, @@ -132,23 +142,38 @@ class _PathButton(object): bg=None, anchor="w", ) - self._label.pack(side=tk.LEFT) + self._widgets["label"].pack(side=tk.LEFT) self._hooks = hooks + self._color_when_not_set = color_when_not_set self._set_text() + def __setitem__(self, key, val): + """ + Implement tk-style setter for state + """ + if key == "state": + for widget in self._widgets.values(): + widget["state"] = val + else: + raise RuntimeError( + f"{self.__class__.__name__} instance does not support item assignment for non-state key {key}!" + ) + @property def val(self) -> Optional[Path]: return self._path def _set_text(self): if self._path is None: - self._label["fg"] = "#EF0000" # Darker red - self._label["text"] = self._info_str + self._widgets["label"]["fg"] = self._color_when_not_set + self._widgets["label"]["text"] = self._info_str else: val = self.val val = val[0] if isinstance(val, tuple) and len(val) == 1 else val - self._label["fg"] = "black" - self._label["text"] = f"{self._button_text.capitalize()} set to {val}" + self._widgets["label"]["fg"] = "black" + self._widgets["label"][ + "text" + ] = f"{self._button_text.capitalize()} set to {val}" def _set_val(self): res = { @@ -169,7 +194,7 @@ class _InputPathButton(_PathButton): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Download the training file! - self._button_download_input = tk.Button( + self._widgets["button_download_input"] = tk.Button( self._frame, text="Download input file", width=_BUTTON_WIDTH, @@ -177,7 +202,7 @@ class _InputPathButton(_PathButton): fg="black", command=self._download_input_file, ) - self._button_download_input.pack(side=tk.RIGHT) + self._widgets["button_download_input"].pack(side=tk.RIGHT) @classmethod def _download_input_file(cls): @@ -201,6 +226,29 @@ class _InputPathButton(_PathButton): return +class _ClearablePathButton(_PathButton): + """ + Can clear a path + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, color_when_not_set="black", **kwargs) + # Download the training file! + self._widgets["button_clear"] = tk.Button( + self._frame, + text="Clear", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._clear_path, + ) + self._widgets["button_clear"].pack(side=tk.RIGHT) + + def _clear_path(self): + self._path = None + self._set_text() + + class _CheckboxKeys(Enum): """ Keys for checkboxes @@ -212,15 +260,50 @@ class _CheckboxKeys(Enum): IGNORE_DATA_CHECKS = "ignore_data_checks" +class _BasicModal(object): + """ + Message and OK button + """ + + def __init__(self, resume_main, msg: str): + self._root = tk.Toplevel() + self._text = tk.Label(self._root, text=msg) + self._text.pack() + self._ok = tk.Button( + self._root, + text="Ok", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._close, + ) + self._ok.pack() + self._resume_main = resume_main + + def _close(self): + self._root.destroy() + self._resume_main() + + +class _GUIWidgets(Enum): + INPUT_PATH = "input_path" + OUTPUT_PATH = "output_path" + TRAINING_DESTINATION = "training_destination" + METADATA = "metadata" + ADVANCED_OPTIONS = "advanced_options" + TRAIN = "train" + + class _GUI(object): def __init__(self): self._root = tk.Tk() self._root.title(f"NAM Trainer - v{__version__}") + self._widgets = {} # Buttons for paths: self._frame_input = tk.Frame(self._root) self._frame_input.pack(anchor="w") - self._path_button_input = _InputPathButton( + self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton( self._frame_input, "Input Audio", f"Select input (DI) file (e.g. {LATEST_VERSION.name})", @@ -230,7 +313,7 @@ class _GUI(object): self._frame_output_path = tk.Frame(self._root) self._frame_output_path.pack(anchor="w") - self._path_button_output = _PathButton( + self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton( self._frame_output_path, "Output Audio", "Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)", @@ -240,7 +323,7 @@ class _GUI(object): self._frame_train_destination = tk.Frame(self._root) self._frame_train_destination.pack(anchor="w") - self._path_button_train_destination = _PathButton( + self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton( self._frame_train_destination, "Train Destination", "Select training output directory", @@ -252,7 +335,7 @@ class _GUI(object): self.user_metadata = UserMetadata() self._frame_metadata = tk.Frame(self._root) self._frame_metadata.pack(anchor="w") - self._button_metadata = tk.Button( + self._widgets["metadata"] = tk.Button( self._frame_metadata, text="Metadata...", width=_BUTTON_WIDTH, @@ -260,7 +343,7 @@ class _GUI(object): fg="black", command=self._open_metadata, ) - self._button_metadata.pack() + self._widgets["metadata"].pack() self.user_metadata_flag = False # This should probably be to the right somewhere @@ -281,10 +364,11 @@ class _GUI(object): _DEFAULT_DELAY, _DEFAULT_IGNORE_CHECKS, _DEFAULT_THRESHOLD_ESR, + _DEFAULT_CHECKPOINT, ) # Window to edit them: - self._button_advanced_options = tk.Button( + self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = tk.Button( self._frame_advanced_options, text="Advanced options...", width=_BUTTON_WIDTH, @@ -292,11 +376,11 @@ class _GUI(object): fg="black", command=self._open_advanced_options, ) - self._button_advanced_options.pack() + self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack() # Train button - self._button_train = tk.Button( + self._widgets[_GUIWidgets.TRAIN] = tk.Button( self._frame_train, text="Train", width=_BUTTON_WIDTH, @@ -304,7 +388,7 @@ class _GUI(object): fg="black", command=self._train, ) - self._button_train.pack() + self._widgets[_GUIWidgets.TRAIN].pack() self._check_button_states() @@ -316,14 +400,14 @@ class _GUI(object): if any( pb.val is None for pb in ( - self._path_button_input, - self._path_button_output, - self._path_button_train_destination, + self._widgets[_GUIWidgets.INPUT_PATH], + self._widgets[_GUIWidgets.OUTPUT_PATH], + self._widgets[_GUIWidgets.TRAINING_DESTINATION], ) ): - self._button_train["state"] = tk.DISABLED + self._widgets[_GUIWidgets.TRAIN]["state"] = tk.DISABLED return - self._button_train["state"] = tk.NORMAL + self._widgets[_GUIWidgets.TRAIN]["state"] = tk.NORMAL def _get_additional_options_frame(self): # Checkboxes @@ -346,6 +430,7 @@ class _GUI(object): self._frame_checkboxes, text=text, variable=variable ) self._checkboxes[key] = Checkbox(variable, check_button) + self._widgets[key] = check_button # For tracking in set-all-widgets ops self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict() make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False) @@ -370,30 +455,39 @@ class _GUI(object): def mainloop(self): self._root.mainloop() + def _disable(self): + self._set_all_widget_states_to(tk.DISABLED) + def _open_advanced_options(self): """ - Open advanced options + Open window for advanced options """ - ao = _AdvancedOptionsGUI(self) - # I should probably disable the main GUI... - ao.mainloop() - # ...and then re-enable it once it gets closed. + + self._wait_while_func(lambda resume: _AdvancedOptionsGUI(resume, self)) def _open_metadata(self): """ - Open dialog for metadata + Open window for metadata """ - mdata = _UserMetadataGUI(self) - # I should probably disable the main GUI... - mdata.mainloop() + + self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self)) + + def _resume(self): + self._set_all_widget_states_to(tk.NORMAL) + self._check_button_states() + + def _set_all_widget_states_to(self, state): + for widget in self._widgets.values(): + widget["state"] = state def _train(self): # Advanced options: num_epochs = self.advanced_options.num_epochs architecture = self.advanced_options.architecture delay = self.advanced_options.latency - file_list = self._path_button_output.val + file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val threshold_esr = self.advanced_options.threshold_esr + checkpoint = self.advanced_options.checkpoint # Advanced-er options # If you're poking around looking for these, then maybe it's time to learn to @@ -408,33 +502,43 @@ class _GUI(object): print("Now training {}".format(file)) basename = re.sub(r"\.wav$", "", file.split("/")[-1]) - trained_model = core.train( - self._path_button_input.val, - file, - self._path_button_train_destination.val, - epochs=num_epochs, - delay=delay, - architecture=architecture, - batch_size=batch_size, - lr=lr, - lr_decay=lr_decay, - seed=seed, - silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), - save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), - modelname=basename, - ignore_checks=self._checkboxes[ - _CheckboxKeys.IGNORE_DATA_CHECKS - ].variable.get(), - local=True, - fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), - threshold_esr=threshold_esr, - ) + try: + trained_model = core.train( + self._widgets[_GUIWidgets.INPUT_PATH].val, + file, + self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, + epochs=num_epochs, + delay=delay, + architecture=architecture, + batch_size=batch_size, + lr=lr, + lr_decay=lr_decay, + seed=seed, + silent=self._checkboxes[ + _CheckboxKeys.SILENT_TRAINING + ].variable.get(), + save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), + modelname=basename, + ignore_checks=self._checkboxes[ + _CheckboxKeys.IGNORE_DATA_CHECKS + ].variable.get(), + local=True, + fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), + threshold_esr=threshold_esr, + checkpoint=checkpoint, + ) + except IncompatibleCheckpointError as e: + trained_model = None + self._wait_while_func( + _BasicModal, "Training failed due to incompatible checkpoint!" + ) + if trained_model is None: print("Model training failed! Skip exporting...") continue print("Model training complete!") print("Exporting...") - outdir = self._path_button_train_destination.val + outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val print(f"Exporting trained model to {outdir}...") trained_model.net.export( outdir, @@ -449,6 +553,15 @@ class _GUI(object): # the user re-visits the window and clicks "ok" self.user_metadata_flag = False + def _wait_while_func(self, func, *args, **kwargs): + """ + Disable this GUI while something happens. + That function _needs_ to call the provided self._resume when it's ready to + release me! + """ + self._disable() + func(self._resume, *args, **kwargs) + # some typing functions def _non_negative_int(val): @@ -594,9 +707,10 @@ class _AdvancedOptionsGUI(object): A window to hold advanced options (Architecture and number of epochs) """ - def __init__(self, parent: _GUI): + def __init__(self, resume_main, parent: _GUI): + self._resume_main = resume_main self._parent = parent - self._root = tk.Tk() + self._root = tk.Toplevel() self._root.title("Advanced Options") # Architecture: radio buttons @@ -641,6 +755,17 @@ class _AdvancedOptionsGUI(object): type=_float_or_null, ) + # Restart from a checkpoint + self._frame_checkpoint = tk.Frame(self._root) + self._frame_checkpoint.pack() + self._path_button_checkpoint = _ClearablePathButton( + self._frame_checkpoint, + "Checkpoint", + "[Optional] Select a checkpoint (.ckpt file) to restart training from", + _PathType.FILE, + default=self._parent.advanced_options.checkpoint, + ) + # "Ok": apply and destory self._frame_ok = tk.Frame(self._root) self._frame_ok.pack() @@ -654,9 +779,6 @@ class _AdvancedOptionsGUI(object): ) self._button_ok.pack() - def mainloop(self): - self._root.mainloop() - def _apply_and_destroy(self): """ Set values to parent and destroy this object @@ -676,14 +798,20 @@ class _AdvancedOptionsGUI(object): self._parent.advanced_options.threshold_esr = ( None if threshold_esr == "null" else threshold_esr ) + checkpoint_path = self._path_button_checkpoint.val + self._parent.advanced_options.checkpoint = ( + None if checkpoint_path is None else Path(checkpoint_path) + ) self._root.destroy() + self._resume_main() class _UserMetadataGUI(object): # Things that are auto-filled: # Model date # gain - def __init__(self, parent: _GUI): + def __init__(self, resume_main, parent: _GUI): + self._resume_main = resume_main self._parent = parent self._root = tk.Tk() self._root.title("Metadata") @@ -758,9 +886,6 @@ class _UserMetadataGUI(object): ) self._button_ok.pack() - def mainloop(self): - self._root.mainloop() - def _apply_and_destroy(self): """ Set values to parent and destroy this object @@ -774,6 +899,7 @@ class _UserMetadataGUI(object): self._parent.user_metadata_flag = True self._root.destroy() + self._resume_main() def _install_error():