commit 5c16c76c7bdbced07633d7d988e348ff17033e97
parent 31a211326f9468428ba99dae70c9d32ecff2d87a
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 22 Apr 2023 17:29:32 -0700
Metadata in GUI trainer (#202)
* STart on metadata for GUI
* Add pydantic
* Pin auraloss to 0.3.0
* Metadata in GUI trainer
Diffstat:
11 files changed, 366 insertions(+), 53 deletions(-)
diff --git a/environment_cpu.yml b/environment_cpu.yml
@@ -14,6 +14,7 @@ dependencies:
- matplotlib
- numpy
- pip
+ - pydantic
- pytest
- pytest-mock
- pytorch
@@ -23,7 +24,7 @@ dependencies:
- tqdm
- wheel
- pip:
- - auraloss
+ - auraloss==0.3.0
- onnx
- onnxruntime
- pre-commit
diff --git a/environment_gpu.yml b/environment_gpu.yml
@@ -15,6 +15,7 @@ dependencies:
- matplotlib
- numpy
- pip
+ - pydantic
- pytest
- pytest-mock
- pytorch
@@ -25,7 +26,7 @@ dependencies:
- tqdm
- wheel
- pip:
- - auraloss
+ - auraloss==0.3.0
- onnx
- onnxruntime # TODO GPU...
- pre-commit
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -10,7 +10,7 @@ steps)
import abc
import math
import pkg_resources
-from typing import Any, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
@@ -37,7 +37,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
def forward(self, *args, **kwargs) -> torch.Tensor:
pass
- def _loudness(self, gain: float = 1.0) -> float:
+ def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
"""
How loud is this model when given a standardized input?
In dB
@@ -50,7 +50,33 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
)
)
y = self._at_nominal_settings(gain * x)
- return 10.0 * torch.log10(torch.mean(torch.square(y))).item()
+ loudness = torch.sqrt(torch.mean(torch.square(y)))
+ if db:
+ loudness = 20.0 * torch.log10(loudness)
+ return loudness.item()
+
+ def _metadata_gain(self) -> float:
+ """
+ Between 0 and 1, how much gain / compression does the model seem to have?
+ """
+ x = np.linspace(0.0, 1.0, 11)
+ y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x])
+ #
+ # O ^ o o o o o o
+ # u | o x +-------------------------------------+
+ # t | o x | x: Minimum gain (no compression) |
+ # p | o x | o: Max gain (100% compression) |
+ # u | o x +-------------------------------------+
+ # t | o
+ # +------------->
+ # Input
+ #
+ max_gain = y[-1] * len(x) # "Square"
+ min_gain = 0.5 * max_gain # "Triangle"
+ gain_range = max_gain - min_gain
+ this_gain = y.sum()
+ normalized_gain = (this_gain - min_gain) / gain_range
+ return np.clip(normalized_gain, 0.0, 1.0)
def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
# parametric?...
@@ -91,10 +117,6 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
self(*args, x, pad_start=True).detach().cpu().numpy(),
)
- def _get_export_dict(self):
- d = super()._get_export_dict()
- return d
-
class BaseNet(_Base):
def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None):
@@ -122,9 +144,10 @@ class BaseNet(_Base):
"""
pass
- def _get_export_dict(self):
- d = super()._get_export_dict()
- d["metadata"]["loudness"] = self._loudness()
+ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
+ d = super()._get_non_user_metadata()
+ d["loudness"] = self._metadata_loudness()
+ d["gain"] = self._metadata_gain()
return d
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -5,23 +5,44 @@
import abc
import json
import logging
+from datetime import datetime
+from enum import Enum
from pathlib import Path
-from typing import Tuple
+from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
from .._version import __version__
from ..data import np_to_wav
+from .metadata import Date, UserMetadata
logger = logging.getLogger(__name__)
+def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
+ """
+ Casts enum-type keys to their values
+ """
+ out = {}
+ for key, val in d.items():
+ if isinstance(val, Enum):
+ val = val.value
+ out[key] = val
+ return out
+
+
class Exportable(abc.ABC):
"""
Interface for my custon export format for use in the plugin.
"""
- def export(self, outdir: Path, include_snapshot: bool = False, modelname: str = "model"):
+ def export(
+ self,
+ outdir: Path,
+ include_snapshot: bool = False,
+ basename: str = "model",
+ user_metadata: Optional[UserMetadata] = None,
+ ):
"""
Interface for exporting.
You should create at least a `config.json` containing the two fields:
@@ -35,13 +56,15 @@ class Exportable(abc.ABC):
Can be used to debug e.g. the implementation of the model in the
plugin.
"""
+ model_dict = self._get_export_dict()
+ model_dict["metadata"].update(
+ {} if user_metadata is None else _cast_enums(user_metadata.dict())
+ )
+
training = self.training
self.eval()
- with open(Path(outdir, modelname + ".nam"), "w") as fp:
- json.dump(
- self._get_export_dict(),
- fp,
- )
+ with open(Path(outdir, f"{basename}.nam"), "w") as fp:
+ json.dump(model_dict, fp)
if include_snapshot:
x, y = self._export_input_output()
x_path = Path(outdir, "test_inputs.npy")
@@ -101,8 +124,24 @@ class Exportable(abc.ABC):
def _get_export_dict(self):
return {
"version": __version__,
+ "metadata": self._get_non_user_metadata(),
"architecture": self.__class__.__name__,
"config": self._export_config(),
- "metadata": {},
"weights": self._export_weights().tolist(),
}
+
+ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
+ """
+ Get any metadata that's non-user-provided (date, loudness, gain)
+ """
+ t = datetime.now()
+ return {
+ "date": Date(
+ year=t.year,
+ month=t.month,
+ day=t.day,
+ hour=t.hour,
+ minute=t.minute,
+ second=t.second,
+ ).dict()
+ }
diff --git a/nam/models/losses.py b/nam/models/losses.py
@@ -11,6 +11,8 @@ from typing import Optional
import torch
from auraloss.freq import MultiResolutionSTFTLoss
+___all__ = ["esr", "multi_resolution_stft_loss"]
+
def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
diff --git a/nam/models/metadata.py b/nam/models/metadata.py
@@ -0,0 +1,53 @@
+# File: metadata.py
+# Created Date: Wednesday April 12th 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+"""
+Metadata about models
+"""
+
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+__all__ = ["GearType", "ToneType", "Date", "UserMetadata"]
+
+
+class GearType(Enum):
+ AMP = "amp"
+ PEDAL = "pedal"
+ AMP_CAB = "amp_cab"
+ AMP_PEDAL_CAB = "amp_pedal_cab"
+ PREAMP = "preamp"
+ STUDIO = "studio"
+
+
+class ToneType(Enum):
+ CLEAN = "clean"
+ OVERDRIVE = "overdrive"
+ CRUNCH = "crunch"
+ HI_GAIN = "hi_gain"
+ FUZZ = "fuzz"
+
+
+class Date(BaseModel):
+ year: int
+ month: int
+ day: int
+ hour: int
+ minute: int
+ second: int
+
+
+class UserMetadata(BaseModel):
+ """
+ Metadata that users provide for a NAM model
+ """
+
+ name: Optional[str] = None
+ modeled_by: Optional[str] = None
+ gear_type: Optional[GearType] = None
+ gear_make: Optional[str] = None
+ gear_model: Optional[str] = None
+ tone_type: Optional[ToneType] = None
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -29,6 +29,7 @@ import re
import tkinter as tk
from dataclasses import dataclass
from enum import Enum
+from functools import partial
from pathlib import Path
from tkinter import filedialog
from typing import Callable, Optional, Sequence
@@ -36,6 +37,7 @@ from typing import Callable, Optional, Sequence
try:
from nam import __version__
from nam.train import core
+ from nam.models.metadata import GearType, UserMetadata, ToneType
_install_is_valid = True
except ImportError:
@@ -47,6 +49,10 @@ _TEXT_WIDTH = 70
_DEFAULT_NUM_EPOCHS = 100
_DEFAULT_DELAY = None
+_ADVANCED_OPTIONS_LEFT_WIDTH = 12
+_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
+_METADATA_RIGHT_WIDTH = 60
+
@dataclass
class _AdvancedOptions(object):
@@ -164,6 +170,21 @@ class _GUI(object):
hooks=[self._check_button_states],
)
+ # Metadata
+ self.user_metadata = UserMetadata()
+ self._frame_metadata = tk.Frame(self._root)
+ self._frame_metadata.pack()
+ self._button_metadata = tk.Button(
+ self._frame_metadata,
+ text="Metadata...",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._open_metadata,
+ )
+ self._button_metadata.pack()
+ self.user_metadata_flag = False
+
# This should probably be to the right somewhere
self._get_additional_options_frame()
@@ -236,6 +257,14 @@ class _GUI(object):
ao.mainloop()
# ...and then re-enable it once it gets closed.
+ def _open_metadata(self):
+ """
+ Open dialog for metadata
+ """
+ mdata = _UserMetadataGUI(self)
+ # I should probably disable the main GUI...
+ mdata.mainloop()
+
def _train(self):
# Advanced options:
num_epochs = self.advanced_options.num_epochs
@@ -247,13 +276,13 @@ class _GUI(object):
# If you're poking around looking for these, then maybe it's time to learn to
# use the command-line scripts ;)
lr = 0.004
- lr_decay = 0.007
+ lr_decay = 0.05
seed = 0
# Run it
for file in file_list:
print("Now training {}".format(file))
- modelname = re.sub(r"\.wav$", "", file.split("/")[-1])
+ basename = re.sub(r"\.wav$", "", file.split("/")[-1])
trained_model = core.train(
self._path_button_input.val,
@@ -267,13 +296,22 @@ class _GUI(object):
seed=seed,
silent=self._silent.get(),
save_plot=self._save_plot.get(),
- modelname=modelname,
+ modelname=basename,
)
print("Model training complete!")
print("Exporting...")
outdir = self._path_button_train_destination.val
print(f"Exporting trained model to {outdir}...")
- trained_model.net.export(outdir, modelname=modelname)
+ trained_model.net.export(
+ outdir,
+ basename=basename,
+ user_metadata=self.user_metadata
+ if self.user_metadata_flag
+ else UserMetadata(),
+ )
+ # Metadata was only valid for 1 run, so make sure it's not used again unless
+ # the user re-visits the window and clicks "ok"
+ self.user_metadata_flag = False
print("Done!")
def _check_button_states(self):
@@ -294,8 +332,27 @@ class _GUI(object):
self._button_train["state"] = tk.NORMAL
-_ADVANCED_OPTIONS_LEFT_WIDTH = 12
-_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
+# some typing functions
+def _non_negative_int(val):
+ val = int(val)
+ if val < 0:
+ val = 0
+ return val
+
+
+def _int_or_null(val):
+ val = val.rstrip()
+ if val == "null":
+ return val
+ return int(val)
+
+
+def _int_or_null_inv(val):
+ return "null" if val is None else str(val)
+
+
+def _rstripped_str(val):
+ return str(val).rstrip()
class _LabeledOptionMenu(object):
@@ -358,7 +415,15 @@ class _LabeledText(object):
Label (left) and text input (right)
"""
- def __init__(self, frame: tk.Frame, label: str, default=None, type=None):
+ def __init__(
+ self,
+ frame: tk.Frame,
+ label: str,
+ default=None,
+ type=None,
+ left_width=_ADVANCED_OPTIONS_LEFT_WIDTH,
+ right_width=_ADVANCED_OPTIONS_RIGHT_WIDTH,
+ ):
"""
:param command: Called to propagate option selection. Is provided with the
value corresponding to the radio button selected.
@@ -369,7 +434,7 @@ class _LabeledText(object):
text_height = 1
self._label = tk.Label(
frame,
- width=_ADVANCED_OPTIONS_LEFT_WIDTH,
+ width=left_width,
height=label_height,
fg="black",
bg=None,
@@ -380,7 +445,7 @@ class _LabeledText(object):
self._text = tk.Text(
frame,
- width=_ADVANCED_OPTIONS_RIGHT_WIDTH,
+ width=right_width,
height=text_height,
fg="black",
bg=None,
@@ -426,37 +491,22 @@ class _AdvancedOptionsGUI(object):
self._frame_epochs = tk.Frame(self._root)
self._frame_epochs.pack()
- def non_negative_int(val):
- val = int(val)
- if val < 0:
- val = 0
- return val
-
self._epochs = _LabeledText(
self._frame_epochs,
"Epochs",
default=str(self._parent.advanced_options.num_epochs),
- type=non_negative_int,
+ type=_non_negative_int,
)
# Delay: text box
self._frame_delay = tk.Frame(self._root)
self._frame_delay.pack()
- def int_or_null(val):
- val = val.rstrip()
- if val == "null":
- return val
- return int(val)
-
- def int_or_null_inv(val):
- return "null" if val is None else str(val)
-
self._delay = _LabeledText(
self._frame_delay,
"Delay",
- default=int_or_null_inv(self._parent.advanced_options.delay),
- type=int_or_null,
+ default=_int_or_null_inv(self._parent.advanced_options.delay),
+ type=_int_or_null,
)
# "Ok": apply and destory
@@ -490,6 +540,103 @@ class _AdvancedOptionsGUI(object):
self._root.destroy()
+class _UserMetadataGUI(object):
+ # Things that are auto-filled:
+ # Model date
+ # gain
+ def __init__(self, parent: _GUI):
+ self._parent = parent
+ self._root = tk.Tk()
+ self._root.title("Metadata")
+
+ LabeledText = partial(_LabeledText, right_width=_METADATA_RIGHT_WIDTH)
+
+ # Name
+ self._frame_name = tk.Frame(self._root)
+ self._frame_name.pack()
+ self._name = LabeledText(
+ self._frame_name,
+ "NAM name",
+ default=parent.user_metadata.name,
+ type=_rstripped_str,
+ )
+ # Modeled by
+ self._frame_modeled_by = tk.Frame(self._root)
+ self._frame_modeled_by.pack()
+ self._modeled_by = LabeledText(
+ self._frame_modeled_by,
+ "Modeled by",
+ default=parent.user_metadata.modeled_by,
+ type=_rstripped_str,
+ )
+ # Gear make
+ self._frame_gear_make = tk.Frame(self._root)
+ self._frame_gear_make.pack()
+ self._gear_make = LabeledText(
+ self._frame_gear_make,
+ "Gear make",
+ default=parent.user_metadata.gear_make,
+ type=_rstripped_str,
+ )
+ # Gear model
+ self._frame_gear_model = tk.Frame(self._root)
+ self._frame_gear_model.pack()
+ self._gear_model = LabeledText(
+ self._frame_gear_model,
+ "Gear model",
+ default=parent.user_metadata.gear_model,
+ type=_rstripped_str,
+ )
+ # Gear type
+ self._frame_gear_type = tk.Frame(self._root)
+ self._frame_gear_type.pack()
+ self._gear_type = _LabeledOptionMenu(
+ self._frame_gear_type,
+ "Gear type",
+ GearType,
+ default=parent.user_metadata.gear_type,
+ )
+ # Tone type
+ self._frame_tone_type = tk.Frame(self._root)
+ self._frame_tone_type.pack()
+ self._tone_type = _LabeledOptionMenu(
+ self._frame_tone_type,
+ "Tone type",
+ ToneType,
+ default=parent.user_metadata.tone_type,
+ )
+
+ # "Ok": apply and destory
+ self._frame_ok = tk.Frame(self._root)
+ self._frame_ok.pack()
+ self._button_ok = tk.Button(
+ self._frame_ok,
+ text="Ok",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._apply_and_destroy,
+ )
+ self._button_ok.pack()
+
+ def mainloop(self):
+ self._root.mainloop()
+
+ def _apply_and_destroy(self):
+ """
+ Set values to parent and destroy this object
+ """
+ self._parent.user_metadata.name = self._name.get()
+ self._parent.user_metadata.modeled_by = self._modeled_by.get()
+ self._parent.user_metadata.gear_make = self._gear_make.get()
+ self._parent.user_metadata.gear_model = self._gear_model.get()
+ self._parent.user_metadata.gear_type = self._gear_type.get()
+ self._parent.user_metadata.tone_type = self._tone_type.get()
+ self._parent.user_metadata_flag = True
+
+ self._root.destroy()
+
+
def _install_error():
window = tk.Tk()
window.title("ERROR")
diff --git a/requirements.txt b/requirements.txt
@@ -2,7 +2,7 @@
# Created Date: 2021-01-24
# Author: Steven Atkinson (steven@atkinson.mn)
-auraloss
+auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss
black
flake8
matplotlib
@@ -11,6 +11,7 @@ onnx
onnxruntime
pip
pre-commit
+pydantic
pytest
pytest-mock
pytorch_lightning
diff --git a/setup.py b/setup.py
@@ -11,9 +11,10 @@ with open(ver_path) as ver_file:
exec(ver_file.read(), main_ns)
requirements = [
- "auraloss",
+ "auraloss==0.3.0",
"matplotlib",
"numpy",
+ "pydantic",
"pytorch_lightning",
"scipy",
"sounddevice",
diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py
@@ -40,12 +40,18 @@ class _MockBaseNet(_base.BaseNet):
return self.gain * x
+def test_metadata_gain():
+ obj = _MockBaseNet(1.0)
+ g = obj._metadata_gain()
+ # It's linear, so gain is zero.
+ assert g == 0.0
+
-def test_loudness():
+def test_metadata_loudness():
obj = _MockBaseNet(1.0)
- y = obj._loudness()
+ y = obj._metadata_loudness()
obj.gain = 2.0
- y2 = obj._loudness()
+ y2 = obj._metadata_loudness()
assert isinstance(y, float)
# 2x louder = +6dB
assert y2 == pytest.approx(y + 20.0 * math.log10(2.0))
diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py
@@ -7,9 +7,10 @@ Test export behavior of models
"""
import json
+from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import Tuple
+from typing import Optional, Tuple
import numpy as np
import pytest
@@ -17,6 +18,7 @@ import torch
import torch.nn as nn
from nam.models import _exportable
+from nam.models import metadata
class TestExportable(object):
@@ -41,6 +43,43 @@ class TestExportable(object):
assert len(weights_list) == 2
assert all(isinstance(w, float) for w in weights_list)
+ @pytest.mark.parametrize(
+ "user_metadata",
+ (
+ None,
+ metadata.UserMetadata(),
+ metadata.UserMetadata(
+ name="My Model",
+ modeled_by="Steve",
+ gear_type=metadata.GearType.AMP,
+ gear_make="SteveCo",
+ gear_model="SteveAmp",
+ tone_type=metadata.ToneType.HI_GAIN,
+ ),
+ ),
+ )
+ def test_export_metadata(self, user_metadata: Optional[metadata.UserMetadata]):
+ """
+ Assert export behavior when metadata is provided
+ """
+ model = self._get_model()
+ with TemporaryDirectory() as tmpdir:
+ model.export(tmpdir, user_metadata=user_metadata)
+ model_basename = "model.nam"
+ model_path = Path(tmpdir, model_basename)
+ assert model_path.exists()
+ with open(model_path, "r") as fp:
+ model_dict = json.load(fp)
+ metadata_key = "metadata"
+ assert metadata_key in model_dict
+ model_dict_metadata = model_dict[metadata_key]
+ if user_metadata is not None:
+ for key, expected_val in user_metadata.dict().items():
+ if isinstance(expected_val, Enum):
+ expected_val = expected_val.value
+ assert key in model_dict_metadata
+ assert model_dict_metadata[key] == expected_val
+
@pytest.mark.parametrize("include_snapshot", (True, False))
def test_include_snapshot(self, include_snapshot):
"""