neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 5c16c76c7bdbced07633d7d988e348ff17033e97
parent 31a211326f9468428ba99dae70c9d32ecff2d87a
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 22 Apr 2023 17:29:32 -0700

Metadata in GUI trainer (#202)

* STart on metadata for GUI

* Add pydantic

* Pin auraloss to 0.3.0

* Metadata in GUI trainer
Diffstat:
Menvironment_cpu.yml | 3++-
Menvironment_gpu.yml | 3++-
Mnam/models/_base.py | 43+++++++++++++++++++++++++++++++++----------
Mnam/models/_exportable.py | 55+++++++++++++++++++++++++++++++++++++++++++++++--------
Mnam/models/losses.py | 2++
Anam/models/metadata.py | 53+++++++++++++++++++++++++++++++++++++++++++++++++++++
Mnam/train/gui.py | 201++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------
Mrequirements.txt | 3++-
Msetup.py | 3++-
Mtests/test_nam/test_models/test_base.py | 12+++++++++---
Mtests/test_nam/test_models/test_exportable.py | 41++++++++++++++++++++++++++++++++++++++++-
11 files changed, 366 insertions(+), 53 deletions(-)

diff --git a/environment_cpu.yml b/environment_cpu.yml @@ -14,6 +14,7 @@ dependencies: - matplotlib - numpy - pip + - pydantic - pytest - pytest-mock - pytorch @@ -23,7 +24,7 @@ dependencies: - tqdm - wheel - pip: - - auraloss + - auraloss==0.3.0 - onnx - onnxruntime - pre-commit diff --git a/environment_gpu.yml b/environment_gpu.yml @@ -15,6 +15,7 @@ dependencies: - matplotlib - numpy - pip + - pydantic - pytest - pytest-mock - pytorch @@ -25,7 +26,7 @@ dependencies: - tqdm - wheel - pip: - - auraloss + - auraloss==0.3.0 - onnx - onnxruntime # TODO GPU... - pre-commit diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -10,7 +10,7 @@ steps) import abc import math import pkg_resources -from typing import Any, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -37,7 +37,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): def forward(self, *args, **kwargs) -> torch.Tensor: pass - def _loudness(self, gain: float = 1.0) -> float: + def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: """ How loud is this model when given a standardized input? In dB @@ -50,7 +50,33 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): ) ) y = self._at_nominal_settings(gain * x) - return 10.0 * torch.log10(torch.mean(torch.square(y))).item() + loudness = torch.sqrt(torch.mean(torch.square(y))) + if db: + loudness = 20.0 * torch.log10(loudness) + return loudness.item() + + def _metadata_gain(self) -> float: + """ + Between 0 and 1, how much gain / compression does the model seem to have? + """ + x = np.linspace(0.0, 1.0, 11) + y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x]) + # + # O ^ o o o o o o + # u | o x +-------------------------------------+ + # t | o x | x: Minimum gain (no compression) | + # p | o x | o: Max gain (100% compression) | + # u | o x +-------------------------------------+ + # t | o + # +-------------> + # Input + # + max_gain = y[-1] * len(x) # "Square" + min_gain = 0.5 * max_gain # "Triangle" + gain_range = max_gain - min_gain + this_gain = y.sum() + normalized_gain = (this_gain - min_gain) / gain_range + return np.clip(normalized_gain, 0.0, 1.0) def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: # parametric?... @@ -91,10 +117,6 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): self(*args, x, pad_start=True).detach().cpu().numpy(), ) - def _get_export_dict(self): - d = super()._get_export_dict() - return d - class BaseNet(_Base): def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None): @@ -122,9 +144,10 @@ class BaseNet(_Base): """ pass - def _get_export_dict(self): - d = super()._get_export_dict() - d["metadata"]["loudness"] = self._loudness() + def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: + d = super()._get_non_user_metadata() + d["loudness"] = self._metadata_loudness() + d["gain"] = self._metadata_gain() return d diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -5,23 +5,44 @@ import abc import json import logging +from datetime import datetime +from enum import Enum from pathlib import Path -from typing import Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np from .._version import __version__ from ..data import np_to_wav +from .metadata import Date, UserMetadata logger = logging.getLogger(__name__) +def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]: + """ + Casts enum-type keys to their values + """ + out = {} + for key, val in d.items(): + if isinstance(val, Enum): + val = val.value + out[key] = val + return out + + class Exportable(abc.ABC): """ Interface for my custon export format for use in the plugin. """ - def export(self, outdir: Path, include_snapshot: bool = False, modelname: str = "model"): + def export( + self, + outdir: Path, + include_snapshot: bool = False, + basename: str = "model", + user_metadata: Optional[UserMetadata] = None, + ): """ Interface for exporting. You should create at least a `config.json` containing the two fields: @@ -35,13 +56,15 @@ class Exportable(abc.ABC): Can be used to debug e.g. the implementation of the model in the plugin. """ + model_dict = self._get_export_dict() + model_dict["metadata"].update( + {} if user_metadata is None else _cast_enums(user_metadata.dict()) + ) + training = self.training self.eval() - with open(Path(outdir, modelname + ".nam"), "w") as fp: - json.dump( - self._get_export_dict(), - fp, - ) + with open(Path(outdir, f"{basename}.nam"), "w") as fp: + json.dump(model_dict, fp) if include_snapshot: x, y = self._export_input_output() x_path = Path(outdir, "test_inputs.npy") @@ -101,8 +124,24 @@ class Exportable(abc.ABC): def _get_export_dict(self): return { "version": __version__, + "metadata": self._get_non_user_metadata(), "architecture": self.__class__.__name__, "config": self._export_config(), - "metadata": {}, "weights": self._export_weights().tolist(), } + + def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: + """ + Get any metadata that's non-user-provided (date, loudness, gain) + """ + t = datetime.now() + return { + "date": Date( + year=t.year, + month=t.month, + day=t.day, + hour=t.hour, + minute=t.minute, + second=t.second, + ).dict() + } diff --git a/nam/models/losses.py b/nam/models/losses.py @@ -11,6 +11,8 @@ from typing import Optional import torch from auraloss.freq import MultiResolutionSTFTLoss +___all__ = ["esr", "multi_resolution_stft_loss"] + def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ diff --git a/nam/models/metadata.py b/nam/models/metadata.py @@ -0,0 +1,53 @@ +# File: metadata.py +# Created Date: Wednesday April 12th 2023 +# Author: Steven Atkinson (steven@atkinson.mn) + +""" +Metadata about models +""" + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +__all__ = ["GearType", "ToneType", "Date", "UserMetadata"] + + +class GearType(Enum): + AMP = "amp" + PEDAL = "pedal" + AMP_CAB = "amp_cab" + AMP_PEDAL_CAB = "amp_pedal_cab" + PREAMP = "preamp" + STUDIO = "studio" + + +class ToneType(Enum): + CLEAN = "clean" + OVERDRIVE = "overdrive" + CRUNCH = "crunch" + HI_GAIN = "hi_gain" + FUZZ = "fuzz" + + +class Date(BaseModel): + year: int + month: int + day: int + hour: int + minute: int + second: int + + +class UserMetadata(BaseModel): + """ + Metadata that users provide for a NAM model + """ + + name: Optional[str] = None + modeled_by: Optional[str] = None + gear_type: Optional[GearType] = None + gear_make: Optional[str] = None + gear_model: Optional[str] = None + tone_type: Optional[ToneType] = None diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -29,6 +29,7 @@ import re import tkinter as tk from dataclasses import dataclass from enum import Enum +from functools import partial from pathlib import Path from tkinter import filedialog from typing import Callable, Optional, Sequence @@ -36,6 +37,7 @@ from typing import Callable, Optional, Sequence try: from nam import __version__ from nam.train import core + from nam.models.metadata import GearType, UserMetadata, ToneType _install_is_valid = True except ImportError: @@ -47,6 +49,10 @@ _TEXT_WIDTH = 70 _DEFAULT_NUM_EPOCHS = 100 _DEFAULT_DELAY = None +_ADVANCED_OPTIONS_LEFT_WIDTH = 12 +_ADVANCED_OPTIONS_RIGHT_WIDTH = 12 +_METADATA_RIGHT_WIDTH = 60 + @dataclass class _AdvancedOptions(object): @@ -164,6 +170,21 @@ class _GUI(object): hooks=[self._check_button_states], ) + # Metadata + self.user_metadata = UserMetadata() + self._frame_metadata = tk.Frame(self._root) + self._frame_metadata.pack() + self._button_metadata = tk.Button( + self._frame_metadata, + text="Metadata...", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._open_metadata, + ) + self._button_metadata.pack() + self.user_metadata_flag = False + # This should probably be to the right somewhere self._get_additional_options_frame() @@ -236,6 +257,14 @@ class _GUI(object): ao.mainloop() # ...and then re-enable it once it gets closed. + def _open_metadata(self): + """ + Open dialog for metadata + """ + mdata = _UserMetadataGUI(self) + # I should probably disable the main GUI... + mdata.mainloop() + def _train(self): # Advanced options: num_epochs = self.advanced_options.num_epochs @@ -247,13 +276,13 @@ class _GUI(object): # If you're poking around looking for these, then maybe it's time to learn to # use the command-line scripts ;) lr = 0.004 - lr_decay = 0.007 + lr_decay = 0.05 seed = 0 # Run it for file in file_list: print("Now training {}".format(file)) - modelname = re.sub(r"\.wav$", "", file.split("/")[-1]) + basename = re.sub(r"\.wav$", "", file.split("/")[-1]) trained_model = core.train( self._path_button_input.val, @@ -267,13 +296,22 @@ class _GUI(object): seed=seed, silent=self._silent.get(), save_plot=self._save_plot.get(), - modelname=modelname, + modelname=basename, ) print("Model training complete!") print("Exporting...") outdir = self._path_button_train_destination.val print(f"Exporting trained model to {outdir}...") - trained_model.net.export(outdir, modelname=modelname) + trained_model.net.export( + outdir, + basename=basename, + user_metadata=self.user_metadata + if self.user_metadata_flag + else UserMetadata(), + ) + # Metadata was only valid for 1 run, so make sure it's not used again unless + # the user re-visits the window and clicks "ok" + self.user_metadata_flag = False print("Done!") def _check_button_states(self): @@ -294,8 +332,27 @@ class _GUI(object): self._button_train["state"] = tk.NORMAL -_ADVANCED_OPTIONS_LEFT_WIDTH = 12 -_ADVANCED_OPTIONS_RIGHT_WIDTH = 12 +# some typing functions +def _non_negative_int(val): + val = int(val) + if val < 0: + val = 0 + return val + + +def _int_or_null(val): + val = val.rstrip() + if val == "null": + return val + return int(val) + + +def _int_or_null_inv(val): + return "null" if val is None else str(val) + + +def _rstripped_str(val): + return str(val).rstrip() class _LabeledOptionMenu(object): @@ -358,7 +415,15 @@ class _LabeledText(object): Label (left) and text input (right) """ - def __init__(self, frame: tk.Frame, label: str, default=None, type=None): + def __init__( + self, + frame: tk.Frame, + label: str, + default=None, + type=None, + left_width=_ADVANCED_OPTIONS_LEFT_WIDTH, + right_width=_ADVANCED_OPTIONS_RIGHT_WIDTH, + ): """ :param command: Called to propagate option selection. Is provided with the value corresponding to the radio button selected. @@ -369,7 +434,7 @@ class _LabeledText(object): text_height = 1 self._label = tk.Label( frame, - width=_ADVANCED_OPTIONS_LEFT_WIDTH, + width=left_width, height=label_height, fg="black", bg=None, @@ -380,7 +445,7 @@ class _LabeledText(object): self._text = tk.Text( frame, - width=_ADVANCED_OPTIONS_RIGHT_WIDTH, + width=right_width, height=text_height, fg="black", bg=None, @@ -426,37 +491,22 @@ class _AdvancedOptionsGUI(object): self._frame_epochs = tk.Frame(self._root) self._frame_epochs.pack() - def non_negative_int(val): - val = int(val) - if val < 0: - val = 0 - return val - self._epochs = _LabeledText( self._frame_epochs, "Epochs", default=str(self._parent.advanced_options.num_epochs), - type=non_negative_int, + type=_non_negative_int, ) # Delay: text box self._frame_delay = tk.Frame(self._root) self._frame_delay.pack() - def int_or_null(val): - val = val.rstrip() - if val == "null": - return val - return int(val) - - def int_or_null_inv(val): - return "null" if val is None else str(val) - self._delay = _LabeledText( self._frame_delay, "Delay", - default=int_or_null_inv(self._parent.advanced_options.delay), - type=int_or_null, + default=_int_or_null_inv(self._parent.advanced_options.delay), + type=_int_or_null, ) # "Ok": apply and destory @@ -490,6 +540,103 @@ class _AdvancedOptionsGUI(object): self._root.destroy() +class _UserMetadataGUI(object): + # Things that are auto-filled: + # Model date + # gain + def __init__(self, parent: _GUI): + self._parent = parent + self._root = tk.Tk() + self._root.title("Metadata") + + LabeledText = partial(_LabeledText, right_width=_METADATA_RIGHT_WIDTH) + + # Name + self._frame_name = tk.Frame(self._root) + self._frame_name.pack() + self._name = LabeledText( + self._frame_name, + "NAM name", + default=parent.user_metadata.name, + type=_rstripped_str, + ) + # Modeled by + self._frame_modeled_by = tk.Frame(self._root) + self._frame_modeled_by.pack() + self._modeled_by = LabeledText( + self._frame_modeled_by, + "Modeled by", + default=parent.user_metadata.modeled_by, + type=_rstripped_str, + ) + # Gear make + self._frame_gear_make = tk.Frame(self._root) + self._frame_gear_make.pack() + self._gear_make = LabeledText( + self._frame_gear_make, + "Gear make", + default=parent.user_metadata.gear_make, + type=_rstripped_str, + ) + # Gear model + self._frame_gear_model = tk.Frame(self._root) + self._frame_gear_model.pack() + self._gear_model = LabeledText( + self._frame_gear_model, + "Gear model", + default=parent.user_metadata.gear_model, + type=_rstripped_str, + ) + # Gear type + self._frame_gear_type = tk.Frame(self._root) + self._frame_gear_type.pack() + self._gear_type = _LabeledOptionMenu( + self._frame_gear_type, + "Gear type", + GearType, + default=parent.user_metadata.gear_type, + ) + # Tone type + self._frame_tone_type = tk.Frame(self._root) + self._frame_tone_type.pack() + self._tone_type = _LabeledOptionMenu( + self._frame_tone_type, + "Tone type", + ToneType, + default=parent.user_metadata.tone_type, + ) + + # "Ok": apply and destory + self._frame_ok = tk.Frame(self._root) + self._frame_ok.pack() + self._button_ok = tk.Button( + self._frame_ok, + text="Ok", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._apply_and_destroy, + ) + self._button_ok.pack() + + def mainloop(self): + self._root.mainloop() + + def _apply_and_destroy(self): + """ + Set values to parent and destroy this object + """ + self._parent.user_metadata.name = self._name.get() + self._parent.user_metadata.modeled_by = self._modeled_by.get() + self._parent.user_metadata.gear_make = self._gear_make.get() + self._parent.user_metadata.gear_model = self._gear_model.get() + self._parent.user_metadata.gear_type = self._gear_type.get() + self._parent.user_metadata.tone_type = self._tone_type.get() + self._parent.user_metadata_flag = True + + self._root.destroy() + + def _install_error(): window = tk.Tk() window.title("ERROR") diff --git a/requirements.txt b/requirements.txt @@ -2,7 +2,7 @@ # Created Date: 2021-01-24 # Author: Steven Atkinson (steven@atkinson.mn) -auraloss +auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss black flake8 matplotlib @@ -11,6 +11,7 @@ onnx onnxruntime pip pre-commit +pydantic pytest pytest-mock pytorch_lightning diff --git a/setup.py b/setup.py @@ -11,9 +11,10 @@ with open(ver_path) as ver_file: exec(ver_file.read(), main_ns) requirements = [ - "auraloss", + "auraloss==0.3.0", "matplotlib", "numpy", + "pydantic", "pytorch_lightning", "scipy", "sounddevice", diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py @@ -40,12 +40,18 @@ class _MockBaseNet(_base.BaseNet): return self.gain * x +def test_metadata_gain(): + obj = _MockBaseNet(1.0) + g = obj._metadata_gain() + # It's linear, so gain is zero. + assert g == 0.0 + -def test_loudness(): +def test_metadata_loudness(): obj = _MockBaseNet(1.0) - y = obj._loudness() + y = obj._metadata_loudness() obj.gain = 2.0 - y2 = obj._loudness() + y2 = obj._metadata_loudness() assert isinstance(y, float) # 2x louder = +6dB assert y2 == pytest.approx(y + 20.0 * math.log10(2.0)) diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py @@ -7,9 +7,10 @@ Test export behavior of models """ import json +from enum import Enum from pathlib import Path from tempfile import TemporaryDirectory -from typing import Tuple +from typing import Optional, Tuple import numpy as np import pytest @@ -17,6 +18,7 @@ import torch import torch.nn as nn from nam.models import _exportable +from nam.models import metadata class TestExportable(object): @@ -41,6 +43,43 @@ class TestExportable(object): assert len(weights_list) == 2 assert all(isinstance(w, float) for w in weights_list) + @pytest.mark.parametrize( + "user_metadata", + ( + None, + metadata.UserMetadata(), + metadata.UserMetadata( + name="My Model", + modeled_by="Steve", + gear_type=metadata.GearType.AMP, + gear_make="SteveCo", + gear_model="SteveAmp", + tone_type=metadata.ToneType.HI_GAIN, + ), + ), + ) + def test_export_metadata(self, user_metadata: Optional[metadata.UserMetadata]): + """ + Assert export behavior when metadata is provided + """ + model = self._get_model() + with TemporaryDirectory() as tmpdir: + model.export(tmpdir, user_metadata=user_metadata) + model_basename = "model.nam" + model_path = Path(tmpdir, model_basename) + assert model_path.exists() + with open(model_path, "r") as fp: + model_dict = json.load(fp) + metadata_key = "metadata" + assert metadata_key in model_dict + model_dict_metadata = model_dict[metadata_key] + if user_metadata is not None: + for key, expected_val in user_metadata.dict().items(): + if isinstance(expected_val, Enum): + expected_val = expected_val.value + assert key in model_dict_metadata + assert model_dict_metadata[key] == expected_val + @pytest.mark.parametrize("include_snapshot", (True, False)) def test_include_snapshot(self, include_snapshot): """