commit c241c3e9f70f94047f9c4b98d00d94ea1f6d1721
parent 6cb0266e3a9b23f6726e9451dd6435e30db90a2e
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 13 Apr 2024 12:17:57 -0700
[GUI] Advanced option for loss-based early stopping (#401)
* Update easy_colab.ipynb
Check installed packages
Quiet install
Both output to logs/
* Add loss-based early stopping
Stop training in GUI trainer if the loss goes below a threshold you can
set under advanced options.
Diffstat:
2 files changed, 85 insertions(+), 25 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -1079,6 +1079,36 @@ def _nasty_checks_modal():
modal.mainloop()
+class _ValidationStopping(pl.callbacks.EarlyStopping):
+ """
+ Callback to indicate to stop training if the validation metric is good enough,
+ without the other conditions that EarlyStopping usually forces like patience.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.patience = np.inf
+
+
+def _get_callbacks(threshold_esr: Optional[float]):
+ callbacks = [
+ pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
+ save_top_k=3,
+ monitor="val_loss",
+ every_n_epochs=1,
+ ),
+ pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
+ ),
+ ]
+ if threshold_esr is not None:
+ callbacks.append(
+ _ValidationStopping(monitor="ESR", stopping_threshold=threshold_esr)
+ )
+ return callbacks
+
+
def train(
input_path: str,
output_path: str,
@@ -1099,7 +1129,11 @@ def train(
ignore_checks: bool = False,
local: bool = False,
fit_cab: bool = False,
+ threshold_esr: Optional[bool] = None,
) -> Optional[Model]:
+ """
+ :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
+ """
if seed is not None:
torch.manual_seed(seed)
@@ -1164,17 +1198,7 @@ def train(
)
trainer = pl.Trainer(
- callbacks=[
- pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
- save_top_k=3,
- monitor="val_loss",
- every_n_epochs=1,
- ),
- pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
- ),
- ],
+ callbacks=_get_callbacks(threshold_esr),
default_root_dir=train_path,
**learning_config["trainer"],
)
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -65,6 +65,7 @@ _TEXT_WIDTH = 70
_DEFAULT_DELAY = None
_DEFAULT_IGNORE_CHECKS = False
+_DEFAULT_THRESHOLD_ESR = None
_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
@@ -73,10 +74,21 @@ _METADATA_RIGHT_WIDTH = 60
@dataclass
class _AdvancedOptions(object):
+ """
+ :param architecture: Which architecture to use.
+ :param num_epochs: How many epochs to train for.
+ :param latency: Latency between the input and output audio, in samples.
+ None means we don't know and it has to be calibrated.
+ :param ignore_checks: Keep going even if a check says that something is wrong.
+ :param threshold_esr: Stop training if the ESR gets better than this. If None, don't
+ stop.
+ """
+
architecture: core.Architecture
num_epochs: int
- delay: Optional[int]
+ latency: Optional[int]
ignore_checks: bool
+ threshold_esr: Optional[float]
class _PathType(Enum):
@@ -268,6 +280,7 @@ class _GUI(object):
_DEFAULT_NUM_EPOCHS,
_DEFAULT_DELAY,
_DEFAULT_IGNORE_CHECKS,
+ _DEFAULT_THRESHOLD_ESR,
)
# Window to edit them:
@@ -378,8 +391,9 @@ class _GUI(object):
# Advanced options:
num_epochs = self.advanced_options.num_epochs
architecture = self.advanced_options.architecture
- delay = self.advanced_options.delay
+ delay = self.advanced_options.latency
file_list = self._path_button_output.val
+ threshold_esr = self.advanced_options.threshold_esr
# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
@@ -413,6 +427,7 @@ class _GUI(object):
].variable.get(),
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
+ threshold_esr=threshold_esr,
)
if trained_model is None:
print("Model training failed! Skip exporting...")
@@ -443,14 +458,18 @@ def _non_negative_int(val):
return val
-def _int_or_null(val):
+def _type_or_null(T, val):
val = val.rstrip()
if val == "null":
return val
- return int(val)
+ return T(val)
+
+
+_int_or_null = partial(_type_or_null, int)
+_float_or_null = partial(_type_or_null, float)
-def _int_or_null_inv(val):
+def _type_or_null_inv(val):
return "null" if val is None else str(val)
@@ -602,16 +621,26 @@ class _AdvancedOptionsGUI(object):
)
# Delay: text box
- self._frame_delay = tk.Frame(self._root)
- self._frame_delay.pack()
+ self._frame_latency = tk.Frame(self._root)
+ self._frame_latency.pack()
- self._delay = _LabeledText(
- self._frame_delay,
- "Delay",
- default=_int_or_null_inv(self._parent.advanced_options.delay),
+ self._latency = _LabeledText(
+ self._frame_latency,
+ "Reamp latency",
+ default=_type_or_null_inv(self._parent.advanced_options.latency),
type=_int_or_null,
)
+ # Threshold ESR
+ self._frame_threshold_esr = tk.Frame(self._root)
+ self._frame_threshold_esr.pack()
+ self._threshold_esr = _LabeledText(
+ self._frame_threshold_esr,
+ "Threshold ESR",
+ default=_type_or_null_inv(self._parent.advanced_options.threshold_esr),
+ type=_float_or_null,
+ )
+
# "Ok": apply and destory
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
@@ -636,10 +665,17 @@ class _AdvancedOptionsGUI(object):
epochs = self._epochs.get()
if epochs is not None:
self._parent.advanced_options.num_epochs = epochs
- delay = self._delay.get()
+ latency = self._latency.get()
# Value None is returned as "null" to disambiguate from non-set.
- if delay is not None:
- self._parent.advanced_options.delay = None if delay == "null" else delay
+ if latency is not None:
+ self._parent.advanced_options.latency = (
+ None if latency == "null" else latency
+ )
+ threshold_esr = self._threshold_esr.get()
+ if threshold_esr is not None:
+ self._parent.advanced_options.threshold_esr = (
+ None if threshold_esr == "null" else threshold_esr
+ )
self._root.destroy()