neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit a98cfbd9dd0c52fb1b80cb435aff9492e77a064d
parent 0d840eee7a0d16628368fb01d3965ae32f306902
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 22 Sep 2024 14:31:34 -0700

Make some attributes public (#470)

* Define some public attributes

* nam.train.core.get_lstm_config
* nam.train.core.get_wavenet_config
* nam.train.gui.GUI.core_train_kwargs
* nam.train.gui.LabeledOptionMenu
* nam.train.gui.LabeledText

* Fix test
Diffstat:
Mnam/train/core.py | 20++++++++++----------
Mnam/train/gui/__init__.py | 55++++++++++++++++++++++++++++---------------------------
Mtests/test_nam/test_models/test_wavenet.py | 4++--
3 files changed, 40 insertions(+), 39 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -551,7 +551,7 @@ def _analyze_latency( return metadata.Latency(manual=user_latency, calibration=calibration_output) -def _get_lstm_config(architecture): +def get_lstm_config(architecture): return { Architecture.STANDARD: { "num_layers": 1, @@ -788,7 +788,7 @@ def _check_data( return out -def _get_wavenet_config(architecture): +def get_wavenet_config(architecture): return { Architecture.STANDARD: { "layers_configs": [ @@ -996,7 +996,7 @@ def _get_configs( "name": "WaveNet", # This should do decently. If you really want a nice model, try turning up # "channels" in the first block and "input_size" in the second from 12 to 16. - "config": _get_wavenet_config(architecture), + "config": get_wavenet_config(architecture), }, "loss": {"val_loss": "esr"}, "optimizer": {"lr": lr}, @@ -1009,7 +1009,7 @@ def _get_configs( model_config = { "net": { "name": "LSTM", - "config": _get_lstm_config(architecture), + "config": get_lstm_config(architecture), }, "loss": { "val_loss": "mse", @@ -1563,13 +1563,13 @@ def validate_data( for split in Split: try: init_dataset(data_config, split) - pytorch_data_split_validation_dict[ - split.value - ] = _PyTorchDataSplitValidation(passed=True, msg=None) + pytorch_data_split_validation_dict[split.value] = ( + _PyTorchDataSplitValidation(passed=True, msg=None) + ) except DataError as e: - pytorch_data_split_validation_dict[ - split.value - ] = _PyTorchDataSplitValidation(passed=False, msg=str(e)) + pytorch_data_split_validation_dict[split.value] = ( + _PyTorchDataSplitValidation(passed=False, msg=str(e)) + ) pytorch_data_validation = _PyTorchDataValidation( passed=all(v.passed for v in pytorch_data_split_validation_dict.values()), **pytorch_data_split_validation_dict, diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -21,7 +21,7 @@ from enum import Enum from functools import partial from pathlib import Path from tkinter import filedialog -from typing import Callable, Dict, NamedTuple, Optional, Sequence +from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence try: # 3rd-party and 1st-party imports import torch @@ -496,6 +496,17 @@ class GUI(object): self._check_button_states() + def core_train_kwargs(self) -> Dict[str, Any]: + """ + Get any additional kwargs to provide to `core.train` + """ + return { + "lr": 0.004, + "lr_decay": _DEFAULT_LR_DECAY, + "batch_size": _DEFAULT_BATCH_SIZE, + "seed": 0, + } + def get_mrstft_fit(self) -> bool: """ Use a pre-emphasized multi-resolution shot-time Fourier transform loss during @@ -690,13 +701,6 @@ class GUI(object): file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val threshold_esr = self.advanced_options.threshold_esr - # Advanced-er options - # If you're poking around looking for these, then maybe it's time to learn to - # use the command-line scripts ;) - lr = 0.004 - lr_decay = _DEFAULT_LR_DECAY - batch_size = _DEFAULT_BATCH_SIZE - seed = 0 # Run it for file in file_list: print(f"Now training {file}") @@ -712,10 +716,6 @@ class GUI(object): epochs=num_epochs, latency=user_latency, architecture=architecture, - batch_size=batch_size, - lr=lr, - lr_decay=lr_decay, - seed=seed, silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), modelname=basename, @@ -724,6 +724,7 @@ class GUI(object): fit_mrstft=self.get_mrstft_fit(), threshold_esr=threshold_esr, user_metadata=user_metadata, + **self.core_train_kwargs(), ) if train_output.model is None: @@ -743,8 +744,8 @@ class GUI(object): ) print("Done!") - # Metadata was only valid for 1 run, so make sure it's not used again unless - # the user re-visits the window and clicks "ok" + # Metadata was only valid for 1 run (possibly a batch), so make sure it's not + # used again unless the user re-visits the window and clicks "ok". self.user_metadata_flag = False def _validate_all_data( @@ -884,7 +885,7 @@ def _rstripped_str(val): return str(val).rstrip() -class _LabeledOptionMenu(object): +class LabeledOptionMenu(object): """ Label (left) and radio buttons (right) """ @@ -937,7 +938,7 @@ class _LabeledOptionMenu(object): self._selected_value = self._choices(val) -class _LabeledText(object): +class LabeledText(object): """ Label (left) and text input (right) """ @@ -1040,7 +1041,7 @@ class AdvancedOptionsGUI(object): # Architecture: radio buttons self._frame_architecture = tk.Frame(self._root) self._frame_architecture.pack() - self._architecture = _LabeledOptionMenu( + self._architecture = LabeledOptionMenu( self._frame_architecture, "Architecture", core.Architecture, @@ -1051,7 +1052,7 @@ class AdvancedOptionsGUI(object): self._frame_epochs = tk.Frame(self._root) self._frame_epochs.pack() - self._epochs = _LabeledText( + self._epochs = LabeledText( self._frame_epochs, "Epochs", default=str(self._parent.advanced_options.num_epochs), @@ -1062,7 +1063,7 @@ class AdvancedOptionsGUI(object): self._frame_latency = tk.Frame(self._root) self._frame_latency.pack() - self._latency = _LabeledText( + self._latency = LabeledText( self._frame_latency, "Reamp latency", default=_type_or_null_inv(self._parent.advanced_options.latency), @@ -1072,7 +1073,7 @@ class AdvancedOptionsGUI(object): # Threshold ESR self._frame_threshold_esr = tk.Frame(self._root) self._frame_threshold_esr.pack() - self._threshold_esr = _LabeledText( + self._threshold_esr = LabeledText( self._frame_threshold_esr, "Threshold ESR", default=_type_or_null_inv(self._parent.advanced_options.threshold_esr), @@ -1089,12 +1090,12 @@ class _UserMetadataGUI(object): self._root = _TopLevelWithOk(self._apply, resume_main) self._root.title("Metadata") - LabeledText = partial(_LabeledText, right_width=_METADATA_RIGHT_WIDTH) + LabeledText_ = partial(LabeledText, right_width=_METADATA_RIGHT_WIDTH) # Name self._frame_name = tk.Frame(self._root) self._frame_name.pack() - self._name = LabeledText( + self._name = LabeledText_( self._frame_name, "NAM name", default=parent.user_metadata.name, @@ -1103,7 +1104,7 @@ class _UserMetadataGUI(object): # Modeled by self._frame_modeled_by = tk.Frame(self._root) self._frame_modeled_by.pack() - self._modeled_by = LabeledText( + self._modeled_by = LabeledText_( self._frame_modeled_by, "Modeled by", default=parent.user_metadata.modeled_by, @@ -1112,7 +1113,7 @@ class _UserMetadataGUI(object): # Gear make self._frame_gear_make = tk.Frame(self._root) self._frame_gear_make.pack() - self._gear_make = LabeledText( + self._gear_make = LabeledText_( self._frame_gear_make, "Gear make", default=parent.user_metadata.gear_make, @@ -1121,7 +1122,7 @@ class _UserMetadataGUI(object): # Gear model self._frame_gear_model = tk.Frame(self._root) self._frame_gear_model.pack() - self._gear_model = LabeledText( + self._gear_model = LabeledText_( self._frame_gear_model, "Gear model", default=parent.user_metadata.gear_model, @@ -1130,7 +1131,7 @@ class _UserMetadataGUI(object): # Gear type self._frame_gear_type = tk.Frame(self._root) self._frame_gear_type.pack() - self._gear_type = _LabeledOptionMenu( + self._gear_type = LabeledOptionMenu( self._frame_gear_type, "Gear type", GearType, @@ -1139,7 +1140,7 @@ class _UserMetadataGUI(object): # Tone type self._frame_tone_type = tk.Frame(self._root) self._frame_tone_type.pack() - self._tone_type = _LabeledOptionMenu( + self._tone_type = LabeledOptionMenu( self._frame_tone_type, "Tone type", ToneType, diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py @@ -6,7 +6,7 @@ import pytest import torch from nam.models.wavenet import WaveNet -from nam.train.core import Architecture, _get_wavenet_config +from nam.train.core import Architecture, get_wavenet_config # from .base import Base @@ -14,7 +14,7 @@ from nam.train.core import Architecture, _get_wavenet_config class TestWaveNet(object): def test_import_weights(self): - config = _get_wavenet_config(Architecture.FEATHER) + config = get_wavenet_config(Architecture.FEATHER) model_1 = WaveNet.init_from_config(config) model_2 = WaveNet.init_from_config(config)