neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

__init__.py (44138B)


      1 # File: gui.py
      2 # Created Date: Saturday February 25th 2023
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 GUI for training
      7 
      8 Usage:
      9 >>> from nam.train.gui import run
     10 >>> run()
     11 """
     12 
     13 import abc as _abc
     14 import re as _re
     15 import requests as _requests
     16 import tkinter as _tk
     17 import subprocess as _subprocess
     18 import sys as _sys
     19 import webbrowser as _webbrowser
     20 from dataclasses import dataclass as _dataclass
     21 from enum import Enum as _Enum
     22 from functools import partial as _partial
     23 
     24 try:  # Not supported in Colab
     25     from idlelib.tooltip import Hovertip
     26 except ModuleNotFoundError:
     27     # Hovertips won't work
     28     class Hovertip(object):
     29         """
     30         Shell class
     31         """
     32 
     33         def __init__(self, *args, **kwargs):
     34             pass
     35 
     36 
     37 from pathlib import Path as _Path
     38 from tkinter import filedialog as _filedialog
     39 from typing import (
     40     Any as _Any,
     41     Callable as _Callable,
     42     Dict as _Dict,
     43     NamedTuple as _NamedTuple,
     44     Optional as _Optional,
     45     Sequence as _Sequence,
     46 )
     47 
     48 try:  # 3rd-party and 1st-party imports
     49     import torch as _torch
     50 
     51     from nam import __version__
     52     from nam.data import Split as _Split
     53     from nam.train import core as _core
     54     from nam.train.gui._resources import settings as _settings
     55     from nam.models.metadata import (
     56         GearType as _GearType,
     57         UserMetadata as _UserMetadata,
     58         ToneType as _ToneType,
     59     )
     60 
     61     # Ok private access here--this is technically allowed access
     62     from nam.train import metadata as _metadata
     63     from nam.train._names import (
     64         INPUT_BASENAMES as _INPUT_BASENAMES,
     65         LATEST_VERSION as _LATEST_VERSION,
     66     )
     67     from nam.train._version import (
     68         Version as _Version,
     69         get_current_version as _get_current_version,
     70     )
     71 
     72     _install_is_valid = True
     73     _HAVE_ACCELERATOR = _torch.cuda.is_available() or _torch.backends.mps.is_available()
     74 except ImportError:
     75     _install_is_valid = False
     76     _HAVE_ACCELERATOR = False
     77 
     78 if _HAVE_ACCELERATOR:
     79     _DEFAULT_NUM_EPOCHS = 100
     80     _DEFAULT_BATCH_SIZE = 16
     81     _DEFAULT_LR_DECAY = 0.007
     82 else:
     83     _DEFAULT_NUM_EPOCHS = 20
     84     _DEFAULT_BATCH_SIZE = 1
     85     _DEFAULT_LR_DECAY = 0.05
     86 _BUTTON_WIDTH = 20
     87 _BUTTON_HEIGHT = 2
     88 _TEXT_WIDTH = 70
     89 
     90 _DEFAULT_DELAY = None
     91 _DEFAULT_IGNORE_CHECKS = False
     92 _DEFAULT_THRESHOLD_ESR = None
     93 
     94 _ADVANCED_OPTIONS_LEFT_WIDTH = 12
     95 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12
     96 _METADATA_LEFT_WIDTH = 19
     97 _METADATA_RIGHT_WIDTH = 60
     98 
     99 
    100 def _is_mac() -> bool:
    101     return _sys.platform == "darwin"
    102 
    103 
    104 _SYSTEM_TEXT_COLOR = "systemTextColor" if _is_mac() else "black"
    105 
    106 
    107 @_dataclass
    108 class AdvancedOptions(object):
    109     """
    110     :param architecture: Which architecture to use.
    111     :param num_epochs: How many epochs to train for.
    112     :param latency: Latency between the input and output audio, in samples.
    113         None means we don't know and it has to be calibrated.
    114     :param ignore_checks: Keep going even if a check says that something is wrong.
    115     :param threshold_esr: Stop training if the ESR gets better than this. If None, don't
    116         stop.
    117     """
    118 
    119     architecture: _core.Architecture
    120     num_epochs: int
    121     latency: _Optional[int]
    122     ignore_checks: bool
    123     threshold_esr: _Optional[float]
    124 
    125 
    126 class _PathType(_Enum):
    127     FILE = "file"
    128     DIRECTORY = "directory"
    129     MULTIFILE = "multifile"
    130 
    131 
    132 class _PathButton(object):
    133     """
    134     Button and the path
    135     """
    136 
    137     def __init__(
    138         self,
    139         frame: _tk.Frame,
    140         button_text: str,
    141         info_str: str,
    142         path_type: _PathType,
    143         path_key: _settings.PathKey,
    144         hooks: _Optional[_Sequence[_Callable[[], None]]] = None,
    145         color_when_not_set: str = "#EF0000",  # Darker red
    146         color_when_set: str = _SYSTEM_TEXT_COLOR,
    147         default: _Optional[_Path] = None,
    148     ):
    149         """
    150         :param hooks: Callables run at the end of setting the value.
    151         """
    152         self._button_text = button_text
    153         self._info_str = info_str
    154         self._path: _Optional[_Path] = default
    155         self._path_type = path_type
    156         self._path_key = path_key
    157         self._frame = frame
    158         self._widgets = {}
    159         self._widgets["button"] = _tk.Button(
    160             self._frame,
    161             text=button_text,
    162             width=_BUTTON_WIDTH,
    163             height=_BUTTON_HEIGHT,
    164             command=self._set_val,
    165         )
    166         self._widgets["button"].pack(side=_tk.LEFT)
    167         self._widgets["label"] = _tk.Label(
    168             self._frame,
    169             width=_TEXT_WIDTH,
    170             height=_BUTTON_HEIGHT,
    171             bg=None,
    172             anchor="w",
    173         )
    174         self._widgets["label"].pack(side=_tk.LEFT)
    175         self._hooks = hooks
    176         self._color_when_not_set = color_when_not_set
    177         self._color_when_set = color_when_set
    178         self._set_text()
    179 
    180     def __setitem__(self, key, val):
    181         """
    182         Implement tk-style setter for state
    183         """
    184         if key == "state":
    185             for widget in self._widgets.values():
    186                 widget["state"] = val
    187         else:
    188             raise RuntimeError(
    189                 f"{self.__class__.__name__} instance does not support item assignment for non-state key {key}!"
    190             )
    191 
    192     @property
    193     def val(self) -> _Optional[_Path]:
    194         return self._path
    195 
    196     def _set_text(self):
    197         if self._path is None:
    198             self._widgets["label"]["fg"] = self._color_when_not_set
    199             self._widgets["label"]["text"] = self._info_str
    200         else:
    201             val = self.val
    202             val = val[0] if isinstance(val, tuple) and len(val) == 1 else val
    203             self._widgets["label"]["fg"] = self._color_when_set
    204             self._widgets["label"][
    205                 "text"
    206             ] = f"{self._button_text.capitalize()} set to {val}"
    207 
    208     def _set_val(self):
    209         last_path = _settings.get_last_path(self._path_key)
    210         if last_path is None:
    211             initial_dir = None
    212         elif not last_path.is_dir():
    213             initial_dir = last_path.parent
    214         else:
    215             initial_dir = last_path
    216         result = {
    217             _PathType.FILE: _filedialog.askopenfilename,
    218             _PathType.DIRECTORY: _filedialog.askdirectory,
    219             _PathType.MULTIFILE: _filedialog.askopenfilenames,
    220         }[self._path_type](initialdir=str(initial_dir))
    221         if result != "":
    222             self._path = result
    223             _settings.set_last_path(
    224                 self._path_key,
    225                 _Path(result[0] if self._path_type == _PathType.MULTIFILE else result),
    226             )
    227         self._set_text()
    228 
    229         if self._hooks is not None:
    230             for h in self._hooks:
    231                 h()
    232 
    233 
    234 class _InputPathButton(_PathButton):
    235     def __init__(self, *args, **kwargs):
    236         super().__init__(*args, **kwargs)
    237         # Download the training file!
    238         self._widgets["button_download_input"] = _tk.Button(
    239             self._frame,
    240             text="Download input file",
    241             width=_BUTTON_WIDTH,
    242             height=_BUTTON_HEIGHT,
    243             command=self._download_input_file,
    244         )
    245         self._widgets["button_download_input"].pack(side=_tk.RIGHT)
    246 
    247     @classmethod
    248     def _download_input_file(cls):
    249         file_urls = {
    250             "input.wav": "https://drive.google.com/file/d/1KbaS4oXXNEuh2aCPLwKrPdf5KFOjda8G/view?usp=drive_link",
    251             "v3_0_0.wav": "https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link",
    252             "v2_0_0.wav": "https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link",
    253             "v1_1_1.wav": "",
    254             "v1.wav": "",
    255         }
    256         # Pick the most recent file.
    257         for input_basename in _INPUT_BASENAMES:
    258             name = input_basename.name
    259             url = file_urls.get(name)
    260             if url:
    261                 if name != _LATEST_VERSION.name:
    262                     print(
    263                         f"WARNING: File {name} is out of date. "
    264                         "This needs to be updated!"
    265                     )
    266                 _webbrowser.open(url)
    267                 return
    268 
    269 
    270 class _CheckboxKeys(_Enum):
    271     """
    272     Keys for checkboxes
    273     """
    274 
    275     SILENT_TRAINING = "silent_training"
    276     SAVE_PLOT = "save_plot"
    277 
    278 
    279 class _TopLevelWithOk(_tk.Toplevel):
    280     """
    281     Toplevel with an Ok button (provide yourself!)
    282     """
    283 
    284     def __init__(
    285         self, on_ok: _Callable[[None], None], resume_main: _Callable[[None], None]
    286     ):
    287         """
    288         :param on_ok: What to do when "Ok" button is pressed
    289         """
    290         super().__init__()
    291         self._on_ok = on_ok
    292         self._resume_main = resume_main
    293 
    294     def destroy(self, pressed_ok: bool = False):
    295         if pressed_ok:
    296             self._on_ok()
    297         self._resume_main()
    298         super().destroy()
    299 
    300 
    301 class _TopLevelWithYesNo(_tk.Toplevel):
    302     """
    303     Toplevel holding functions for yes/no buttons to close
    304     """
    305 
    306     def __init__(
    307         self,
    308         on_yes: _Callable[[None], None],
    309         on_no: _Callable[[None], None],
    310         on_close: _Optional[_Callable[[None], None]],
    311         resume_main: _Callable[[None], None],
    312     ):
    313         """
    314         :param on_yes: What to do when "Yes" button is pressed.
    315         :param on_no: What to do when "No" button is pressed.
    316         :param on_close: Do this regardless when closing (via yes/no/x) before
    317             resuming.
    318         """
    319         super().__init__()
    320         self._on_yes = on_yes
    321         self._on_no = on_no
    322         self._on_close = on_close
    323         self._resume_main = resume_main
    324 
    325     def destroy(self, pressed_yes: bool = False, pressed_no: bool = False):
    326         if pressed_yes:
    327             self._on_yes()
    328         if pressed_no:
    329             self._on_no()
    330         if self._on_close is not None:
    331             self._on_close()
    332         self._resume_main()
    333         super().destroy()
    334 
    335 
    336 class _OkModal(object):
    337     """
    338     Message and OK button
    339     """
    340 
    341     def __init__(self, resume_main, msg: str, label_kwargs: _Optional[dict] = None):
    342         label_kwargs = {} if label_kwargs is None else label_kwargs
    343 
    344         self._root = _TopLevelWithOk((lambda: None), resume_main)
    345         self._text = _tk.Label(self._root, text=msg, **label_kwargs)
    346         self._text.pack()
    347         self._ok = _tk.Button(
    348             self._root,
    349             text="Ok",
    350             width=_BUTTON_WIDTH,
    351             height=_BUTTON_HEIGHT,
    352             command=lambda: self._root.destroy(pressed_ok=True),
    353         )
    354         self._ok.pack()
    355 
    356 
    357 class _YesNoModal(object):
    358     """
    359     Modal w/ yes/no buttons
    360     """
    361 
    362     def __init__(
    363         self,
    364         on_yes: _Callable[[None], None],
    365         on_no: _Callable[[None], None],
    366         resume_main,
    367         msg: str,
    368         on_close: _Optional[_Callable[[None], None]] = None,
    369         label_kwargs: _Optional[dict] = None,
    370     ):
    371         label_kwargs = {} if label_kwargs is None else label_kwargs
    372         self._root = _TopLevelWithYesNo(on_yes, on_no, on_close, resume_main)
    373         self._text = _tk.Label(self._root, text=msg, **label_kwargs)
    374         self._text.pack()
    375         self._buttons_frame = _tk.Frame(self._root)
    376         self._buttons_frame.pack()
    377         self._yes = _tk.Button(
    378             self._buttons_frame,
    379             text="Yes",
    380             width=_BUTTON_WIDTH,
    381             height=_BUTTON_HEIGHT,
    382             command=lambda: self._root.destroy(pressed_yes=True),
    383         )
    384         self._yes.pack(side=_tk.LEFT)
    385         self._no = _tk.Button(
    386             self._buttons_frame,
    387             text="No",
    388             width=_BUTTON_WIDTH,
    389             height=_BUTTON_HEIGHT,
    390             command=lambda: self._root.destroy(pressed_no=True),
    391         )
    392         self._no.pack(side=_tk.RIGHT)
    393 
    394 
    395 class _GUIWidgets(_Enum):
    396     INPUT_PATH = "input_path"
    397     OUTPUT_PATH = "output_path"
    398     TRAINING_DESTINATION = "training_destination"
    399     METADATA = "metadata"
    400     ADVANCED_OPTIONS = "advanced_options"
    401     TRAIN = "train"
    402     UPDATE = "update"
    403 
    404 
    405 @_dataclass
    406 class Checkbox(object):
    407     variable: _tk.BooleanVar
    408     check_button: _tk.Checkbutton
    409 
    410 
    411 class GUI(object):
    412     def __init__(self):
    413         self._root = _tk.Tk()
    414         self._root.title(f"NAM Trainer - v{__version__}")
    415         self._widgets = {}
    416 
    417         # Buttons for paths:
    418         self._frame_input = _tk.Frame(self._root)
    419         self._frame_input.pack(anchor="w")
    420         self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton(
    421             self._frame_input,
    422             "Input Audio",
    423             f"Select input (DI) file (e.g. {_LATEST_VERSION.name})",
    424             _PathType.FILE,
    425             _settings.PathKey.INPUT_FILE,
    426             hooks=[self._check_button_states],
    427         )
    428 
    429         self._frame_output_path = _tk.Frame(self._root)
    430         self._frame_output_path.pack(anchor="w")
    431         self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton(
    432             self._frame_output_path,
    433             "Output Audio",
    434             "Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)",
    435             _PathType.MULTIFILE,
    436             _settings.PathKey.OUTPUT_FILE,
    437             hooks=[self._check_button_states],
    438         )
    439 
    440         self._frame_train_destination = _tk.Frame(self._root)
    441         self._frame_train_destination.pack(anchor="w")
    442         self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton(
    443             self._frame_train_destination,
    444             "Train Destination",
    445             "Select training output directory",
    446             _PathType.DIRECTORY,
    447             _settings.PathKey.TRAINING_DESTINATION,
    448             hooks=[self._check_button_states],
    449         )
    450 
    451         # Metadata
    452         self.user_metadata = _UserMetadata()
    453         self._frame_metadata = _tk.Frame(self._root)
    454         self._frame_metadata.pack(anchor="w")
    455         self._widgets["metadata"] = _tk.Button(
    456             self._frame_metadata,
    457             text="Metadata...",
    458             width=_BUTTON_WIDTH,
    459             height=_BUTTON_HEIGHT,
    460             command=self._open_metadata,
    461         )
    462         self._widgets["metadata"].pack()
    463         self.user_metadata_flag = False
    464 
    465         # This should probably be to the right somewhere
    466         self._get_additional_options_frame()
    467 
    468         # Last frames: avdanced options & train in the SE corner:
    469         self._frame_advanced_options = _tk.Frame(self._root)
    470         self._frame_train = _tk.Frame(self._root)
    471         self._frame_update = _tk.Frame(self._root)
    472         # Pack must be in reverse order
    473         self._frame_update.pack(side=_tk.BOTTOM, anchor="e")
    474         self._frame_train.pack(side=_tk.BOTTOM, anchor="e")
    475         self._frame_advanced_options.pack(side=_tk.BOTTOM, anchor="e")
    476 
    477         # Advanced options for training
    478         default_architecture = _core.Architecture.STANDARD
    479         self.advanced_options = AdvancedOptions(
    480             default_architecture,
    481             _DEFAULT_NUM_EPOCHS,
    482             _DEFAULT_DELAY,
    483             _DEFAULT_IGNORE_CHECKS,
    484             _DEFAULT_THRESHOLD_ESR,
    485         )
    486         # Window to edit them:
    487 
    488         self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = _tk.Button(
    489             self._frame_advanced_options,
    490             text="Advanced options...",
    491             width=_BUTTON_WIDTH,
    492             height=_BUTTON_HEIGHT,
    493             command=self._open_advanced_options,
    494         )
    495         self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack()
    496 
    497         # Train button
    498 
    499         self._widgets[_GUIWidgets.TRAIN] = _tk.Button(
    500             self._frame_train,
    501             text="Train",
    502             width=_BUTTON_WIDTH,
    503             height=_BUTTON_HEIGHT,
    504             command=self._train,
    505         )
    506         self._widgets[_GUIWidgets.TRAIN].pack()
    507 
    508         self._pack_update_button_if_update_is_available()
    509 
    510         self._check_button_states()
    511 
    512     def core_train_kwargs(self) -> _Dict[str, _Any]:
    513         """
    514         Get any additional kwargs to provide to `core.train`
    515         """
    516         return {
    517             "lr": 0.004,
    518             "lr_decay": _DEFAULT_LR_DECAY,
    519             "batch_size": _DEFAULT_BATCH_SIZE,
    520             "seed": 0,
    521         }
    522 
    523     def get_mrstft_fit(self) -> bool:
    524         """
    525         Use a pre-emphasized multi-resolution shot-time Fourier transform loss during
    526         training.
    527 
    528         This improves agreement in the high frequencies, usually with a minimial loss in
    529         ESR.
    530         """
    531         # Leave this as a public method to anticipate an extension to make it
    532         # changeable.
    533         return True
    534 
    535     def _check_button_states(self):
    536         """
    537         Determine if any buttons should be disabled
    538         """
    539         # Train button is disabled unless all paths are set
    540         if any(
    541             pb.val is None
    542             for pb in (
    543                 self._widgets[_GUIWidgets.INPUT_PATH],
    544                 self._widgets[_GUIWidgets.OUTPUT_PATH],
    545                 self._widgets[_GUIWidgets.TRAINING_DESTINATION],
    546             )
    547         ):
    548             self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.DISABLED
    549             return
    550         self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.NORMAL
    551 
    552     def _get_additional_options_frame(self):
    553         # Checkboxes
    554         # TODO get these definitions into __init__()
    555         self._frame_checkboxes = _tk.Frame(self._root)
    556         self._frame_checkboxes.pack(side=_tk.LEFT)
    557         row = 1
    558 
    559         def make_checkbox(
    560             key: _CheckboxKeys, text: str, default_value: bool
    561         ) -> Checkbox:
    562             variable = _tk.BooleanVar()
    563             variable.set(default_value)
    564             check_button = _tk.Checkbutton(
    565                 self._frame_checkboxes, text=text, variable=variable
    566             )
    567             self._checkboxes[key] = Checkbox(variable, check_button)
    568             self._widgets[key] = check_button  # For tracking in set-all-widgets ops
    569 
    570         self._checkboxes: _Dict[_CheckboxKeys, Checkbox] = dict()
    571         make_checkbox(
    572             _CheckboxKeys.SILENT_TRAINING,
    573             "Silent run (suggested for batch training)",
    574             False,
    575         )
    576         make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True)
    577 
    578         # Grid them:
    579         row = 1
    580         for v in self._checkboxes.values():
    581             v.check_button.grid(row=row, column=1, sticky="W")
    582             row += 1
    583 
    584     def mainloop(self):
    585         self._root.mainloop()
    586 
    587     def _disable(self):
    588         self._set_all_widget_states_to(_tk.DISABLED)
    589 
    590     def _open_advanced_options(self):
    591         """
    592         Open window for advanced options
    593         """
    594 
    595         self._wait_while_func(lambda resume: AdvancedOptionsGUI(resume, self))
    596 
    597     def _open_metadata(self):
    598         """
    599         Open window for metadata
    600         """
    601 
    602         self._wait_while_func(lambda resume: UserMetadataGUI(resume, self))
    603 
    604     def _pack_update_button(self, version_from: _Version, version_to: _Version):
    605         """
    606         Pack a button that a user can click to update
    607         """
    608 
    609         def update_nam():
    610             result = _subprocess.run(
    611                 [
    612                     f"{_sys.executable}",
    613                     "-m",
    614                     "pip",
    615                     "install",
    616                     "--upgrade",
    617                     "neural-amp-modeler",
    618                 ]
    619             )
    620             if result.returncode == 0:
    621                 self._wait_while_func(
    622                     (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
    623                     "Update complete! Restart NAM for changes to take effect.",
    624                 )
    625             else:
    626                 self._wait_while_func(
    627                     (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
    628                     "Update failed! See logs.",
    629                 )
    630 
    631         self._widgets[_GUIWidgets.UPDATE] = _tk.Button(
    632             self._frame_update,
    633             text=f"Update ({str(version_from)} -> {str(version_to)})",
    634             width=_BUTTON_WIDTH,
    635             height=_BUTTON_HEIGHT,
    636             command=update_nam,
    637         )
    638         self._widgets[_GUIWidgets.UPDATE].pack()
    639 
    640     def _pack_update_button_if_update_is_available(self):
    641         class UpdateInfo(_NamedTuple):
    642             available: bool
    643             current_version: _Version
    644             new_version: _Optional[_Version]
    645 
    646         def get_info() -> UpdateInfo:
    647             # TODO error handling
    648             url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases"
    649             current_version = _get_current_version()
    650             try:
    651                 response = _requests.get(url)
    652             except _requests.exceptions.ConnectionError:
    653                 print("WARNING: Failed to reach the server to check for updates")
    654                 return UpdateInfo(
    655                     available=False, current_version=current_version, new_version=None
    656                 )
    657             if response.status_code != 200:
    658                 print(f"Failed to fetch releases. Status code: {response.status_code}")
    659                 return UpdateInfo(
    660                     available=False, current_version=current_version, new_version=None
    661                 )
    662             else:
    663                 releases = response.json()
    664                 latest_version = None
    665                 if releases:
    666                     for release in releases:
    667                         tag = release["tag_name"]
    668                         if not tag.startswith("v"):
    669                             print(f"Found invalid version {tag}")
    670                         else:
    671                             this_version = _Version.from_string(tag[1:])
    672                             if latest_version is None or this_version > latest_version:
    673                                 latest_version = this_version
    674                 else:
    675                     print("No releases found for this repository.")
    676             update_available = (
    677                 latest_version is not None and latest_version > current_version
    678             )
    679             return UpdateInfo(
    680                 available=update_available,
    681                 current_version=current_version,
    682                 new_version=latest_version,
    683             )
    684 
    685         update_info = get_info()
    686         if update_info.available:
    687             self._pack_update_button(
    688                 update_info.current_version, update_info.new_version
    689             )
    690 
    691     def _resume(self):
    692         self._set_all_widget_states_to(_tk.NORMAL)
    693         self._check_button_states()
    694 
    695     def _set_all_widget_states_to(self, state):
    696         for widget in self._widgets.values():
    697             widget["state"] = state
    698 
    699     def _train(self):
    700         input_path = self._widgets[_GUIWidgets.INPUT_PATH].val
    701         output_paths = self._widgets[_GUIWidgets.OUTPUT_PATH].val
    702         # Validate all files before running:
    703         success = self._validate_all_data(input_path, output_paths)
    704         if success:
    705             self._train2()
    706 
    707     def _train2(self, ignore_checks=False):
    708         input_path = self._widgets[_GUIWidgets.INPUT_PATH].val
    709 
    710         # Advanced options:
    711         num_epochs = self.advanced_options.num_epochs
    712         architecture = self.advanced_options.architecture
    713         user_latency = self.advanced_options.latency
    714         file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
    715         threshold_esr = self.advanced_options.threshold_esr
    716 
    717         # Run it
    718         for file in file_list:
    719             print(f"Now training {file}")
    720             basename = _re.sub(r"\.wav$", "", file.split("/")[-1])
    721             user_metadata = (
    722                 self.user_metadata if self.user_metadata_flag else _UserMetadata()
    723             )
    724 
    725             train_output = _core.train(
    726                 input_path,
    727                 file,
    728                 self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
    729                 epochs=num_epochs,
    730                 latency=user_latency,
    731                 architecture=architecture,
    732                 silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
    733                 save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
    734                 modelname=basename,
    735                 ignore_checks=ignore_checks,
    736                 local=True,
    737                 fit_mrstft=self.get_mrstft_fit(),
    738                 threshold_esr=threshold_esr,
    739                 user_metadata=user_metadata,
    740                 **self.core_train_kwargs(),
    741             )
    742 
    743             if train_output.model is None:
    744                 print("Model training failed! Skip exporting...")
    745                 continue
    746             print("Model training complete!")
    747             print("Exporting...")
    748             outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val
    749             print(f"Exporting trained model to {outdir}...")
    750             train_output.model.net.export(
    751                 outdir,
    752                 basename=basename,
    753                 user_metadata=user_metadata,
    754                 other_metadata={
    755                     _metadata.TRAINING_KEY: train_output.metadata.model_dump()
    756                 },
    757             )
    758             print("Done!")
    759 
    760         # Metadata was only valid for 1 run (possibly a batch), so make sure it's not
    761         # used again unless the user re-visits the window and clicks "ok".
    762         self.user_metadata_flag = False
    763 
    764     def _validate_all_data(
    765         self, input_path: _Path, output_paths: _Sequence[_Path]
    766     ) -> bool:
    767         """
    768         Validate all the data.
    769         If something doesn't pass, then alert the user and ask them whether they
    770         want to continue.
    771 
    772         :return: whether we passed (NOTE: Training in spite of failure is
    773             triggered by a modal that is produced on failure.)
    774         """
    775 
    776         def make_message_for_file(
    777             output_path: str, validation_output: _core.DataValidationOutput
    778         ) -> str:
    779             """
    780             State the file and explain what's wrong with it.
    781             """
    782             # TODO put this closer to what it looks at, i.e. core.DataValidationOutput
    783             msg = (
    784                 f"\t{_Path(output_path).name}:\n"  # They all have the same directory so
    785             )
    786             if not validation_output.sample_rate.passed:
    787                 msg += (
    788                     "\t\t There are different sample rates for the input ("
    789                     f"{validation_output.sample_rate.input}) and output ("
    790                     f"{validation_output.sample_rate.output}).\n"
    791                 )
    792             if not validation_output.length.passed:
    793                 msg += (
    794                     "\t\t* The input and output audio files are too different in length"
    795                 )
    796                 if validation_output.length.delta_seconds > 0:
    797                     msg += (
    798                         f" (the output is {validation_output.length.delta_seconds:.2f} "
    799                         "seconds longer than the input)\n"
    800                     )
    801                 else:
    802                     msg += (
    803                         f" (the output is {-validation_output.length.delta_seconds:.2f}"
    804                         " seconds shorter than the input)\n"
    805                     )
    806             if validation_output.latency.manual is None:
    807                 if validation_output.latency.calibration.warnings.matches_lookahead:
    808                     msg += (
    809                         "\t\t* The calibrated latency is the maximum allowed. This is "
    810                         "probably because the latency calibration was triggered by noise.\n"
    811                     )
    812                 if validation_output.latency.calibration.warnings.disagreement_too_high:
    813                     msg += "\t\t* The calculated latencies are too different from each other.\n"
    814             if not validation_output.checks.passed:
    815                 msg += "\t\t* A data check failed (TODO in more detail).\n"
    816             if not validation_output.pytorch.passed:
    817                 msg += "\t\t* PyTorch data set errors:\n"
    818                 for split in _Split:
    819                     split_validation = getattr(validation_output.pytorch, split.value)
    820                     if not split_validation.passed:
    821                         msg += f"   * {split.value:10s}: {split_validation.msg}\n"
    822             return msg
    823 
    824         # Validate input
    825         input_validation = _core.validate_input(input_path)
    826         if not input_validation.passed:
    827             self._wait_while_func(
    828                 (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
    829                 f"Input file {input_path} is not recognized as a standardized input "
    830                 "file.\nTraining cannot proceed.",
    831             )
    832             return False
    833 
    834         user_latency = self.advanced_options.latency
    835         file_validation_outputs = {
    836             output_path: _core.validate_data(
    837                 input_path,
    838                 output_path,
    839                 user_latency,
    840             )
    841             for output_path in output_paths
    842         }
    843         if any(not fv.passed for fv in file_validation_outputs.values()):
    844             msg = "The following output files failed checks:\n" + "".join(
    845                 [
    846                     make_message_for_file(output_path, fv)
    847                     for output_path, fv in file_validation_outputs.items()
    848                     if not fv.passed
    849                 ]
    850             )
    851             if all(fv.passed_critical for fv in file_validation_outputs.values()):
    852                 msg += "\nIgnore and proceed?"
    853 
    854                 # Hacky to listen to the modal:
    855                 modal_listener = {"proceed": False, "still_open": True}
    856 
    857                 def on_yes():
    858                     modal_listener["proceed"] = True
    859 
    860                 def on_no():
    861                     modal_listener["proceed"] = False
    862 
    863                 def on_close():
    864                     if modal_listener["proceed"]:
    865                         self._train2(ignore_checks=True)
    866 
    867                 self._wait_while_func(
    868                     (
    869                         lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal(
    870                             on_yes, on_no, resume, *args, **kwargs
    871                         )
    872                     ),
    873                     on_yes=on_yes,
    874                     on_no=on_no,
    875                     msg=msg,
    876                     on_close=on_close,
    877                     label_kwargs={"justify": "left"},
    878                 )
    879                 return False  # we still failed checks so say so.
    880             else:
    881                 msg += "\nCritical errors found, cannot ignore."
    882                 self._wait_while_func(
    883                     lambda resume, msg, **kwargs: _OkModal(resume, msg, **kwargs),
    884                     msg=msg,
    885                     label_kwargs={"justify": "left"},
    886                 )
    887                 return False
    888 
    889         return True
    890 
    891     def _wait_while_func(self, func, *args, **kwargs):
    892         """
    893         Disable this GUI while something happens.
    894         That function _needs_ to call the provided self._resume when it's ready to
    895         release me!
    896         """
    897         self._disable()
    898         func(self._resume, *args, **kwargs)
    899 
    900 
    901 # some typing functions
    902 def _non_negative_int(val):
    903     val = int(val)
    904     if val < 0:
    905         val = 0
    906     return val
    907 
    908 
    909 class _TypeOrNull(object):
    910     def __init__(self, T, null_str=""):
    911         """
    912         :param T: tpe to cast to on .forward()
    913         """
    914         self._T = T
    915         self._null_str = null_str
    916 
    917     @property
    918     def null_str(self) -> str:
    919         """
    920         What str is displayed when for "None"
    921         """
    922         return self._null_str
    923 
    924     def forward(self, val: str):
    925         val = val.rstrip()
    926         return None if val == self._null_str else self._T(val)
    927 
    928     def inverse(self, val) -> str:
    929         return self._null_str if val is None else str(val)
    930 
    931 
    932 _int_or_null = _TypeOrNull(int)
    933 _float_or_null = _TypeOrNull(float)
    934 
    935 
    936 def _rstripped_str(val):
    937     return str(val).rstrip()
    938 
    939 
    940 class _SettingWidget(_abc.ABC):
    941     """
    942     A widget for the user to interact with to set something
    943     """
    944 
    945     @_abc.abstractmethod
    946     def get(self):
    947         pass
    948 
    949 
    950 class LabeledOptionMenu(_SettingWidget):
    951     """
    952     Label (left) and radio buttons (right)
    953     """
    954 
    955     def __init__(
    956         self,
    957         frame: _tk.Frame,
    958         label: str,
    959         choices: _Enum,
    960         default: _Optional[_Enum] = None,
    961     ):
    962         """
    963         :param command: Called to propagate option selection. Is provided with the
    964             value corresponding to the radio button selected.
    965         """
    966         self._frame = frame
    967         self._choices = choices
    968         height = _BUTTON_HEIGHT
    969         bg = None
    970         self._label = _tk.Label(
    971             frame,
    972             width=_ADVANCED_OPTIONS_LEFT_WIDTH,
    973             height=height,
    974             bg=bg,
    975             anchor="w",
    976             text=label,
    977         )
    978         self._label.pack(side=_tk.LEFT)
    979 
    980         frame_menu = _tk.Frame(frame)
    981         frame_menu.pack(side=_tk.RIGHT)
    982 
    983         self._selected_value = None
    984         default = (list(choices)[0] if default is None else default).value
    985         self._menu = _tk.OptionMenu(
    986             frame_menu,
    987             _tk.StringVar(master=frame, value=default, name=label),
    988             # default,
    989             *[choice.value for choice in choices],  #  if choice.value!=default],
    990             command=self._set,
    991         )
    992         self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH)
    993         self._menu.pack(side=_tk.RIGHT)
    994         # Initialize
    995         self._set(default)
    996 
    997     def get(self) -> _Enum:
    998         return self._selected_value
    999 
   1000     def _set(self, val: str):
   1001         """
   1002         Set the value selected
   1003         """
   1004         self._selected_value = self._choices(val)
   1005 
   1006 
   1007 class _Hovertip(Hovertip):
   1008     """
   1009     Adjustments:
   1010 
   1011     * Always black text (macOS)
   1012     """
   1013 
   1014     def showcontents(self):
   1015         # Override
   1016         label = _tk.Label(
   1017             self.tipwindow,
   1018             text=self.text,
   1019             justify=_tk.LEFT,
   1020             background="#ffffe0",
   1021             relief=_tk.SOLID,
   1022             borderwidth=1,
   1023             fg="black",
   1024         )
   1025         label.pack()
   1026 
   1027 
   1028 class LabeledText(_SettingWidget):
   1029     """
   1030     Label (left) and text input (right)
   1031     """
   1032 
   1033     def __init__(
   1034         self,
   1035         frame: _tk.Frame,
   1036         label: str,
   1037         default=None,
   1038         type=None,
   1039         left_width=_ADVANCED_OPTIONS_LEFT_WIDTH,
   1040         right_width=_ADVANCED_OPTIONS_RIGHT_WIDTH,
   1041     ):
   1042         """
   1043         :param command: Called to propagate option selection. Is provided with the
   1044             value corresponding to the radio button selected.
   1045         :param type: If provided, casts value to given type
   1046         :param left_width: How much space to use on the left side (text)
   1047         :param right_width: How much space for the Text field
   1048         """
   1049         self._frame = frame
   1050         label_height = 2
   1051         text_height = 1
   1052         self._label = _tk.Label(
   1053             frame,
   1054             width=left_width,
   1055             height=label_height,
   1056             bg=None,
   1057             anchor="e",
   1058             text=label,
   1059         )
   1060         self._label.pack(side=_tk.LEFT)
   1061 
   1062         self._text = _tk.Text(
   1063             frame,
   1064             width=right_width,
   1065             height=text_height,
   1066             bg=None,
   1067         )
   1068         self._text.pack(side=_tk.RIGHT)
   1069 
   1070         self._type = (lambda x: x) if type is None else type
   1071 
   1072         if default is not None:
   1073             self._text.insert("1.0", str(default))
   1074 
   1075         # You can assign a tooltip for the label if you'd like.
   1076         self.label_tooltip: _Optional[_Hovertip] = None
   1077 
   1078     @property
   1079     def label(self) -> _tk.Label:
   1080         return self._label
   1081 
   1082     def get(self):
   1083         """
   1084         Attempt to get and return the value.
   1085         May throw a tk.TclError indicating something went wrong getting the value.
   1086         """
   1087         # "1.0" means Line 1, character zero (wat)
   1088         return self._type(self._text.get("1.0", _tk.END))
   1089 
   1090 
   1091 class AdvancedOptionsGUI(object):
   1092     """
   1093     A window to hold advanced options (Architecture and number of epochs)
   1094     """
   1095 
   1096     def __init__(self, resume_main, parent: GUI):
   1097         self._parent = parent
   1098         self._root = _TopLevelWithOk(self.apply, resume_main)
   1099         self._root.title("Advanced Options")
   1100 
   1101         self.pack()
   1102 
   1103         # "Ok": apply and destroy
   1104         self._frame_ok = _tk.Frame(self._root)
   1105         self._frame_ok.pack()
   1106         self._button_ok = _tk.Button(
   1107             self._frame_ok,
   1108             text="Ok",
   1109             width=_BUTTON_WIDTH,
   1110             height=_BUTTON_HEIGHT,
   1111             command=lambda: self._root.destroy(pressed_ok=True),
   1112         )
   1113         self._button_ok.pack()
   1114 
   1115     def apply(self):
   1116         """
   1117         Set values to parent and destroy this object
   1118         """
   1119 
   1120         def safe_apply(name):
   1121             try:
   1122                 setattr(
   1123                     self._parent.advanced_options, name, getattr(self, "_" + name).get()
   1124                 )
   1125             except ValueError:
   1126                 pass
   1127 
   1128         # TODO could clean up more / see `.pack_options()`
   1129         for name in ("architecture", "num_epochs", "latency", "threshold_esr"):
   1130             safe_apply(name)
   1131 
   1132     def pack(self):
   1133         # TODO things that are `_SettingWidget`s are named carefully, need to make this
   1134         # easier to work with.
   1135 
   1136         # Architecture: radio buttons
   1137         self._frame_architecture = _tk.Frame(self._root)
   1138         self._frame_architecture.pack()
   1139         self._architecture = LabeledOptionMenu(
   1140             self._frame_architecture,
   1141             "Architecture",
   1142             _core.Architecture,
   1143             default=self._parent.advanced_options.architecture,
   1144         )
   1145 
   1146         # Number of epochs: text box
   1147         self._frame_epochs = _tk.Frame(self._root)
   1148         self._frame_epochs.pack()
   1149 
   1150         self._num_epochs = LabeledText(
   1151             self._frame_epochs,
   1152             "Epochs",
   1153             default=str(self._parent.advanced_options.num_epochs),
   1154             type=_non_negative_int,
   1155         )
   1156 
   1157         # Delay: text box
   1158         self._frame_latency = _tk.Frame(self._root)
   1159         self._frame_latency.pack()
   1160 
   1161         self._latency = LabeledText(
   1162             self._frame_latency,
   1163             "Reamp latency",
   1164             default=_int_or_null.inverse(self._parent.advanced_options.latency),
   1165             type=_int_or_null.forward,
   1166         )
   1167 
   1168         # Threshold ESR
   1169         self._frame_threshold_esr = _tk.Frame(self._root)
   1170         self._frame_threshold_esr.pack()
   1171         self._threshold_esr = LabeledText(
   1172             self._frame_threshold_esr,
   1173             "Threshold ESR",
   1174             default=_float_or_null.inverse(self._parent.advanced_options.threshold_esr),
   1175             type=_float_or_null.forward,
   1176         )
   1177 
   1178 
   1179 class UserMetadataGUI(object):
   1180     # Things that are auto-filled:
   1181     # Model date
   1182     # gain
   1183     def __init__(self, resume_main, parent: GUI):
   1184         self._parent = parent
   1185         self._root = _TopLevelWithOk(self.apply, resume_main)
   1186         self._root.title("Metadata")
   1187 
   1188         # Pack all the widgets
   1189         self.pack()
   1190 
   1191         # "Ok": apply and destroy
   1192         self._frame_ok = _tk.Frame(self._root)
   1193         self._frame_ok.pack()
   1194         self._button_ok = _tk.Button(
   1195             self._frame_ok,
   1196             text="Ok",
   1197             width=_BUTTON_WIDTH,
   1198             height=_BUTTON_HEIGHT,
   1199             command=lambda: self._root.destroy(pressed_ok=True),
   1200         )
   1201         self._button_ok.pack()
   1202 
   1203     def apply(self):
   1204         """
   1205         Set values to parent and destroy this object
   1206         """
   1207 
   1208         def safe_apply(name):
   1209             try:
   1210                 setattr(
   1211                     self._parent.user_metadata, name, getattr(self, "_" + name).get()
   1212                 )
   1213             except ValueError:
   1214                 pass
   1215 
   1216         # TODO could clean up more / see `.pack()`
   1217         for name in (
   1218             "name",
   1219             "modeled_by",
   1220             "gear_make",
   1221             "gear_model",
   1222             "gear_type",
   1223             "tone_type",
   1224             "input_level_dbu",
   1225             "output_level_dbu",
   1226         ):
   1227             safe_apply(name)
   1228         self._parent.user_metadata_flag = True
   1229 
   1230     def pack(self):
   1231         # TODO things that are `_SettingWidget`s are named carefully, need to make this
   1232         # easier to work with.
   1233 
   1234         LabeledText_ = _partial(
   1235             LabeledText,
   1236             left_width=_METADATA_LEFT_WIDTH,
   1237             right_width=_METADATA_RIGHT_WIDTH,
   1238         )
   1239         parent = self._parent
   1240 
   1241         # Name
   1242         self._frame_name = _tk.Frame(self._root)
   1243         self._frame_name.pack()
   1244         self._name = LabeledText_(
   1245             self._frame_name,
   1246             "NAM name",
   1247             default=parent.user_metadata.name,
   1248             type=_rstripped_str,
   1249         )
   1250         # Modeled by
   1251         self._frame_modeled_by = _tk.Frame(self._root)
   1252         self._frame_modeled_by.pack()
   1253         self._modeled_by = LabeledText_(
   1254             self._frame_modeled_by,
   1255             "Modeled by",
   1256             default=parent.user_metadata.modeled_by,
   1257             type=_rstripped_str,
   1258         )
   1259         # Gear make
   1260         self._frame_gear_make = _tk.Frame(self._root)
   1261         self._frame_gear_make.pack()
   1262         self._gear_make = LabeledText_(
   1263             self._frame_gear_make,
   1264             "Gear make",
   1265             default=parent.user_metadata.gear_make,
   1266             type=_rstripped_str,
   1267         )
   1268         # Gear model
   1269         self._frame_gear_model = _tk.Frame(self._root)
   1270         self._frame_gear_model.pack()
   1271         self._gear_model = LabeledText_(
   1272             self._frame_gear_model,
   1273             "Gear model",
   1274             default=parent.user_metadata.gear_model,
   1275             type=_rstripped_str,
   1276         )
   1277         # Calibration: input & output dBu
   1278         self._frame_input_dbu = _tk.Frame(self._root)
   1279         self._frame_input_dbu.pack()
   1280         self._input_level_dbu = LabeledText_(
   1281             self._frame_input_dbu,
   1282             "Reamp send level (dBu)",
   1283             default=_float_or_null.inverse(parent.user_metadata.input_level_dbu),
   1284             type=_float_or_null.forward,
   1285         )
   1286         self._input_level_dbu.label_tooltip = _Hovertip(
   1287             anchor_widget=self._input_level_dbu.label,
   1288             text=(
   1289                 "(Ok to leave blank)\n\n"
   1290                 "Play a sine wave with frequency 1kHz and peak amplitude 0dBFS. Use\n"
   1291                 "a multimeter to measure the RMS voltage of the signal at the jack\n"
   1292                 "that connects to your gear, and convert to dBu.\n"
   1293                 "Record the value here."
   1294             ),
   1295         )
   1296         self._frame_output_dbu = _tk.Frame(self._root)
   1297         self._frame_output_dbu.pack()
   1298         self._output_level_dbu = LabeledText_(
   1299             self._frame_output_dbu,
   1300             "Reamp return level (dBu)",
   1301             default=_float_or_null.inverse(parent.user_metadata.output_level_dbu),
   1302             type=_float_or_null.forward,
   1303         )
   1304         self._output_level_dbu.label_tooltip = _Hovertip(
   1305             anchor_widget=self._output_level_dbu.label,
   1306             text=(
   1307                 "(Ok to leave blank)\n\n"
   1308                 "Play a sine wave with frequency 1kHz into your interface where\n"
   1309                 "you're recording your gear. Keeping the interface's input gain\n"
   1310                 "trimmed as you will use it when recording, adjust the sine wave\n"
   1311                 "until the input peaks at exactly 0dBFS in your DAW. Measure the RMS\n"
   1312                 "voltage and convert to dBu.\n"
   1313                 "Record the value here."
   1314             ),
   1315         )
   1316         # Gear type
   1317         self._frame_gear_type = _tk.Frame(self._root)
   1318         self._frame_gear_type.pack()
   1319         self._gear_type = LabeledOptionMenu(
   1320             self._frame_gear_type,
   1321             "Gear type",
   1322             _GearType,
   1323             default=parent.user_metadata.gear_type,
   1324         )
   1325         # Tone type
   1326         self._frame_tone_type = _tk.Frame(self._root)
   1327         self._frame_tone_type.pack()
   1328         self._tone_type = LabeledOptionMenu(
   1329             self._frame_tone_type,
   1330             "Tone type",
   1331             _ToneType,
   1332             default=parent.user_metadata.tone_type,
   1333         )
   1334 
   1335 
   1336 def _install_error():
   1337     window = _tk.Tk()
   1338     window.title("ERROR")
   1339     label = _tk.Label(
   1340         window,
   1341         width=45,
   1342         height=2,
   1343         text="The NAM training software has not been installed correctly.",
   1344     )
   1345     label.pack()
   1346     button = _tk.Button(window, width=10, height=2, text="Quit", command=window.destroy)
   1347     button.pack()
   1348     window.mainloop()
   1349 
   1350 
   1351 def run():
   1352     if _install_is_valid:
   1353         _gui = GUI()
   1354         _gui.mainloop()
   1355         print("Shut down NAM trainer")
   1356     else:
   1357         _install_error()
   1358 
   1359 
   1360 if __name__ == "__main__":
   1361     run()