commit a98cfbd9dd0c52fb1b80cb435aff9492e77a064d
parent 0d840eee7a0d16628368fb01d3965ae32f306902
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 22 Sep 2024 14:31:34 -0700
Make some attributes public (#470)
* Define some public attributes
* nam.train.core.get_lstm_config
* nam.train.core.get_wavenet_config
* nam.train.gui.GUI.core_train_kwargs
* nam.train.gui.LabeledOptionMenu
* nam.train.gui.LabeledText
* Fix test
Diffstat:
3 files changed, 40 insertions(+), 39 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -551,7 +551,7 @@ def _analyze_latency(
return metadata.Latency(manual=user_latency, calibration=calibration_output)
-def _get_lstm_config(architecture):
+def get_lstm_config(architecture):
return {
Architecture.STANDARD: {
"num_layers": 1,
@@ -788,7 +788,7 @@ def _check_data(
return out
-def _get_wavenet_config(architecture):
+def get_wavenet_config(architecture):
return {
Architecture.STANDARD: {
"layers_configs": [
@@ -996,7 +996,7 @@ def _get_configs(
"name": "WaveNet",
# This should do decently. If you really want a nice model, try turning up
# "channels" in the first block and "input_size" in the second from 12 to 16.
- "config": _get_wavenet_config(architecture),
+ "config": get_wavenet_config(architecture),
},
"loss": {"val_loss": "esr"},
"optimizer": {"lr": lr},
@@ -1009,7 +1009,7 @@ def _get_configs(
model_config = {
"net": {
"name": "LSTM",
- "config": _get_lstm_config(architecture),
+ "config": get_lstm_config(architecture),
},
"loss": {
"val_loss": "mse",
@@ -1563,13 +1563,13 @@ def validate_data(
for split in Split:
try:
init_dataset(data_config, split)
- pytorch_data_split_validation_dict[
- split.value
- ] = _PyTorchDataSplitValidation(passed=True, msg=None)
+ pytorch_data_split_validation_dict[split.value] = (
+ _PyTorchDataSplitValidation(passed=True, msg=None)
+ )
except DataError as e:
- pytorch_data_split_validation_dict[
- split.value
- ] = _PyTorchDataSplitValidation(passed=False, msg=str(e))
+ pytorch_data_split_validation_dict[split.value] = (
+ _PyTorchDataSplitValidation(passed=False, msg=str(e))
+ )
pytorch_data_validation = _PyTorchDataValidation(
passed=all(v.passed for v in pytorch_data_split_validation_dict.values()),
**pytorch_data_split_validation_dict,
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -21,7 +21,7 @@ from enum import Enum
from functools import partial
from pathlib import Path
from tkinter import filedialog
-from typing import Callable, Dict, NamedTuple, Optional, Sequence
+from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence
try: # 3rd-party and 1st-party imports
import torch
@@ -496,6 +496,17 @@ class GUI(object):
self._check_button_states()
+ def core_train_kwargs(self) -> Dict[str, Any]:
+ """
+ Get any additional kwargs to provide to `core.train`
+ """
+ return {
+ "lr": 0.004,
+ "lr_decay": _DEFAULT_LR_DECAY,
+ "batch_size": _DEFAULT_BATCH_SIZE,
+ "seed": 0,
+ }
+
def get_mrstft_fit(self) -> bool:
"""
Use a pre-emphasized multi-resolution shot-time Fourier transform loss during
@@ -690,13 +701,6 @@ class GUI(object):
file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
threshold_esr = self.advanced_options.threshold_esr
- # Advanced-er options
- # If you're poking around looking for these, then maybe it's time to learn to
- # use the command-line scripts ;)
- lr = 0.004
- lr_decay = _DEFAULT_LR_DECAY
- batch_size = _DEFAULT_BATCH_SIZE
- seed = 0
# Run it
for file in file_list:
print(f"Now training {file}")
@@ -712,10 +716,6 @@ class GUI(object):
epochs=num_epochs,
latency=user_latency,
architecture=architecture,
- batch_size=batch_size,
- lr=lr,
- lr_decay=lr_decay,
- seed=seed,
silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
modelname=basename,
@@ -724,6 +724,7 @@ class GUI(object):
fit_mrstft=self.get_mrstft_fit(),
threshold_esr=threshold_esr,
user_metadata=user_metadata,
+ **self.core_train_kwargs(),
)
if train_output.model is None:
@@ -743,8 +744,8 @@ class GUI(object):
)
print("Done!")
- # Metadata was only valid for 1 run, so make sure it's not used again unless
- # the user re-visits the window and clicks "ok"
+ # Metadata was only valid for 1 run (possibly a batch), so make sure it's not
+ # used again unless the user re-visits the window and clicks "ok".
self.user_metadata_flag = False
def _validate_all_data(
@@ -884,7 +885,7 @@ def _rstripped_str(val):
return str(val).rstrip()
-class _LabeledOptionMenu(object):
+class LabeledOptionMenu(object):
"""
Label (left) and radio buttons (right)
"""
@@ -937,7 +938,7 @@ class _LabeledOptionMenu(object):
self._selected_value = self._choices(val)
-class _LabeledText(object):
+class LabeledText(object):
"""
Label (left) and text input (right)
"""
@@ -1040,7 +1041,7 @@ class AdvancedOptionsGUI(object):
# Architecture: radio buttons
self._frame_architecture = tk.Frame(self._root)
self._frame_architecture.pack()
- self._architecture = _LabeledOptionMenu(
+ self._architecture = LabeledOptionMenu(
self._frame_architecture,
"Architecture",
core.Architecture,
@@ -1051,7 +1052,7 @@ class AdvancedOptionsGUI(object):
self._frame_epochs = tk.Frame(self._root)
self._frame_epochs.pack()
- self._epochs = _LabeledText(
+ self._epochs = LabeledText(
self._frame_epochs,
"Epochs",
default=str(self._parent.advanced_options.num_epochs),
@@ -1062,7 +1063,7 @@ class AdvancedOptionsGUI(object):
self._frame_latency = tk.Frame(self._root)
self._frame_latency.pack()
- self._latency = _LabeledText(
+ self._latency = LabeledText(
self._frame_latency,
"Reamp latency",
default=_type_or_null_inv(self._parent.advanced_options.latency),
@@ -1072,7 +1073,7 @@ class AdvancedOptionsGUI(object):
# Threshold ESR
self._frame_threshold_esr = tk.Frame(self._root)
self._frame_threshold_esr.pack()
- self._threshold_esr = _LabeledText(
+ self._threshold_esr = LabeledText(
self._frame_threshold_esr,
"Threshold ESR",
default=_type_or_null_inv(self._parent.advanced_options.threshold_esr),
@@ -1089,12 +1090,12 @@ class _UserMetadataGUI(object):
self._root = _TopLevelWithOk(self._apply, resume_main)
self._root.title("Metadata")
- LabeledText = partial(_LabeledText, right_width=_METADATA_RIGHT_WIDTH)
+ LabeledText_ = partial(LabeledText, right_width=_METADATA_RIGHT_WIDTH)
# Name
self._frame_name = tk.Frame(self._root)
self._frame_name.pack()
- self._name = LabeledText(
+ self._name = LabeledText_(
self._frame_name,
"NAM name",
default=parent.user_metadata.name,
@@ -1103,7 +1104,7 @@ class _UserMetadataGUI(object):
# Modeled by
self._frame_modeled_by = tk.Frame(self._root)
self._frame_modeled_by.pack()
- self._modeled_by = LabeledText(
+ self._modeled_by = LabeledText_(
self._frame_modeled_by,
"Modeled by",
default=parent.user_metadata.modeled_by,
@@ -1112,7 +1113,7 @@ class _UserMetadataGUI(object):
# Gear make
self._frame_gear_make = tk.Frame(self._root)
self._frame_gear_make.pack()
- self._gear_make = LabeledText(
+ self._gear_make = LabeledText_(
self._frame_gear_make,
"Gear make",
default=parent.user_metadata.gear_make,
@@ -1121,7 +1122,7 @@ class _UserMetadataGUI(object):
# Gear model
self._frame_gear_model = tk.Frame(self._root)
self._frame_gear_model.pack()
- self._gear_model = LabeledText(
+ self._gear_model = LabeledText_(
self._frame_gear_model,
"Gear model",
default=parent.user_metadata.gear_model,
@@ -1130,7 +1131,7 @@ class _UserMetadataGUI(object):
# Gear type
self._frame_gear_type = tk.Frame(self._root)
self._frame_gear_type.pack()
- self._gear_type = _LabeledOptionMenu(
+ self._gear_type = LabeledOptionMenu(
self._frame_gear_type,
"Gear type",
GearType,
@@ -1139,7 +1140,7 @@ class _UserMetadataGUI(object):
# Tone type
self._frame_tone_type = tk.Frame(self._root)
self._frame_tone_type.pack()
- self._tone_type = _LabeledOptionMenu(
+ self._tone_type = LabeledOptionMenu(
self._frame_tone_type,
"Tone type",
ToneType,
diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py
@@ -6,7 +6,7 @@ import pytest
import torch
from nam.models.wavenet import WaveNet
-from nam.train.core import Architecture, _get_wavenet_config
+from nam.train.core import Architecture, get_wavenet_config
# from .base import Base
@@ -14,7 +14,7 @@ from nam.train.core import Architecture, _get_wavenet_config
class TestWaveNet(object):
def test_import_weights(self):
- config = _get_wavenet_config(Architecture.FEATHER)
+ config = get_wavenet_config(Architecture.FEATHER)
model_1 = WaveNet.init_from_config(config)
model_2 = WaveNet.init_from_config(config)