commit 01e00c86b134d72976325ac0eb1f04fd946f3dac
parent b71db729a7fa1be4cea3bd72836c0677318954ab
Author: Steven Atkinson <steven@atkinson.mn>
Date: Mon, 13 May 2024 08:32:29 -0700
Get user metadata in checkpoint .nams (#410)
Diffstat:
2 files changed, 22 insertions(+), 14 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -27,6 +27,7 @@ from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.exportable import Exportable
from ..models.losses import esr
+from ..models.metadata import UserMetadata
from ..util import filter_warnings
from ._version import PROTEUS_VERSION, Version
@@ -1097,6 +1098,10 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
Extension to model checkpoint to save a .nam file as well as the .ckpt file.
"""
+ def __init__(self, *args, user_metadata: Optional[UserMetadata] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._user_metadata = user_metadata
+
_NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION
@classmethod
@@ -1121,10 +1126,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
outdir = nam_filepath.parent
# HACK: Assume the extension
basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)]
- nam_model.export(
- outdir,
- basename=basename,
- )
+ nam_model.export(outdir, basename=basename, user_metadata=self._user_metadata)
def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None:
super()._remove_checkpoint(trainer, filepath)
@@ -1133,16 +1135,21 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
nam_path.unlink()
-def _get_callbacks(threshold_esr: Optional[float]):
+def _get_callbacks(
+ threshold_esr: Optional[float], user_metadata: Optional[UserMetadata] = None
+):
callbacks = [
_ModelCheckpoint(
filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
save_top_k=3,
monitor="val_loss",
every_n_epochs=1,
+ user_metadata=user_metadata,
),
_ModelCheckpoint(
- filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
+ filename="checkpoint_last_{epoch:04d}_{step}",
+ every_n_epochs=1,
+ user_metadata=user_metadata,
),
]
if threshold_esr is not None:
@@ -1173,10 +1180,12 @@ def train(
local: bool = False,
fit_cab: bool = False,
threshold_esr: Optional[bool] = None,
+ user_metadata: Optional[UserMetadata] = None,
) -> Optional[Model]:
"""
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
"""
+
if seed is not None:
torch.manual_seed(seed)
@@ -1243,7 +1252,7 @@ def train(
model.net.sample_rate = sample_rate
trainer = pl.Trainer(
- callbacks=_get_callbacks(threshold_esr),
+ callbacks=_get_callbacks(threshold_esr, user_metadata=user_metadata),
default_root_dir=train_path,
**learning_config["trainer"],
)
@@ -1260,8 +1269,7 @@ def train(
)
model.cpu()
model.eval()
- # HACK set again
- model.net.sample_rate = sample_rate
+ model.net.sample_rate = sample_rate # Hack, part 2
def window_kwargs(version: Version):
if version.major == 1:
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -495,6 +495,9 @@ class _GUI(object):
for file in file_list:
print("Now training {}".format(file))
basename = re.sub(r"\.wav$", "", file.split("/")[-1])
+ user_metadata = (
+ self.user_metadata if self.user_metadata_flag else UserMetadata()
+ )
trained_model = core.train(
self._widgets[_GUIWidgets.INPUT_PATH].val,
@@ -516,6 +519,7 @@ class _GUI(object):
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
threshold_esr=threshold_esr,
+ user_metadata=user_metadata,
)
if trained_model is None:
@@ -526,11 +530,7 @@ class _GUI(object):
outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val
print(f"Exporting trained model to {outdir}...")
trained_model.net.export(
- outdir,
- basename=basename,
- user_metadata=(
- self.user_metadata if self.user_metadata_flag else UserMetadata()
- ),
+ outdir, basename=basename, user_metadata=user_metadata
)
print("Done!")