neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 01e00c86b134d72976325ac0eb1f04fd946f3dac
parent b71db729a7fa1be4cea3bd72836c0677318954ab
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Mon, 13 May 2024 08:32:29 -0700

Get user metadata in checkpoint .nams (#410)


Diffstat:
Mnam/train/core.py | 26+++++++++++++++++---------
Mnam/train/gui.py | 10+++++-----
2 files changed, 22 insertions(+), 14 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -27,6 +27,7 @@ from ..data import Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model from ..models.exportable import Exportable from ..models.losses import esr +from ..models.metadata import UserMetadata from ..util import filter_warnings from ._version import PROTEUS_VERSION, Version @@ -1097,6 +1098,10 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): Extension to model checkpoint to save a .nam file as well as the .ckpt file. """ + def __init__(self, *args, user_metadata: Optional[UserMetadata] = None, **kwargs): + super().__init__(*args, **kwargs) + self._user_metadata = user_metadata + _NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION @classmethod @@ -1121,10 +1126,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): outdir = nam_filepath.parent # HACK: Assume the extension basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)] - nam_model.export( - outdir, - basename=basename, - ) + nam_model.export(outdir, basename=basename, user_metadata=self._user_metadata) def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: super()._remove_checkpoint(trainer, filepath) @@ -1133,16 +1135,21 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): nam_path.unlink() -def _get_callbacks(threshold_esr: Optional[float]): +def _get_callbacks( + threshold_esr: Optional[float], user_metadata: Optional[UserMetadata] = None +): callbacks = [ _ModelCheckpoint( filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}", save_top_k=3, monitor="val_loss", every_n_epochs=1, + user_metadata=user_metadata, ), _ModelCheckpoint( - filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 + filename="checkpoint_last_{epoch:04d}_{step}", + every_n_epochs=1, + user_metadata=user_metadata, ), ] if threshold_esr is not None: @@ -1173,10 +1180,12 @@ def train( local: bool = False, fit_cab: bool = False, threshold_esr: Optional[bool] = None, + user_metadata: Optional[UserMetadata] = None, ) -> Optional[Model]: """ :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. """ + if seed is not None: torch.manual_seed(seed) @@ -1243,7 +1252,7 @@ def train( model.net.sample_rate = sample_rate trainer = pl.Trainer( - callbacks=_get_callbacks(threshold_esr), + callbacks=_get_callbacks(threshold_esr, user_metadata=user_metadata), default_root_dir=train_path, **learning_config["trainer"], ) @@ -1260,8 +1269,7 @@ def train( ) model.cpu() model.eval() - # HACK set again - model.net.sample_rate = sample_rate + model.net.sample_rate = sample_rate # Hack, part 2 def window_kwargs(version: Version): if version.major == 1: diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -495,6 +495,9 @@ class _GUI(object): for file in file_list: print("Now training {}".format(file)) basename = re.sub(r"\.wav$", "", file.split("/")[-1]) + user_metadata = ( + self.user_metadata if self.user_metadata_flag else UserMetadata() + ) trained_model = core.train( self._widgets[_GUIWidgets.INPUT_PATH].val, @@ -516,6 +519,7 @@ class _GUI(object): local=True, fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), threshold_esr=threshold_esr, + user_metadata=user_metadata, ) if trained_model is None: @@ -526,11 +530,7 @@ class _GUI(object): outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val print(f"Exporting trained model to {outdir}...") trained_model.net.export( - outdir, - basename=basename, - user_metadata=( - self.user_metadata if self.user_metadata_flag else UserMetadata() - ), + outdir, basename=basename, user_metadata=user_metadata ) print("Done!")