commit fae4a1301ada7253e39f3bbcdf61dc0ec79e0ea3
parent ac32ef536139892ec9318ec63471211b0ce77259
Author: Steven Atkinson <steven@atkinson.mn>
Date: Tue, 14 May 2024 23:14:06 -0700
[FEATURE] GUI trainer: Track the last directories used for input/output/train destination (#417)
* Move gui.py to gui/__init__.py
* Track the last paths for input file, output file, and training destination
* Fix
Diffstat:
5 files changed, 994 insertions(+), 909 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -132,3 +132,6 @@ dmypy.json
# Training outputs
./bin/train/outputs/
+
+# cache files
+./nam/train/gui/_resources/settings.json
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -1,909 +0,0 @@
-# File: gui.py
-# Created Date: Saturday February 25th 2023
-# Author: Steven Atkinson (steven@atkinson.mn)
-
-"""
-GUI for training
-
-Usage:
->>> from nam.train.gui import run
->>> run()
-"""
-
-
-# Hack to recover graceful shutdowns in Windows.
-# This has to happen ASAP
-# See:
-# https://github.com/sdatkinson/neural-amp-modeler/issues/105
-# https://stackoverflow.com/a/44822794
-def _ensure_graceful_shutdowns():
- import os
-
- if os.name == "nt": # OS is Windows
- os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
-
-
-_ensure_graceful_shutdowns()
-
-import re
-import tkinter as tk
-import webbrowser
-from dataclasses import dataclass
-from enum import Enum
-from functools import partial
-from pathlib import Path
-from tkinter import filedialog
-from typing import Callable, Dict, Optional, Sequence
-
-try: # 3rd-party and 1st-party imports
- import torch
-
- from nam import __version__
- from nam.train import core
- from nam.models.metadata import GearType, UserMetadata, ToneType
-
- # Ok private access here--this is technically allowed access
- from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
-
- _install_is_valid = True
- _HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
-except ImportError:
- _install_is_valid = False
- _HAVE_ACCELERATOR = False
-
-if _HAVE_ACCELERATOR:
- _DEFAULT_NUM_EPOCHS = 100
- _DEFAULT_BATCH_SIZE = 16
- _DEFAULT_LR_DECAY = 0.007
-else:
- _DEFAULT_NUM_EPOCHS = 20
- _DEFAULT_BATCH_SIZE = 1
- _DEFAULT_LR_DECAY = 0.05
-_BUTTON_WIDTH = 20
-_BUTTON_HEIGHT = 2
-_TEXT_WIDTH = 70
-
-_DEFAULT_DELAY = None
-_DEFAULT_IGNORE_CHECKS = False
-_DEFAULT_THRESHOLD_ESR = None
-
-_ADVANCED_OPTIONS_LEFT_WIDTH = 12
-_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
-_METADATA_RIGHT_WIDTH = 60
-
-
-@dataclass
-class _AdvancedOptions(object):
- """
- :param architecture: Which architecture to use.
- :param num_epochs: How many epochs to train for.
- :param latency: Latency between the input and output audio, in samples.
- None means we don't know and it has to be calibrated.
- :param ignore_checks: Keep going even if a check says that something is wrong.
- :param threshold_esr: Stop training if the ESR gets better than this. If None, don't
- stop.
- """
-
- architecture: core.Architecture
- num_epochs: int
- latency: Optional[int]
- ignore_checks: bool
- threshold_esr: Optional[float]
-
-
-class _PathType(Enum):
- FILE = "file"
- DIRECTORY = "directory"
- MULTIFILE = "multifile"
-
-
-class _PathButton(object):
- """
- Button and the path
- """
-
- def __init__(
- self,
- frame: tk.Frame,
- button_text: str,
- info_str: str,
- path_type: _PathType,
- hooks: Optional[Sequence[Callable[[], None]]] = None,
- color_when_not_set: str = "#EF0000", # Darker red
- default: Optional[Path] = None,
- ):
- """
- :param hooks: Callables run at the end of setting the value.
- """
- self._button_text = button_text
- self._info_str = info_str
- self._path: Optional[Path] = default
- self._path_type = path_type
- self._frame = frame
- self._widgets = {}
- self._widgets["button"] = tk.Button(
- self._frame,
- text=button_text,
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=self._set_val,
- )
- self._widgets["button"].pack(side=tk.LEFT)
- self._widgets["label"] = tk.Label(
- self._frame,
- width=_TEXT_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- bg=None,
- anchor="w",
- )
- self._widgets["label"].pack(side=tk.LEFT)
- self._hooks = hooks
- self._color_when_not_set = color_when_not_set
- self._set_text()
-
- def __setitem__(self, key, val):
- """
- Implement tk-style setter for state
- """
- if key == "state":
- for widget in self._widgets.values():
- widget["state"] = val
- else:
- raise RuntimeError(
- f"{self.__class__.__name__} instance does not support item assignment for non-state key {key}!"
- )
-
- @property
- def val(self) -> Optional[Path]:
- return self._path
-
- def _set_text(self):
- if self._path is None:
- self._widgets["label"]["fg"] = self._color_when_not_set
- self._widgets["label"]["text"] = self._info_str
- else:
- val = self.val
- val = val[0] if isinstance(val, tuple) and len(val) == 1 else val
- self._widgets["label"]["fg"] = "black"
- self._widgets["label"][
- "text"
- ] = f"{self._button_text.capitalize()} set to {val}"
-
- def _set_val(self):
- res = {
- _PathType.FILE: filedialog.askopenfilename,
- _PathType.DIRECTORY: filedialog.askdirectory,
- _PathType.MULTIFILE: filedialog.askopenfilenames,
- }[self._path_type]()
- if res != "":
- self._path = res
- self._set_text()
-
- if self._hooks is not None:
- for h in self._hooks:
- h()
-
-
-class _InputPathButton(_PathButton):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # Download the training file!
- self._widgets["button_download_input"] = tk.Button(
- self._frame,
- text="Download input file",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=self._download_input_file,
- )
- self._widgets["button_download_input"].pack(side=tk.RIGHT)
-
- @classmethod
- def _download_input_file(cls):
- file_urls = {
- "v3_0_0.wav": "https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link",
- "v2_0_0.wav": "https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link",
- "v1_1_1.wav": "",
- "v1.wav": "",
- }
- # Pick the most recent file.
- for input_basename in INPUT_BASENAMES:
- name = input_basename.name
- url = file_urls.get(name)
- if url:
- if name != LATEST_VERSION.name:
- print(
- f"WARNING: File {name} is out of date. "
- "This needs to be updated!"
- )
- webbrowser.open(url)
- return
-
-
-class _ClearablePathButton(_PathButton):
- """
- Can clear a path
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, color_when_not_set="black", **kwargs)
- # Download the training file!
- self._widgets["button_clear"] = tk.Button(
- self._frame,
- text="Clear",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=self._clear_path,
- )
- self._widgets["button_clear"].pack(side=tk.RIGHT)
-
- def _clear_path(self):
- self._path = None
- self._set_text()
-
-
-class _CheckboxKeys(Enum):
- """
- Keys for checkboxes
- """
-
- FIT_CAB = "fit_cab"
- SILENT_TRAINING = "silent_training"
- SAVE_PLOT = "save_plot"
- IGNORE_DATA_CHECKS = "ignore_data_checks"
-
-
-class _TopLevelWithOk(tk.Toplevel):
- """
- Toplevel with an Ok button (provide yourself!)
- """
-
- def __init__(
- self, on_ok: Callable[[None], None], resume_main: Callable[[None], None]
- ):
- """
- :param on_ok: What to do when "Ok" button is pressed
- """
- super().__init__()
- self._on_ok = on_ok
- self._resume_main = resume_main
-
- def destroy(self, pressed_ok: bool = False):
- if pressed_ok:
- self._on_ok()
- self._resume_main()
- super().destroy()
-
-
-class _BasicModal(object):
- """
- Message and OK button
- """
-
- def __init__(self, resume_main, msg: str):
- self._root = _TopLevelWithOk((lambda: None), resume_main)
- self._text = tk.Label(self._root, text=msg)
- self._text.pack()
- self._ok = tk.Button(
- self._root,
- text="Ok",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=lambda: self._root.destroy(pressed_ok=True),
- )
- self._ok.pack()
-
-
-class _GUIWidgets(Enum):
- INPUT_PATH = "input_path"
- OUTPUT_PATH = "output_path"
- TRAINING_DESTINATION = "training_destination"
- METADATA = "metadata"
- ADVANCED_OPTIONS = "advanced_options"
- TRAIN = "train"
-
-
-class _GUI(object):
- def __init__(self):
- self._root = tk.Tk()
- self._root.title(f"NAM Trainer - v{__version__}")
- self._widgets = {}
-
- # Buttons for paths:
- self._frame_input = tk.Frame(self._root)
- self._frame_input.pack(anchor="w")
- self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton(
- self._frame_input,
- "Input Audio",
- f"Select input (DI) file (e.g. {LATEST_VERSION.name})",
- _PathType.FILE,
- hooks=[self._check_button_states],
- )
-
- self._frame_output_path = tk.Frame(self._root)
- self._frame_output_path.pack(anchor="w")
- self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton(
- self._frame_output_path,
- "Output Audio",
- "Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)",
- _PathType.MULTIFILE,
- hooks=[self._check_button_states],
- )
-
- self._frame_train_destination = tk.Frame(self._root)
- self._frame_train_destination.pack(anchor="w")
- self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton(
- self._frame_train_destination,
- "Train Destination",
- "Select training output directory",
- _PathType.DIRECTORY,
- hooks=[self._check_button_states],
- )
-
- # Metadata
- self.user_metadata = UserMetadata()
- self._frame_metadata = tk.Frame(self._root)
- self._frame_metadata.pack(anchor="w")
- self._widgets["metadata"] = tk.Button(
- self._frame_metadata,
- text="Metadata...",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=self._open_metadata,
- )
- self._widgets["metadata"].pack()
- self.user_metadata_flag = False
-
- # This should probably be to the right somewhere
- self._get_additional_options_frame()
-
- # Last frames: avdanced options & train in the SE corner:
- self._frame_advanced_options = tk.Frame(self._root)
- self._frame_train = tk.Frame(self._root)
- # Pack train first so that it's on bottom.
- self._frame_train.pack(side=tk.BOTTOM, anchor="e")
- self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e")
-
- # Advanced options for training
- default_architecture = core.Architecture.STANDARD
- self.advanced_options = _AdvancedOptions(
- default_architecture,
- _DEFAULT_NUM_EPOCHS,
- _DEFAULT_DELAY,
- _DEFAULT_IGNORE_CHECKS,
- _DEFAULT_THRESHOLD_ESR,
- )
- # Window to edit them:
-
- self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = tk.Button(
- self._frame_advanced_options,
- text="Advanced options...",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=self._open_advanced_options,
- )
- self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack()
-
- # Train button
-
- self._widgets[_GUIWidgets.TRAIN] = tk.Button(
- self._frame_train,
- text="Train",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=self._train,
- )
- self._widgets[_GUIWidgets.TRAIN].pack()
-
- self._check_button_states()
-
- def _check_button_states(self):
- """
- Determine if any buttons should be disabled
- """
- # Train button is disabled unless all paths are set
- if any(
- pb.val is None
- for pb in (
- self._widgets[_GUIWidgets.INPUT_PATH],
- self._widgets[_GUIWidgets.OUTPUT_PATH],
- self._widgets[_GUIWidgets.TRAINING_DESTINATION],
- )
- ):
- self._widgets[_GUIWidgets.TRAIN]["state"] = tk.DISABLED
- return
- self._widgets[_GUIWidgets.TRAIN]["state"] = tk.NORMAL
-
- def _get_additional_options_frame(self):
- # Checkboxes
- # TODO get these definitions into __init__()
- self._frame_checkboxes = tk.Frame(self._root)
- self._frame_checkboxes.pack(side=tk.LEFT)
- row = 1
-
- @dataclass
- class Checkbox(object):
- variable: tk.BooleanVar
- check_button: tk.Checkbutton
-
- def make_checkbox(
- key: _CheckboxKeys, text: str, default_value: bool
- ) -> Checkbox:
- variable = tk.BooleanVar()
- variable.set(default_value)
- check_button = tk.Checkbutton(
- self._frame_checkboxes, text=text, variable=variable
- )
- self._checkboxes[key] = Checkbox(variable, check_button)
- self._widgets[key] = check_button # For tracking in set-all-widgets ops
-
- self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
- make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False)
- make_checkbox(
- _CheckboxKeys.SILENT_TRAINING,
- "Silent run (suggested for batch training)",
- False,
- )
- make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True)
- make_checkbox(
- _CheckboxKeys.IGNORE_DATA_CHECKS,
- "Ignore data quality checks (DO AT YOUR OWN RISK!)",
- False,
- )
-
- # Grid them:
- row = 1
- for v in self._checkboxes.values():
- v.check_button.grid(row=row, column=1, sticky="W")
- row += 1
-
- def mainloop(self):
- self._root.mainloop()
-
- def _disable(self):
- self._set_all_widget_states_to(tk.DISABLED)
-
- def _open_advanced_options(self):
- """
- Open window for advanced options
- """
-
- self._wait_while_func(lambda resume: _AdvancedOptionsGUI(resume, self))
-
- def _open_metadata(self):
- """
- Open window for metadata
- """
-
- self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self))
-
- def _resume(self):
- self._set_all_widget_states_to(tk.NORMAL)
- self._check_button_states()
-
- def _set_all_widget_states_to(self, state):
- for widget in self._widgets.values():
- widget["state"] = state
-
- def _train(self):
- # Advanced options:
- num_epochs = self.advanced_options.num_epochs
- architecture = self.advanced_options.architecture
- delay = self.advanced_options.latency
- file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
- threshold_esr = self.advanced_options.threshold_esr
-
- # Advanced-er options
- # If you're poking around looking for these, then maybe it's time to learn to
- # use the command-line scripts ;)
- lr = 0.004
- lr_decay = _DEFAULT_LR_DECAY
- batch_size = _DEFAULT_BATCH_SIZE
- seed = 0
-
- # Run it
- for file in file_list:
- print("Now training {}".format(file))
- basename = re.sub(r"\.wav$", "", file.split("/")[-1])
- user_metadata = (
- self.user_metadata if self.user_metadata_flag else UserMetadata()
- )
-
- trained_model = core.train(
- self._widgets[_GUIWidgets.INPUT_PATH].val,
- file,
- self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
- epochs=num_epochs,
- delay=delay,
- architecture=architecture,
- batch_size=batch_size,
- lr=lr,
- lr_decay=lr_decay,
- seed=seed,
- silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
- save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
- modelname=basename,
- ignore_checks=self._checkboxes[
- _CheckboxKeys.IGNORE_DATA_CHECKS
- ].variable.get(),
- local=True,
- fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
- threshold_esr=threshold_esr,
- user_metadata=user_metadata,
- )
-
- if trained_model is None:
- print("Model training failed! Skip exporting...")
- continue
- print("Model training complete!")
- print("Exporting...")
- outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val
- print(f"Exporting trained model to {outdir}...")
- trained_model.net.export(
- outdir, basename=basename, user_metadata=user_metadata
- )
- print("Done!")
-
- # Metadata was only valid for 1 run, so make sure it's not used again unless
- # the user re-visits the window and clicks "ok"
- self.user_metadata_flag = False
-
- def _wait_while_func(self, func, *args, **kwargs):
- """
- Disable this GUI while something happens.
- That function _needs_ to call the provided self._resume when it's ready to
- release me!
- """
- self._disable()
- func(self._resume, *args, **kwargs)
-
-
-# some typing functions
-def _non_negative_int(val):
- val = int(val)
- if val < 0:
- val = 0
- return val
-
-
-def _type_or_null(T, val):
- val = val.rstrip()
- if val == "null":
- return val
- return T(val)
-
-
-_int_or_null = partial(_type_or_null, int)
-_float_or_null = partial(_type_or_null, float)
-
-
-def _type_or_null_inv(val):
- return "null" if val is None else str(val)
-
-
-def _rstripped_str(val):
- return str(val).rstrip()
-
-
-class _LabeledOptionMenu(object):
- """
- Label (left) and radio buttons (right)
- """
-
- def __init__(
- self, frame: tk.Frame, label: str, choices: Enum, default: Optional[Enum] = None
- ):
- """
- :param command: Called to propagate option selection. Is provided with the
- value corresponding to the radio button selected.
- """
- self._frame = frame
- self._choices = choices
- height = _BUTTON_HEIGHT
- bg = None
- fg = "black"
- self._label = tk.Label(
- frame,
- width=_ADVANCED_OPTIONS_LEFT_WIDTH,
- height=height,
- fg=fg,
- bg=bg,
- anchor="w",
- text=label,
- )
- self._label.pack(side=tk.LEFT)
-
- frame_menu = tk.Frame(frame)
- frame_menu.pack(side=tk.RIGHT)
-
- self._selected_value = None
- default = (list(choices)[0] if default is None else default).value
- self._menu = tk.OptionMenu(
- frame_menu,
- tk.StringVar(master=frame, value=default, name=label),
- # default,
- *[choice.value for choice in choices], # if choice.value!=default],
- command=self._set,
- )
- self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH)
- self._menu.pack(side=tk.RIGHT)
- # Initialize
- self._set(default)
-
- def get(self) -> Enum:
- return self._selected_value
-
- def _set(self, val: str):
- """
- Set the value selected
- """
- self._selected_value = self._choices(val)
-
-
-class _LabeledText(object):
- """
- Label (left) and text input (right)
- """
-
- def __init__(
- self,
- frame: tk.Frame,
- label: str,
- default=None,
- type=None,
- left_width=_ADVANCED_OPTIONS_LEFT_WIDTH,
- right_width=_ADVANCED_OPTIONS_RIGHT_WIDTH,
- ):
- """
- :param command: Called to propagate option selection. Is provided with the
- value corresponding to the radio button selected.
- :param type: If provided, casts value to given type
- """
- self._frame = frame
- label_height = 2
- text_height = 1
- self._label = tk.Label(
- frame,
- width=left_width,
- height=label_height,
- fg="black",
- bg=None,
- anchor="w",
- text=label,
- )
- self._label.pack(side=tk.LEFT)
-
- self._text = tk.Text(
- frame,
- width=right_width,
- height=text_height,
- fg="black",
- bg=None,
- )
- self._text.pack(side=tk.RIGHT)
-
- self._type = type
-
- if default is not None:
- self._text.insert("1.0", str(default))
-
- def get(self):
- try:
- val = self._text.get("1.0", tk.END) # Line 1, character zero (wat)
- if self._type is not None:
- val = self._type(val)
- return val
- except tk.TclError:
- return None
-
-
-class _AdvancedOptionsGUI(object):
- """
- A window to hold advanced options (Architecture and number of epochs)
- """
-
- def __init__(self, resume_main, parent: _GUI):
- self._parent = parent
- self._root = _TopLevelWithOk(self._apply, resume_main)
- self._root.title("Advanced Options")
-
- # Architecture: radio buttons
- self._frame_architecture = tk.Frame(self._root)
- self._frame_architecture.pack()
- self._architecture = _LabeledOptionMenu(
- self._frame_architecture,
- "Architecture",
- core.Architecture,
- default=self._parent.advanced_options.architecture,
- )
-
- # Number of epochs: text box
- self._frame_epochs = tk.Frame(self._root)
- self._frame_epochs.pack()
-
- self._epochs = _LabeledText(
- self._frame_epochs,
- "Epochs",
- default=str(self._parent.advanced_options.num_epochs),
- type=_non_negative_int,
- )
-
- # Delay: text box
- self._frame_latency = tk.Frame(self._root)
- self._frame_latency.pack()
-
- self._latency = _LabeledText(
- self._frame_latency,
- "Reamp latency",
- default=_type_or_null_inv(self._parent.advanced_options.latency),
- type=_int_or_null,
- )
-
- # Threshold ESR
- self._frame_threshold_esr = tk.Frame(self._root)
- self._frame_threshold_esr.pack()
- self._threshold_esr = _LabeledText(
- self._frame_threshold_esr,
- "Threshold ESR",
- default=_type_or_null_inv(self._parent.advanced_options.threshold_esr),
- type=_float_or_null,
- )
-
- # "Ok": apply and destory
- self._frame_ok = tk.Frame(self._root)
- self._frame_ok.pack()
- self._button_ok = tk.Button(
- self._frame_ok,
- text="Ok",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=lambda: self._root.destroy(pressed_ok=True),
- )
- self._button_ok.pack()
-
- def _apply(self):
- """
- Set values to parent and destroy this object
- """
- self._parent.advanced_options.architecture = self._architecture.get()
- epochs = self._epochs.get()
- if epochs is not None:
- self._parent.advanced_options.num_epochs = epochs
- latency = self._latency.get()
- # Value None is returned as "null" to disambiguate from non-set.
- if latency is not None:
- self._parent.advanced_options.latency = (
- None if latency == "null" else latency
- )
- threshold_esr = self._threshold_esr.get()
- if threshold_esr is not None:
- self._parent.advanced_options.threshold_esr = (
- None if threshold_esr == "null" else threshold_esr
- )
-
-
-class _UserMetadataGUI(object):
- # Things that are auto-filled:
- # Model date
- # gain
- def __init__(self, resume_main, parent: _GUI):
- self._parent = parent
- self._root = _TopLevelWithOk(self._apply, resume_main)
- self._root.title("Metadata")
-
- LabeledText = partial(_LabeledText, right_width=_METADATA_RIGHT_WIDTH)
-
- # Name
- self._frame_name = tk.Frame(self._root)
- self._frame_name.pack()
- self._name = LabeledText(
- self._frame_name,
- "NAM name",
- default=parent.user_metadata.name,
- type=_rstripped_str,
- )
- # Modeled by
- self._frame_modeled_by = tk.Frame(self._root)
- self._frame_modeled_by.pack()
- self._modeled_by = LabeledText(
- self._frame_modeled_by,
- "Modeled by",
- default=parent.user_metadata.modeled_by,
- type=_rstripped_str,
- )
- # Gear make
- self._frame_gear_make = tk.Frame(self._root)
- self._frame_gear_make.pack()
- self._gear_make = LabeledText(
- self._frame_gear_make,
- "Gear make",
- default=parent.user_metadata.gear_make,
- type=_rstripped_str,
- )
- # Gear model
- self._frame_gear_model = tk.Frame(self._root)
- self._frame_gear_model.pack()
- self._gear_model = LabeledText(
- self._frame_gear_model,
- "Gear model",
- default=parent.user_metadata.gear_model,
- type=_rstripped_str,
- )
- # Gear type
- self._frame_gear_type = tk.Frame(self._root)
- self._frame_gear_type.pack()
- self._gear_type = _LabeledOptionMenu(
- self._frame_gear_type,
- "Gear type",
- GearType,
- default=parent.user_metadata.gear_type,
- )
- # Tone type
- self._frame_tone_type = tk.Frame(self._root)
- self._frame_tone_type.pack()
- self._tone_type = _LabeledOptionMenu(
- self._frame_tone_type,
- "Tone type",
- ToneType,
- default=parent.user_metadata.tone_type,
- )
-
- # "Ok": apply and destory
- self._frame_ok = tk.Frame(self._root)
- self._frame_ok.pack()
- self._button_ok = tk.Button(
- self._frame_ok,
- text="Ok",
- width=_BUTTON_WIDTH,
- height=_BUTTON_HEIGHT,
- fg="black",
- command=lambda: self._root.destroy(pressed_ok=True),
- )
- self._button_ok.pack()
-
- def _apply(self):
- """
- Set values to parent and destroy this object
- """
- self._parent.user_metadata.name = self._name.get()
- self._parent.user_metadata.modeled_by = self._modeled_by.get()
- self._parent.user_metadata.gear_make = self._gear_make.get()
- self._parent.user_metadata.gear_model = self._gear_model.get()
- self._parent.user_metadata.gear_type = self._gear_type.get()
- self._parent.user_metadata.tone_type = self._tone_type.get()
- self._parent.user_metadata_flag = True
-
-
-def _install_error():
- window = tk.Tk()
- window.title("ERROR")
- label = tk.Label(
- window,
- width=45,
- height=2,
- text="The NAM training software has not been installed correctly.",
- )
- label.pack()
- button = tk.Button(window, width=10, height=2, text="Quit", command=window.destroy)
- button.pack()
- window.mainloop()
-
-
-def run():
- if _install_is_valid:
- _gui = _GUI()
- _gui.mainloop()
- else:
- _install_error()
-
-
-if __name__ == "__main__":
- run()
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -0,0 +1,928 @@
+# File: gui.py
+# Created Date: Saturday February 25th 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+"""
+GUI for training
+
+Usage:
+>>> from nam.train.gui import run
+>>> run()
+"""
+
+
+# Hack to recover graceful shutdowns in Windows.
+# This has to happen ASAP
+# See:
+# https://github.com/sdatkinson/neural-amp-modeler/issues/105
+# https://stackoverflow.com/a/44822794
+def _ensure_graceful_shutdowns():
+ import os
+
+ if os.name == "nt": # OS is Windows
+ os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
+
+
+_ensure_graceful_shutdowns()
+
+import re
+import tkinter as tk
+import webbrowser
+from dataclasses import dataclass
+from enum import Enum
+from functools import partial
+from pathlib import Path
+from tkinter import filedialog
+from typing import Callable, Dict, Optional, Sequence
+
+try: # 3rd-party and 1st-party imports
+ import torch
+
+ from nam import __version__
+ from nam.train import core
+ from nam.train.gui._resources import settings
+ from nam.models.metadata import GearType, UserMetadata, ToneType
+
+ # Ok private access here--this is technically allowed access
+ from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
+
+ _install_is_valid = True
+ _HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
+except ImportError:
+ _install_is_valid = False
+ _HAVE_ACCELERATOR = False
+
+__all__ = ["run"]
+
+if _HAVE_ACCELERATOR:
+ _DEFAULT_NUM_EPOCHS = 100
+ _DEFAULT_BATCH_SIZE = 16
+ _DEFAULT_LR_DECAY = 0.007
+else:
+ _DEFAULT_NUM_EPOCHS = 20
+ _DEFAULT_BATCH_SIZE = 1
+ _DEFAULT_LR_DECAY = 0.05
+_BUTTON_WIDTH = 20
+_BUTTON_HEIGHT = 2
+_TEXT_WIDTH = 70
+
+_DEFAULT_DELAY = None
+_DEFAULT_IGNORE_CHECKS = False
+_DEFAULT_THRESHOLD_ESR = None
+
+_ADVANCED_OPTIONS_LEFT_WIDTH = 12
+_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
+_METADATA_RIGHT_WIDTH = 60
+
+
+@dataclass
+class _AdvancedOptions(object):
+ """
+ :param architecture: Which architecture to use.
+ :param num_epochs: How many epochs to train for.
+ :param latency: Latency between the input and output audio, in samples.
+ None means we don't know and it has to be calibrated.
+ :param ignore_checks: Keep going even if a check says that something is wrong.
+ :param threshold_esr: Stop training if the ESR gets better than this. If None, don't
+ stop.
+ """
+
+ architecture: core.Architecture
+ num_epochs: int
+ latency: Optional[int]
+ ignore_checks: bool
+ threshold_esr: Optional[float]
+
+
+class _PathType(Enum):
+ FILE = "file"
+ DIRECTORY = "directory"
+ MULTIFILE = "multifile"
+
+
+class _PathButton(object):
+ """
+ Button and the path
+ """
+
+ def __init__(
+ self,
+ frame: tk.Frame,
+ button_text: str,
+ info_str: str,
+ path_type: _PathType,
+ path_key: settings.PathKey,
+ hooks: Optional[Sequence[Callable[[], None]]] = None,
+ color_when_not_set: str = "#EF0000", # Darker red
+ default: Optional[Path] = None,
+ ):
+ """
+ :param hooks: Callables run at the end of setting the value.
+ """
+ self._button_text = button_text
+ self._info_str = info_str
+ self._path: Optional[Path] = default
+ self._path_type = path_type
+ self._path_key = path_key
+ self._frame = frame
+ self._widgets = {}
+ self._widgets["button"] = tk.Button(
+ self._frame,
+ text=button_text,
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._set_val,
+ )
+ self._widgets["button"].pack(side=tk.LEFT)
+ self._widgets["label"] = tk.Label(
+ self._frame,
+ width=_TEXT_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ bg=None,
+ anchor="w",
+ )
+ self._widgets["label"].pack(side=tk.LEFT)
+ self._hooks = hooks
+ self._color_when_not_set = color_when_not_set
+ self._set_text()
+
+ def __setitem__(self, key, val):
+ """
+ Implement tk-style setter for state
+ """
+ if key == "state":
+ for widget in self._widgets.values():
+ widget["state"] = val
+ else:
+ raise RuntimeError(
+ f"{self.__class__.__name__} instance does not support item assignment for non-state key {key}!"
+ )
+
+ @property
+ def val(self) -> Optional[Path]:
+ return self._path
+
+ def _set_text(self):
+ if self._path is None:
+ self._widgets["label"]["fg"] = self._color_when_not_set
+ self._widgets["label"]["text"] = self._info_str
+ else:
+ val = self.val
+ val = val[0] if isinstance(val, tuple) and len(val) == 1 else val
+ self._widgets["label"]["fg"] = "black"
+ self._widgets["label"][
+ "text"
+ ] = f"{self._button_text.capitalize()} set to {val}"
+
+ def _set_val(self):
+ last_path = settings.get_last_path(self._path_key)
+ if last_path is None:
+ initial_dir = None
+ elif not last_path.is_dir():
+ initial_dir = last_path.parent
+ else:
+ initial_dir = last_path
+ result = {
+ _PathType.FILE: filedialog.askopenfilename,
+ _PathType.DIRECTORY: filedialog.askdirectory,
+ _PathType.MULTIFILE: filedialog.askopenfilenames,
+ }[self._path_type](initialdir=str(initial_dir))
+ if result != "":
+ self._path = result
+ settings.set_last_path(
+ self._path_key,
+ Path(result[0] if self._path_type == _PathType.MULTIFILE else result),
+ )
+ self._set_text()
+
+ if self._hooks is not None:
+ for h in self._hooks:
+ h()
+
+
+class _InputPathButton(_PathButton):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Download the training file!
+ self._widgets["button_download_input"] = tk.Button(
+ self._frame,
+ text="Download input file",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._download_input_file,
+ )
+ self._widgets["button_download_input"].pack(side=tk.RIGHT)
+
+ @classmethod
+ def _download_input_file(cls):
+ file_urls = {
+ "v3_0_0.wav": "https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link",
+ "v2_0_0.wav": "https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link",
+ "v1_1_1.wav": "",
+ "v1.wav": "",
+ }
+ # Pick the most recent file.
+ for input_basename in INPUT_BASENAMES:
+ name = input_basename.name
+ url = file_urls.get(name)
+ if url:
+ if name != LATEST_VERSION.name:
+ print(
+ f"WARNING: File {name} is out of date. "
+ "This needs to be updated!"
+ )
+ webbrowser.open(url)
+ return
+
+
+class _ClearablePathButton(_PathButton):
+ """
+ Can clear a path
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, color_when_not_set="black", **kwargs)
+ # Download the training file!
+ self._widgets["button_clear"] = tk.Button(
+ self._frame,
+ text="Clear",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._clear_path,
+ )
+ self._widgets["button_clear"].pack(side=tk.RIGHT)
+
+ def _clear_path(self):
+ self._path = None
+ self._set_text()
+
+
+class _CheckboxKeys(Enum):
+ """
+ Keys for checkboxes
+ """
+
+ FIT_CAB = "fit_cab"
+ SILENT_TRAINING = "silent_training"
+ SAVE_PLOT = "save_plot"
+ IGNORE_DATA_CHECKS = "ignore_data_checks"
+
+
+class _TopLevelWithOk(tk.Toplevel):
+ """
+ Toplevel with an Ok button (provide yourself!)
+ """
+
+ def __init__(
+ self, on_ok: Callable[[None], None], resume_main: Callable[[None], None]
+ ):
+ """
+ :param on_ok: What to do when "Ok" button is pressed
+ """
+ super().__init__()
+ self._on_ok = on_ok
+ self._resume_main = resume_main
+
+ def destroy(self, pressed_ok: bool = False):
+ if pressed_ok:
+ self._on_ok()
+ self._resume_main()
+ super().destroy()
+
+
+class _BasicModal(object):
+ """
+ Message and OK button
+ """
+
+ def __init__(self, resume_main, msg: str):
+ self._root = _TopLevelWithOk((lambda: None), resume_main)
+ self._text = tk.Label(self._root, text=msg)
+ self._text.pack()
+ self._ok = tk.Button(
+ self._root,
+ text="Ok",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=lambda: self._root.destroy(pressed_ok=True),
+ )
+ self._ok.pack()
+
+
+class _GUIWidgets(Enum):
+ INPUT_PATH = "input_path"
+ OUTPUT_PATH = "output_path"
+ TRAINING_DESTINATION = "training_destination"
+ METADATA = "metadata"
+ ADVANCED_OPTIONS = "advanced_options"
+ TRAIN = "train"
+
+
+class _GUI(object):
+ def __init__(self):
+ self._root = tk.Tk()
+ self._root.title(f"NAM Trainer - v{__version__}")
+ self._widgets = {}
+
+ # Buttons for paths:
+ self._frame_input = tk.Frame(self._root)
+ self._frame_input.pack(anchor="w")
+ self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton(
+ self._frame_input,
+ "Input Audio",
+ f"Select input (DI) file (e.g. {LATEST_VERSION.name})",
+ _PathType.FILE,
+ settings.PathKey.INPUT_FILE,
+ hooks=[self._check_button_states],
+ )
+
+ self._frame_output_path = tk.Frame(self._root)
+ self._frame_output_path.pack(anchor="w")
+ self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton(
+ self._frame_output_path,
+ "Output Audio",
+ "Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)",
+ _PathType.MULTIFILE,
+ settings.PathKey.OUTPUT_FILE,
+ hooks=[self._check_button_states],
+ )
+
+ self._frame_train_destination = tk.Frame(self._root)
+ self._frame_train_destination.pack(anchor="w")
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton(
+ self._frame_train_destination,
+ "Train Destination",
+ "Select training output directory",
+ _PathType.DIRECTORY,
+ settings.PathKey.TRAINING_DESTINATION,
+ hooks=[self._check_button_states],
+ )
+
+ # Metadata
+ self.user_metadata = UserMetadata()
+ self._frame_metadata = tk.Frame(self._root)
+ self._frame_metadata.pack(anchor="w")
+ self._widgets["metadata"] = tk.Button(
+ self._frame_metadata,
+ text="Metadata...",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._open_metadata,
+ )
+ self._widgets["metadata"].pack()
+ self.user_metadata_flag = False
+
+ # This should probably be to the right somewhere
+ self._get_additional_options_frame()
+
+ # Last frames: avdanced options & train in the SE corner:
+ self._frame_advanced_options = tk.Frame(self._root)
+ self._frame_train = tk.Frame(self._root)
+ # Pack train first so that it's on bottom.
+ self._frame_train.pack(side=tk.BOTTOM, anchor="e")
+ self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e")
+
+ # Advanced options for training
+ default_architecture = core.Architecture.STANDARD
+ self.advanced_options = _AdvancedOptions(
+ default_architecture,
+ _DEFAULT_NUM_EPOCHS,
+ _DEFAULT_DELAY,
+ _DEFAULT_IGNORE_CHECKS,
+ _DEFAULT_THRESHOLD_ESR,
+ )
+ # Window to edit them:
+
+ self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = tk.Button(
+ self._frame_advanced_options,
+ text="Advanced options...",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._open_advanced_options,
+ )
+ self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack()
+
+ # Train button
+
+ self._widgets[_GUIWidgets.TRAIN] = tk.Button(
+ self._frame_train,
+ text="Train",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._train,
+ )
+ self._widgets[_GUIWidgets.TRAIN].pack()
+
+ self._check_button_states()
+
+ def _check_button_states(self):
+ """
+ Determine if any buttons should be disabled
+ """
+ # Train button is disabled unless all paths are set
+ if any(
+ pb.val is None
+ for pb in (
+ self._widgets[_GUIWidgets.INPUT_PATH],
+ self._widgets[_GUIWidgets.OUTPUT_PATH],
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION],
+ )
+ ):
+ self._widgets[_GUIWidgets.TRAIN]["state"] = tk.DISABLED
+ return
+ self._widgets[_GUIWidgets.TRAIN]["state"] = tk.NORMAL
+
+ def _get_additional_options_frame(self):
+ # Checkboxes
+ # TODO get these definitions into __init__()
+ self._frame_checkboxes = tk.Frame(self._root)
+ self._frame_checkboxes.pack(side=tk.LEFT)
+ row = 1
+
+ @dataclass
+ class Checkbox(object):
+ variable: tk.BooleanVar
+ check_button: tk.Checkbutton
+
+ def make_checkbox(
+ key: _CheckboxKeys, text: str, default_value: bool
+ ) -> Checkbox:
+ variable = tk.BooleanVar()
+ variable.set(default_value)
+ check_button = tk.Checkbutton(
+ self._frame_checkboxes, text=text, variable=variable
+ )
+ self._checkboxes[key] = Checkbox(variable, check_button)
+ self._widgets[key] = check_button # For tracking in set-all-widgets ops
+
+ self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
+ make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False)
+ make_checkbox(
+ _CheckboxKeys.SILENT_TRAINING,
+ "Silent run (suggested for batch training)",
+ False,
+ )
+ make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True)
+ make_checkbox(
+ _CheckboxKeys.IGNORE_DATA_CHECKS,
+ "Ignore data quality checks (DO AT YOUR OWN RISK!)",
+ False,
+ )
+
+ # Grid them:
+ row = 1
+ for v in self._checkboxes.values():
+ v.check_button.grid(row=row, column=1, sticky="W")
+ row += 1
+
+ def mainloop(self):
+ self._root.mainloop()
+
+ def _disable(self):
+ self._set_all_widget_states_to(tk.DISABLED)
+
+ def _open_advanced_options(self):
+ """
+ Open window for advanced options
+ """
+
+ self._wait_while_func(lambda resume: _AdvancedOptionsGUI(resume, self))
+
+ def _open_metadata(self):
+ """
+ Open window for metadata
+ """
+
+ self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self))
+
+ def _resume(self):
+ self._set_all_widget_states_to(tk.NORMAL)
+ self._check_button_states()
+
+ def _set_all_widget_states_to(self, state):
+ for widget in self._widgets.values():
+ widget["state"] = state
+
+ def _train(self):
+ # Advanced options:
+ num_epochs = self.advanced_options.num_epochs
+ architecture = self.advanced_options.architecture
+ delay = self.advanced_options.latency
+ file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
+ threshold_esr = self.advanced_options.threshold_esr
+
+ # Advanced-er options
+ # If you're poking around looking for these, then maybe it's time to learn to
+ # use the command-line scripts ;)
+ lr = 0.004
+ lr_decay = _DEFAULT_LR_DECAY
+ batch_size = _DEFAULT_BATCH_SIZE
+ seed = 0
+
+ # Run it
+ for file in file_list:
+ print("Now training {}".format(file))
+ basename = re.sub(r"\.wav$", "", file.split("/")[-1])
+ user_metadata = (
+ self.user_metadata if self.user_metadata_flag else UserMetadata()
+ )
+
+ trained_model = core.train(
+ self._widgets[_GUIWidgets.INPUT_PATH].val,
+ file,
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
+ epochs=num_epochs,
+ delay=delay,
+ architecture=architecture,
+ batch_size=batch_size,
+ lr=lr,
+ lr_decay=lr_decay,
+ seed=seed,
+ silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
+ save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
+ modelname=basename,
+ ignore_checks=self._checkboxes[
+ _CheckboxKeys.IGNORE_DATA_CHECKS
+ ].variable.get(),
+ local=True,
+ fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
+ threshold_esr=threshold_esr,
+ user_metadata=user_metadata,
+ )
+
+ if trained_model is None:
+ print("Model training failed! Skip exporting...")
+ continue
+ print("Model training complete!")
+ print("Exporting...")
+ outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val
+ print(f"Exporting trained model to {outdir}...")
+ trained_model.net.export(
+ outdir, basename=basename, user_metadata=user_metadata
+ )
+ print("Done!")
+
+ # Metadata was only valid for 1 run, so make sure it's not used again unless
+ # the user re-visits the window and clicks "ok"
+ self.user_metadata_flag = False
+
+ def _wait_while_func(self, func, *args, **kwargs):
+ """
+ Disable this GUI while something happens.
+ That function _needs_ to call the provided self._resume when it's ready to
+ release me!
+ """
+ self._disable()
+ func(self._resume, *args, **kwargs)
+
+
+# some typing functions
+def _non_negative_int(val):
+ val = int(val)
+ if val < 0:
+ val = 0
+ return val
+
+
+def _type_or_null(T, val):
+ val = val.rstrip()
+ if val == "null":
+ return val
+ return T(val)
+
+
+_int_or_null = partial(_type_or_null, int)
+_float_or_null = partial(_type_or_null, float)
+
+
+def _type_or_null_inv(val):
+ return "null" if val is None else str(val)
+
+
+def _rstripped_str(val):
+ return str(val).rstrip()
+
+
+class _LabeledOptionMenu(object):
+ """
+ Label (left) and radio buttons (right)
+ """
+
+ def __init__(
+ self, frame: tk.Frame, label: str, choices: Enum, default: Optional[Enum] = None
+ ):
+ """
+ :param command: Called to propagate option selection. Is provided with the
+ value corresponding to the radio button selected.
+ """
+ self._frame = frame
+ self._choices = choices
+ height = _BUTTON_HEIGHT
+ bg = None
+ fg = "black"
+ self._label = tk.Label(
+ frame,
+ width=_ADVANCED_OPTIONS_LEFT_WIDTH,
+ height=height,
+ fg=fg,
+ bg=bg,
+ anchor="w",
+ text=label,
+ )
+ self._label.pack(side=tk.LEFT)
+
+ frame_menu = tk.Frame(frame)
+ frame_menu.pack(side=tk.RIGHT)
+
+ self._selected_value = None
+ default = (list(choices)[0] if default is None else default).value
+ self._menu = tk.OptionMenu(
+ frame_menu,
+ tk.StringVar(master=frame, value=default, name=label),
+ # default,
+ *[choice.value for choice in choices], # if choice.value!=default],
+ command=self._set,
+ )
+ self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH)
+ self._menu.pack(side=tk.RIGHT)
+ # Initialize
+ self._set(default)
+
+ def get(self) -> Enum:
+ return self._selected_value
+
+ def _set(self, val: str):
+ """
+ Set the value selected
+ """
+ self._selected_value = self._choices(val)
+
+
+class _LabeledText(object):
+ """
+ Label (left) and text input (right)
+ """
+
+ def __init__(
+ self,
+ frame: tk.Frame,
+ label: str,
+ default=None,
+ type=None,
+ left_width=_ADVANCED_OPTIONS_LEFT_WIDTH,
+ right_width=_ADVANCED_OPTIONS_RIGHT_WIDTH,
+ ):
+ """
+ :param command: Called to propagate option selection. Is provided with the
+ value corresponding to the radio button selected.
+ :param type: If provided, casts value to given type
+ """
+ self._frame = frame
+ label_height = 2
+ text_height = 1
+ self._label = tk.Label(
+ frame,
+ width=left_width,
+ height=label_height,
+ fg="black",
+ bg=None,
+ anchor="w",
+ text=label,
+ )
+ self._label.pack(side=tk.LEFT)
+
+ self._text = tk.Text(
+ frame,
+ width=right_width,
+ height=text_height,
+ fg="black",
+ bg=None,
+ )
+ self._text.pack(side=tk.RIGHT)
+
+ self._type = type
+
+ if default is not None:
+ self._text.insert("1.0", str(default))
+
+ def get(self):
+ try:
+ val = self._text.get("1.0", tk.END) # Line 1, character zero (wat)
+ if self._type is not None:
+ val = self._type(val)
+ return val
+ except tk.TclError:
+ return None
+
+
+class _AdvancedOptionsGUI(object):
+ """
+ A window to hold advanced options (Architecture and number of epochs)
+ """
+
+ def __init__(self, resume_main, parent: _GUI):
+ self._parent = parent
+ self._root = _TopLevelWithOk(self._apply, resume_main)
+ self._root.title("Advanced Options")
+
+ # Architecture: radio buttons
+ self._frame_architecture = tk.Frame(self._root)
+ self._frame_architecture.pack()
+ self._architecture = _LabeledOptionMenu(
+ self._frame_architecture,
+ "Architecture",
+ core.Architecture,
+ default=self._parent.advanced_options.architecture,
+ )
+
+ # Number of epochs: text box
+ self._frame_epochs = tk.Frame(self._root)
+ self._frame_epochs.pack()
+
+ self._epochs = _LabeledText(
+ self._frame_epochs,
+ "Epochs",
+ default=str(self._parent.advanced_options.num_epochs),
+ type=_non_negative_int,
+ )
+
+ # Delay: text box
+ self._frame_latency = tk.Frame(self._root)
+ self._frame_latency.pack()
+
+ self._latency = _LabeledText(
+ self._frame_latency,
+ "Reamp latency",
+ default=_type_or_null_inv(self._parent.advanced_options.latency),
+ type=_int_or_null,
+ )
+
+ # Threshold ESR
+ self._frame_threshold_esr = tk.Frame(self._root)
+ self._frame_threshold_esr.pack()
+ self._threshold_esr = _LabeledText(
+ self._frame_threshold_esr,
+ "Threshold ESR",
+ default=_type_or_null_inv(self._parent.advanced_options.threshold_esr),
+ type=_float_or_null,
+ )
+
+ # "Ok": apply and destory
+ self._frame_ok = tk.Frame(self._root)
+ self._frame_ok.pack()
+ self._button_ok = tk.Button(
+ self._frame_ok,
+ text="Ok",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=lambda: self._root.destroy(pressed_ok=True),
+ )
+ self._button_ok.pack()
+
+ def _apply(self):
+ """
+ Set values to parent and destroy this object
+ """
+ self._parent.advanced_options.architecture = self._architecture.get()
+ epochs = self._epochs.get()
+ if epochs is not None:
+ self._parent.advanced_options.num_epochs = epochs
+ latency = self._latency.get()
+ # Value None is returned as "null" to disambiguate from non-set.
+ if latency is not None:
+ self._parent.advanced_options.latency = (
+ None if latency == "null" else latency
+ )
+ threshold_esr = self._threshold_esr.get()
+ if threshold_esr is not None:
+ self._parent.advanced_options.threshold_esr = (
+ None if threshold_esr == "null" else threshold_esr
+ )
+
+
+class _UserMetadataGUI(object):
+ # Things that are auto-filled:
+ # Model date
+ # gain
+ def __init__(self, resume_main, parent: _GUI):
+ self._parent = parent
+ self._root = _TopLevelWithOk(self._apply, resume_main)
+ self._root.title("Metadata")
+
+ LabeledText = partial(_LabeledText, right_width=_METADATA_RIGHT_WIDTH)
+
+ # Name
+ self._frame_name = tk.Frame(self._root)
+ self._frame_name.pack()
+ self._name = LabeledText(
+ self._frame_name,
+ "NAM name",
+ default=parent.user_metadata.name,
+ type=_rstripped_str,
+ )
+ # Modeled by
+ self._frame_modeled_by = tk.Frame(self._root)
+ self._frame_modeled_by.pack()
+ self._modeled_by = LabeledText(
+ self._frame_modeled_by,
+ "Modeled by",
+ default=parent.user_metadata.modeled_by,
+ type=_rstripped_str,
+ )
+ # Gear make
+ self._frame_gear_make = tk.Frame(self._root)
+ self._frame_gear_make.pack()
+ self._gear_make = LabeledText(
+ self._frame_gear_make,
+ "Gear make",
+ default=parent.user_metadata.gear_make,
+ type=_rstripped_str,
+ )
+ # Gear model
+ self._frame_gear_model = tk.Frame(self._root)
+ self._frame_gear_model.pack()
+ self._gear_model = LabeledText(
+ self._frame_gear_model,
+ "Gear model",
+ default=parent.user_metadata.gear_model,
+ type=_rstripped_str,
+ )
+ # Gear type
+ self._frame_gear_type = tk.Frame(self._root)
+ self._frame_gear_type.pack()
+ self._gear_type = _LabeledOptionMenu(
+ self._frame_gear_type,
+ "Gear type",
+ GearType,
+ default=parent.user_metadata.gear_type,
+ )
+ # Tone type
+ self._frame_tone_type = tk.Frame(self._root)
+ self._frame_tone_type.pack()
+ self._tone_type = _LabeledOptionMenu(
+ self._frame_tone_type,
+ "Tone type",
+ ToneType,
+ default=parent.user_metadata.tone_type,
+ )
+
+ # "Ok": apply and destory
+ self._frame_ok = tk.Frame(self._root)
+ self._frame_ok.pack()
+ self._button_ok = tk.Button(
+ self._frame_ok,
+ text="Ok",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=lambda: self._root.destroy(pressed_ok=True),
+ )
+ self._button_ok.pack()
+
+ def _apply(self):
+ """
+ Set values to parent and destroy this object
+ """
+ self._parent.user_metadata.name = self._name.get()
+ self._parent.user_metadata.modeled_by = self._modeled_by.get()
+ self._parent.user_metadata.gear_make = self._gear_make.get()
+ self._parent.user_metadata.gear_model = self._gear_model.get()
+ self._parent.user_metadata.gear_type = self._gear_type.get()
+ self._parent.user_metadata.tone_type = self._tone_type.get()
+ self._parent.user_metadata_flag = True
+
+
+def _install_error():
+ window = tk.Tk()
+ window.title("ERROR")
+ label = tk.Label(
+ window,
+ width=45,
+ height=2,
+ text="The NAM training software has not been installed correctly.",
+ )
+ label.pack()
+ button = tk.Button(window, width=10, height=2, text="Quit", command=window.destroy)
+ button.pack()
+ window.mainloop()
+
+
+def run():
+ if _install_is_valid:
+ _gui = _GUI()
+ _gui.mainloop()
+ else:
+ _install_error()
+
+
+if __name__ == "__main__":
+ run()
diff --git a/nam/train/gui/_resources/__init__.py b/nam/train/gui/_resources/__init__.py
@@ -0,0 +1,7 @@
+# File: __init__.py
+# Created Date: Tuesday May 14th 2024
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+from . import settings
+
+__all__ = ["settings"]
diff --git a/nam/train/gui/_resources/settings.py b/nam/train/gui/_resources/settings.py
@@ -0,0 +1,56 @@
+# File: settings.py
+# Created Date: Tuesday May 14th 2024
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+import json
+from enum import Enum
+from functools import partial
+from pathlib import Path
+from typing import Optional
+
+__all__ = ["PathKey", "get_last_path", "set_last_path"]
+
+_THIS_DIR = Path(__file__).parent.resolve()
+_SETTINGS_JSON_PATH = Path(_THIS_DIR, "settings.json")
+_LAST_PATHS_KEY = "last_paths"
+
+
+class PathKey(Enum):
+ INPUT_FILE = "input_file"
+ OUTPUT_FILE = "output_file"
+ TRAINING_DESTINATION = "training_destination"
+
+
+def get_last_path(path_key: PathKey) -> Optional[Path]:
+ s = _get_settings()
+ if _LAST_PATHS_KEY not in s:
+ return None
+ last_path = s[_LAST_PATHS_KEY].get(path_key.value)
+ if last_path is None:
+ return None
+ assert isinstance(last_path, str)
+ return Path(last_path)
+
+
+def set_last_path(path_key: PathKey, path: Path):
+ s = _get_settings()
+ if _LAST_PATHS_KEY not in s:
+ s[_LAST_PATHS_KEY] = {}
+ s[_LAST_PATHS_KEY][path_key.value] = str(path)
+ _write_settings(s)
+
+
+def _get_settings() -> dict:
+ """
+ Make sure that ./settings.json exists; if it does, then read it. If not, empty dict.
+ """
+
+ if not _SETTINGS_JSON_PATH.exists():
+ _write_settings({})
+ with open(_SETTINGS_JSON_PATH, "r") as fp:
+ return json.load(fp)
+
+
+def _write_settings(obj: dict):
+ with open(_SETTINGS_JSON_PATH, "w") as fp:
+ json.dump(obj, fp, indent=4)