__init__.py (44138B)
1 # File: gui.py 2 # Created Date: Saturday February 25th 2023 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 """ 6 GUI for training 7 8 Usage: 9 >>> from nam.train.gui import run 10 >>> run() 11 """ 12 13 import abc as _abc 14 import re as _re 15 import requests as _requests 16 import tkinter as _tk 17 import subprocess as _subprocess 18 import sys as _sys 19 import webbrowser as _webbrowser 20 from dataclasses import dataclass as _dataclass 21 from enum import Enum as _Enum 22 from functools import partial as _partial 23 24 try: # Not supported in Colab 25 from idlelib.tooltip import Hovertip 26 except ModuleNotFoundError: 27 # Hovertips won't work 28 class Hovertip(object): 29 """ 30 Shell class 31 """ 32 33 def __init__(self, *args, **kwargs): 34 pass 35 36 37 from pathlib import Path as _Path 38 from tkinter import filedialog as _filedialog 39 from typing import ( 40 Any as _Any, 41 Callable as _Callable, 42 Dict as _Dict, 43 NamedTuple as _NamedTuple, 44 Optional as _Optional, 45 Sequence as _Sequence, 46 ) 47 48 try: # 3rd-party and 1st-party imports 49 import torch as _torch 50 51 from nam import __version__ 52 from nam.data import Split as _Split 53 from nam.train import core as _core 54 from nam.train.gui._resources import settings as _settings 55 from nam.models.metadata import ( 56 GearType as _GearType, 57 UserMetadata as _UserMetadata, 58 ToneType as _ToneType, 59 ) 60 61 # Ok private access here--this is technically allowed access 62 from nam.train import metadata as _metadata 63 from nam.train._names import ( 64 INPUT_BASENAMES as _INPUT_BASENAMES, 65 LATEST_VERSION as _LATEST_VERSION, 66 ) 67 from nam.train._version import ( 68 Version as _Version, 69 get_current_version as _get_current_version, 70 ) 71 72 _install_is_valid = True 73 _HAVE_ACCELERATOR = _torch.cuda.is_available() or _torch.backends.mps.is_available() 74 except ImportError: 75 _install_is_valid = False 76 _HAVE_ACCELERATOR = False 77 78 if _HAVE_ACCELERATOR: 79 _DEFAULT_NUM_EPOCHS = 100 80 _DEFAULT_BATCH_SIZE = 16 81 _DEFAULT_LR_DECAY = 0.007 82 else: 83 _DEFAULT_NUM_EPOCHS = 20 84 _DEFAULT_BATCH_SIZE = 1 85 _DEFAULT_LR_DECAY = 0.05 86 _BUTTON_WIDTH = 20 87 _BUTTON_HEIGHT = 2 88 _TEXT_WIDTH = 70 89 90 _DEFAULT_DELAY = None 91 _DEFAULT_IGNORE_CHECKS = False 92 _DEFAULT_THRESHOLD_ESR = None 93 94 _ADVANCED_OPTIONS_LEFT_WIDTH = 12 95 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12 96 _METADATA_LEFT_WIDTH = 19 97 _METADATA_RIGHT_WIDTH = 60 98 99 100 def _is_mac() -> bool: 101 return _sys.platform == "darwin" 102 103 104 _SYSTEM_TEXT_COLOR = "systemTextColor" if _is_mac() else "black" 105 106 107 @_dataclass 108 class AdvancedOptions(object): 109 """ 110 :param architecture: Which architecture to use. 111 :param num_epochs: How many epochs to train for. 112 :param latency: Latency between the input and output audio, in samples. 113 None means we don't know and it has to be calibrated. 114 :param ignore_checks: Keep going even if a check says that something is wrong. 115 :param threshold_esr: Stop training if the ESR gets better than this. If None, don't 116 stop. 117 """ 118 119 architecture: _core.Architecture 120 num_epochs: int 121 latency: _Optional[int] 122 ignore_checks: bool 123 threshold_esr: _Optional[float] 124 125 126 class _PathType(_Enum): 127 FILE = "file" 128 DIRECTORY = "directory" 129 MULTIFILE = "multifile" 130 131 132 class _PathButton(object): 133 """ 134 Button and the path 135 """ 136 137 def __init__( 138 self, 139 frame: _tk.Frame, 140 button_text: str, 141 info_str: str, 142 path_type: _PathType, 143 path_key: _settings.PathKey, 144 hooks: _Optional[_Sequence[_Callable[[], None]]] = None, 145 color_when_not_set: str = "#EF0000", # Darker red 146 color_when_set: str = _SYSTEM_TEXT_COLOR, 147 default: _Optional[_Path] = None, 148 ): 149 """ 150 :param hooks: Callables run at the end of setting the value. 151 """ 152 self._button_text = button_text 153 self._info_str = info_str 154 self._path: _Optional[_Path] = default 155 self._path_type = path_type 156 self._path_key = path_key 157 self._frame = frame 158 self._widgets = {} 159 self._widgets["button"] = _tk.Button( 160 self._frame, 161 text=button_text, 162 width=_BUTTON_WIDTH, 163 height=_BUTTON_HEIGHT, 164 command=self._set_val, 165 ) 166 self._widgets["button"].pack(side=_tk.LEFT) 167 self._widgets["label"] = _tk.Label( 168 self._frame, 169 width=_TEXT_WIDTH, 170 height=_BUTTON_HEIGHT, 171 bg=None, 172 anchor="w", 173 ) 174 self._widgets["label"].pack(side=_tk.LEFT) 175 self._hooks = hooks 176 self._color_when_not_set = color_when_not_set 177 self._color_when_set = color_when_set 178 self._set_text() 179 180 def __setitem__(self, key, val): 181 """ 182 Implement tk-style setter for state 183 """ 184 if key == "state": 185 for widget in self._widgets.values(): 186 widget["state"] = val 187 else: 188 raise RuntimeError( 189 f"{self.__class__.__name__} instance does not support item assignment for non-state key {key}!" 190 ) 191 192 @property 193 def val(self) -> _Optional[_Path]: 194 return self._path 195 196 def _set_text(self): 197 if self._path is None: 198 self._widgets["label"]["fg"] = self._color_when_not_set 199 self._widgets["label"]["text"] = self._info_str 200 else: 201 val = self.val 202 val = val[0] if isinstance(val, tuple) and len(val) == 1 else val 203 self._widgets["label"]["fg"] = self._color_when_set 204 self._widgets["label"][ 205 "text" 206 ] = f"{self._button_text.capitalize()} set to {val}" 207 208 def _set_val(self): 209 last_path = _settings.get_last_path(self._path_key) 210 if last_path is None: 211 initial_dir = None 212 elif not last_path.is_dir(): 213 initial_dir = last_path.parent 214 else: 215 initial_dir = last_path 216 result = { 217 _PathType.FILE: _filedialog.askopenfilename, 218 _PathType.DIRECTORY: _filedialog.askdirectory, 219 _PathType.MULTIFILE: _filedialog.askopenfilenames, 220 }[self._path_type](initialdir=str(initial_dir)) 221 if result != "": 222 self._path = result 223 _settings.set_last_path( 224 self._path_key, 225 _Path(result[0] if self._path_type == _PathType.MULTIFILE else result), 226 ) 227 self._set_text() 228 229 if self._hooks is not None: 230 for h in self._hooks: 231 h() 232 233 234 class _InputPathButton(_PathButton): 235 def __init__(self, *args, **kwargs): 236 super().__init__(*args, **kwargs) 237 # Download the training file! 238 self._widgets["button_download_input"] = _tk.Button( 239 self._frame, 240 text="Download input file", 241 width=_BUTTON_WIDTH, 242 height=_BUTTON_HEIGHT, 243 command=self._download_input_file, 244 ) 245 self._widgets["button_download_input"].pack(side=_tk.RIGHT) 246 247 @classmethod 248 def _download_input_file(cls): 249 file_urls = { 250 "input.wav": "https://drive.google.com/file/d/1KbaS4oXXNEuh2aCPLwKrPdf5KFOjda8G/view?usp=drive_link", 251 "v3_0_0.wav": "https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link", 252 "v2_0_0.wav": "https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link", 253 "v1_1_1.wav": "", 254 "v1.wav": "", 255 } 256 # Pick the most recent file. 257 for input_basename in _INPUT_BASENAMES: 258 name = input_basename.name 259 url = file_urls.get(name) 260 if url: 261 if name != _LATEST_VERSION.name: 262 print( 263 f"WARNING: File {name} is out of date. " 264 "This needs to be updated!" 265 ) 266 _webbrowser.open(url) 267 return 268 269 270 class _CheckboxKeys(_Enum): 271 """ 272 Keys for checkboxes 273 """ 274 275 SILENT_TRAINING = "silent_training" 276 SAVE_PLOT = "save_plot" 277 278 279 class _TopLevelWithOk(_tk.Toplevel): 280 """ 281 Toplevel with an Ok button (provide yourself!) 282 """ 283 284 def __init__( 285 self, on_ok: _Callable[[None], None], resume_main: _Callable[[None], None] 286 ): 287 """ 288 :param on_ok: What to do when "Ok" button is pressed 289 """ 290 super().__init__() 291 self._on_ok = on_ok 292 self._resume_main = resume_main 293 294 def destroy(self, pressed_ok: bool = False): 295 if pressed_ok: 296 self._on_ok() 297 self._resume_main() 298 super().destroy() 299 300 301 class _TopLevelWithYesNo(_tk.Toplevel): 302 """ 303 Toplevel holding functions for yes/no buttons to close 304 """ 305 306 def __init__( 307 self, 308 on_yes: _Callable[[None], None], 309 on_no: _Callable[[None], None], 310 on_close: _Optional[_Callable[[None], None]], 311 resume_main: _Callable[[None], None], 312 ): 313 """ 314 :param on_yes: What to do when "Yes" button is pressed. 315 :param on_no: What to do when "No" button is pressed. 316 :param on_close: Do this regardless when closing (via yes/no/x) before 317 resuming. 318 """ 319 super().__init__() 320 self._on_yes = on_yes 321 self._on_no = on_no 322 self._on_close = on_close 323 self._resume_main = resume_main 324 325 def destroy(self, pressed_yes: bool = False, pressed_no: bool = False): 326 if pressed_yes: 327 self._on_yes() 328 if pressed_no: 329 self._on_no() 330 if self._on_close is not None: 331 self._on_close() 332 self._resume_main() 333 super().destroy() 334 335 336 class _OkModal(object): 337 """ 338 Message and OK button 339 """ 340 341 def __init__(self, resume_main, msg: str, label_kwargs: _Optional[dict] = None): 342 label_kwargs = {} if label_kwargs is None else label_kwargs 343 344 self._root = _TopLevelWithOk((lambda: None), resume_main) 345 self._text = _tk.Label(self._root, text=msg, **label_kwargs) 346 self._text.pack() 347 self._ok = _tk.Button( 348 self._root, 349 text="Ok", 350 width=_BUTTON_WIDTH, 351 height=_BUTTON_HEIGHT, 352 command=lambda: self._root.destroy(pressed_ok=True), 353 ) 354 self._ok.pack() 355 356 357 class _YesNoModal(object): 358 """ 359 Modal w/ yes/no buttons 360 """ 361 362 def __init__( 363 self, 364 on_yes: _Callable[[None], None], 365 on_no: _Callable[[None], None], 366 resume_main, 367 msg: str, 368 on_close: _Optional[_Callable[[None], None]] = None, 369 label_kwargs: _Optional[dict] = None, 370 ): 371 label_kwargs = {} if label_kwargs is None else label_kwargs 372 self._root = _TopLevelWithYesNo(on_yes, on_no, on_close, resume_main) 373 self._text = _tk.Label(self._root, text=msg, **label_kwargs) 374 self._text.pack() 375 self._buttons_frame = _tk.Frame(self._root) 376 self._buttons_frame.pack() 377 self._yes = _tk.Button( 378 self._buttons_frame, 379 text="Yes", 380 width=_BUTTON_WIDTH, 381 height=_BUTTON_HEIGHT, 382 command=lambda: self._root.destroy(pressed_yes=True), 383 ) 384 self._yes.pack(side=_tk.LEFT) 385 self._no = _tk.Button( 386 self._buttons_frame, 387 text="No", 388 width=_BUTTON_WIDTH, 389 height=_BUTTON_HEIGHT, 390 command=lambda: self._root.destroy(pressed_no=True), 391 ) 392 self._no.pack(side=_tk.RIGHT) 393 394 395 class _GUIWidgets(_Enum): 396 INPUT_PATH = "input_path" 397 OUTPUT_PATH = "output_path" 398 TRAINING_DESTINATION = "training_destination" 399 METADATA = "metadata" 400 ADVANCED_OPTIONS = "advanced_options" 401 TRAIN = "train" 402 UPDATE = "update" 403 404 405 @_dataclass 406 class Checkbox(object): 407 variable: _tk.BooleanVar 408 check_button: _tk.Checkbutton 409 410 411 class GUI(object): 412 def __init__(self): 413 self._root = _tk.Tk() 414 self._root.title(f"NAM Trainer - v{__version__}") 415 self._widgets = {} 416 417 # Buttons for paths: 418 self._frame_input = _tk.Frame(self._root) 419 self._frame_input.pack(anchor="w") 420 self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton( 421 self._frame_input, 422 "Input Audio", 423 f"Select input (DI) file (e.g. {_LATEST_VERSION.name})", 424 _PathType.FILE, 425 _settings.PathKey.INPUT_FILE, 426 hooks=[self._check_button_states], 427 ) 428 429 self._frame_output_path = _tk.Frame(self._root) 430 self._frame_output_path.pack(anchor="w") 431 self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton( 432 self._frame_output_path, 433 "Output Audio", 434 "Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)", 435 _PathType.MULTIFILE, 436 _settings.PathKey.OUTPUT_FILE, 437 hooks=[self._check_button_states], 438 ) 439 440 self._frame_train_destination = _tk.Frame(self._root) 441 self._frame_train_destination.pack(anchor="w") 442 self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton( 443 self._frame_train_destination, 444 "Train Destination", 445 "Select training output directory", 446 _PathType.DIRECTORY, 447 _settings.PathKey.TRAINING_DESTINATION, 448 hooks=[self._check_button_states], 449 ) 450 451 # Metadata 452 self.user_metadata = _UserMetadata() 453 self._frame_metadata = _tk.Frame(self._root) 454 self._frame_metadata.pack(anchor="w") 455 self._widgets["metadata"] = _tk.Button( 456 self._frame_metadata, 457 text="Metadata...", 458 width=_BUTTON_WIDTH, 459 height=_BUTTON_HEIGHT, 460 command=self._open_metadata, 461 ) 462 self._widgets["metadata"].pack() 463 self.user_metadata_flag = False 464 465 # This should probably be to the right somewhere 466 self._get_additional_options_frame() 467 468 # Last frames: avdanced options & train in the SE corner: 469 self._frame_advanced_options = _tk.Frame(self._root) 470 self._frame_train = _tk.Frame(self._root) 471 self._frame_update = _tk.Frame(self._root) 472 # Pack must be in reverse order 473 self._frame_update.pack(side=_tk.BOTTOM, anchor="e") 474 self._frame_train.pack(side=_tk.BOTTOM, anchor="e") 475 self._frame_advanced_options.pack(side=_tk.BOTTOM, anchor="e") 476 477 # Advanced options for training 478 default_architecture = _core.Architecture.STANDARD 479 self.advanced_options = AdvancedOptions( 480 default_architecture, 481 _DEFAULT_NUM_EPOCHS, 482 _DEFAULT_DELAY, 483 _DEFAULT_IGNORE_CHECKS, 484 _DEFAULT_THRESHOLD_ESR, 485 ) 486 # Window to edit them: 487 488 self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = _tk.Button( 489 self._frame_advanced_options, 490 text="Advanced options...", 491 width=_BUTTON_WIDTH, 492 height=_BUTTON_HEIGHT, 493 command=self._open_advanced_options, 494 ) 495 self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack() 496 497 # Train button 498 499 self._widgets[_GUIWidgets.TRAIN] = _tk.Button( 500 self._frame_train, 501 text="Train", 502 width=_BUTTON_WIDTH, 503 height=_BUTTON_HEIGHT, 504 command=self._train, 505 ) 506 self._widgets[_GUIWidgets.TRAIN].pack() 507 508 self._pack_update_button_if_update_is_available() 509 510 self._check_button_states() 511 512 def core_train_kwargs(self) -> _Dict[str, _Any]: 513 """ 514 Get any additional kwargs to provide to `core.train` 515 """ 516 return { 517 "lr": 0.004, 518 "lr_decay": _DEFAULT_LR_DECAY, 519 "batch_size": _DEFAULT_BATCH_SIZE, 520 "seed": 0, 521 } 522 523 def get_mrstft_fit(self) -> bool: 524 """ 525 Use a pre-emphasized multi-resolution shot-time Fourier transform loss during 526 training. 527 528 This improves agreement in the high frequencies, usually with a minimial loss in 529 ESR. 530 """ 531 # Leave this as a public method to anticipate an extension to make it 532 # changeable. 533 return True 534 535 def _check_button_states(self): 536 """ 537 Determine if any buttons should be disabled 538 """ 539 # Train button is disabled unless all paths are set 540 if any( 541 pb.val is None 542 for pb in ( 543 self._widgets[_GUIWidgets.INPUT_PATH], 544 self._widgets[_GUIWidgets.OUTPUT_PATH], 545 self._widgets[_GUIWidgets.TRAINING_DESTINATION], 546 ) 547 ): 548 self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.DISABLED 549 return 550 self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.NORMAL 551 552 def _get_additional_options_frame(self): 553 # Checkboxes 554 # TODO get these definitions into __init__() 555 self._frame_checkboxes = _tk.Frame(self._root) 556 self._frame_checkboxes.pack(side=_tk.LEFT) 557 row = 1 558 559 def make_checkbox( 560 key: _CheckboxKeys, text: str, default_value: bool 561 ) -> Checkbox: 562 variable = _tk.BooleanVar() 563 variable.set(default_value) 564 check_button = _tk.Checkbutton( 565 self._frame_checkboxes, text=text, variable=variable 566 ) 567 self._checkboxes[key] = Checkbox(variable, check_button) 568 self._widgets[key] = check_button # For tracking in set-all-widgets ops 569 570 self._checkboxes: _Dict[_CheckboxKeys, Checkbox] = dict() 571 make_checkbox( 572 _CheckboxKeys.SILENT_TRAINING, 573 "Silent run (suggested for batch training)", 574 False, 575 ) 576 make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True) 577 578 # Grid them: 579 row = 1 580 for v in self._checkboxes.values(): 581 v.check_button.grid(row=row, column=1, sticky="W") 582 row += 1 583 584 def mainloop(self): 585 self._root.mainloop() 586 587 def _disable(self): 588 self._set_all_widget_states_to(_tk.DISABLED) 589 590 def _open_advanced_options(self): 591 """ 592 Open window for advanced options 593 """ 594 595 self._wait_while_func(lambda resume: AdvancedOptionsGUI(resume, self)) 596 597 def _open_metadata(self): 598 """ 599 Open window for metadata 600 """ 601 602 self._wait_while_func(lambda resume: UserMetadataGUI(resume, self)) 603 604 def _pack_update_button(self, version_from: _Version, version_to: _Version): 605 """ 606 Pack a button that a user can click to update 607 """ 608 609 def update_nam(): 610 result = _subprocess.run( 611 [ 612 f"{_sys.executable}", 613 "-m", 614 "pip", 615 "install", 616 "--upgrade", 617 "neural-amp-modeler", 618 ] 619 ) 620 if result.returncode == 0: 621 self._wait_while_func( 622 (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), 623 "Update complete! Restart NAM for changes to take effect.", 624 ) 625 else: 626 self._wait_while_func( 627 (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), 628 "Update failed! See logs.", 629 ) 630 631 self._widgets[_GUIWidgets.UPDATE] = _tk.Button( 632 self._frame_update, 633 text=f"Update ({str(version_from)} -> {str(version_to)})", 634 width=_BUTTON_WIDTH, 635 height=_BUTTON_HEIGHT, 636 command=update_nam, 637 ) 638 self._widgets[_GUIWidgets.UPDATE].pack() 639 640 def _pack_update_button_if_update_is_available(self): 641 class UpdateInfo(_NamedTuple): 642 available: bool 643 current_version: _Version 644 new_version: _Optional[_Version] 645 646 def get_info() -> UpdateInfo: 647 # TODO error handling 648 url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases" 649 current_version = _get_current_version() 650 try: 651 response = _requests.get(url) 652 except _requests.exceptions.ConnectionError: 653 print("WARNING: Failed to reach the server to check for updates") 654 return UpdateInfo( 655 available=False, current_version=current_version, new_version=None 656 ) 657 if response.status_code != 200: 658 print(f"Failed to fetch releases. Status code: {response.status_code}") 659 return UpdateInfo( 660 available=False, current_version=current_version, new_version=None 661 ) 662 else: 663 releases = response.json() 664 latest_version = None 665 if releases: 666 for release in releases: 667 tag = release["tag_name"] 668 if not tag.startswith("v"): 669 print(f"Found invalid version {tag}") 670 else: 671 this_version = _Version.from_string(tag[1:]) 672 if latest_version is None or this_version > latest_version: 673 latest_version = this_version 674 else: 675 print("No releases found for this repository.") 676 update_available = ( 677 latest_version is not None and latest_version > current_version 678 ) 679 return UpdateInfo( 680 available=update_available, 681 current_version=current_version, 682 new_version=latest_version, 683 ) 684 685 update_info = get_info() 686 if update_info.available: 687 self._pack_update_button( 688 update_info.current_version, update_info.new_version 689 ) 690 691 def _resume(self): 692 self._set_all_widget_states_to(_tk.NORMAL) 693 self._check_button_states() 694 695 def _set_all_widget_states_to(self, state): 696 for widget in self._widgets.values(): 697 widget["state"] = state 698 699 def _train(self): 700 input_path = self._widgets[_GUIWidgets.INPUT_PATH].val 701 output_paths = self._widgets[_GUIWidgets.OUTPUT_PATH].val 702 # Validate all files before running: 703 success = self._validate_all_data(input_path, output_paths) 704 if success: 705 self._train2() 706 707 def _train2(self, ignore_checks=False): 708 input_path = self._widgets[_GUIWidgets.INPUT_PATH].val 709 710 # Advanced options: 711 num_epochs = self.advanced_options.num_epochs 712 architecture = self.advanced_options.architecture 713 user_latency = self.advanced_options.latency 714 file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val 715 threshold_esr = self.advanced_options.threshold_esr 716 717 # Run it 718 for file in file_list: 719 print(f"Now training {file}") 720 basename = _re.sub(r"\.wav$", "", file.split("/")[-1]) 721 user_metadata = ( 722 self.user_metadata if self.user_metadata_flag else _UserMetadata() 723 ) 724 725 train_output = _core.train( 726 input_path, 727 file, 728 self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, 729 epochs=num_epochs, 730 latency=user_latency, 731 architecture=architecture, 732 silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), 733 save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), 734 modelname=basename, 735 ignore_checks=ignore_checks, 736 local=True, 737 fit_mrstft=self.get_mrstft_fit(), 738 threshold_esr=threshold_esr, 739 user_metadata=user_metadata, 740 **self.core_train_kwargs(), 741 ) 742 743 if train_output.model is None: 744 print("Model training failed! Skip exporting...") 745 continue 746 print("Model training complete!") 747 print("Exporting...") 748 outdir = self._widgets[_GUIWidgets.TRAINING_DESTINATION].val 749 print(f"Exporting trained model to {outdir}...") 750 train_output.model.net.export( 751 outdir, 752 basename=basename, 753 user_metadata=user_metadata, 754 other_metadata={ 755 _metadata.TRAINING_KEY: train_output.metadata.model_dump() 756 }, 757 ) 758 print("Done!") 759 760 # Metadata was only valid for 1 run (possibly a batch), so make sure it's not 761 # used again unless the user re-visits the window and clicks "ok". 762 self.user_metadata_flag = False 763 764 def _validate_all_data( 765 self, input_path: _Path, output_paths: _Sequence[_Path] 766 ) -> bool: 767 """ 768 Validate all the data. 769 If something doesn't pass, then alert the user and ask them whether they 770 want to continue. 771 772 :return: whether we passed (NOTE: Training in spite of failure is 773 triggered by a modal that is produced on failure.) 774 """ 775 776 def make_message_for_file( 777 output_path: str, validation_output: _core.DataValidationOutput 778 ) -> str: 779 """ 780 State the file and explain what's wrong with it. 781 """ 782 # TODO put this closer to what it looks at, i.e. core.DataValidationOutput 783 msg = ( 784 f"\t{_Path(output_path).name}:\n" # They all have the same directory so 785 ) 786 if not validation_output.sample_rate.passed: 787 msg += ( 788 "\t\t There are different sample rates for the input (" 789 f"{validation_output.sample_rate.input}) and output (" 790 f"{validation_output.sample_rate.output}).\n" 791 ) 792 if not validation_output.length.passed: 793 msg += ( 794 "\t\t* The input and output audio files are too different in length" 795 ) 796 if validation_output.length.delta_seconds > 0: 797 msg += ( 798 f" (the output is {validation_output.length.delta_seconds:.2f} " 799 "seconds longer than the input)\n" 800 ) 801 else: 802 msg += ( 803 f" (the output is {-validation_output.length.delta_seconds:.2f}" 804 " seconds shorter than the input)\n" 805 ) 806 if validation_output.latency.manual is None: 807 if validation_output.latency.calibration.warnings.matches_lookahead: 808 msg += ( 809 "\t\t* The calibrated latency is the maximum allowed. This is " 810 "probably because the latency calibration was triggered by noise.\n" 811 ) 812 if validation_output.latency.calibration.warnings.disagreement_too_high: 813 msg += "\t\t* The calculated latencies are too different from each other.\n" 814 if not validation_output.checks.passed: 815 msg += "\t\t* A data check failed (TODO in more detail).\n" 816 if not validation_output.pytorch.passed: 817 msg += "\t\t* PyTorch data set errors:\n" 818 for split in _Split: 819 split_validation = getattr(validation_output.pytorch, split.value) 820 if not split_validation.passed: 821 msg += f" * {split.value:10s}: {split_validation.msg}\n" 822 return msg 823 824 # Validate input 825 input_validation = _core.validate_input(input_path) 826 if not input_validation.passed: 827 self._wait_while_func( 828 (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), 829 f"Input file {input_path} is not recognized as a standardized input " 830 "file.\nTraining cannot proceed.", 831 ) 832 return False 833 834 user_latency = self.advanced_options.latency 835 file_validation_outputs = { 836 output_path: _core.validate_data( 837 input_path, 838 output_path, 839 user_latency, 840 ) 841 for output_path in output_paths 842 } 843 if any(not fv.passed for fv in file_validation_outputs.values()): 844 msg = "The following output files failed checks:\n" + "".join( 845 [ 846 make_message_for_file(output_path, fv) 847 for output_path, fv in file_validation_outputs.items() 848 if not fv.passed 849 ] 850 ) 851 if all(fv.passed_critical for fv in file_validation_outputs.values()): 852 msg += "\nIgnore and proceed?" 853 854 # Hacky to listen to the modal: 855 modal_listener = {"proceed": False, "still_open": True} 856 857 def on_yes(): 858 modal_listener["proceed"] = True 859 860 def on_no(): 861 modal_listener["proceed"] = False 862 863 def on_close(): 864 if modal_listener["proceed"]: 865 self._train2(ignore_checks=True) 866 867 self._wait_while_func( 868 ( 869 lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal( 870 on_yes, on_no, resume, *args, **kwargs 871 ) 872 ), 873 on_yes=on_yes, 874 on_no=on_no, 875 msg=msg, 876 on_close=on_close, 877 label_kwargs={"justify": "left"}, 878 ) 879 return False # we still failed checks so say so. 880 else: 881 msg += "\nCritical errors found, cannot ignore." 882 self._wait_while_func( 883 lambda resume, msg, **kwargs: _OkModal(resume, msg, **kwargs), 884 msg=msg, 885 label_kwargs={"justify": "left"}, 886 ) 887 return False 888 889 return True 890 891 def _wait_while_func(self, func, *args, **kwargs): 892 """ 893 Disable this GUI while something happens. 894 That function _needs_ to call the provided self._resume when it's ready to 895 release me! 896 """ 897 self._disable() 898 func(self._resume, *args, **kwargs) 899 900 901 # some typing functions 902 def _non_negative_int(val): 903 val = int(val) 904 if val < 0: 905 val = 0 906 return val 907 908 909 class _TypeOrNull(object): 910 def __init__(self, T, null_str=""): 911 """ 912 :param T: tpe to cast to on .forward() 913 """ 914 self._T = T 915 self._null_str = null_str 916 917 @property 918 def null_str(self) -> str: 919 """ 920 What str is displayed when for "None" 921 """ 922 return self._null_str 923 924 def forward(self, val: str): 925 val = val.rstrip() 926 return None if val == self._null_str else self._T(val) 927 928 def inverse(self, val) -> str: 929 return self._null_str if val is None else str(val) 930 931 932 _int_or_null = _TypeOrNull(int) 933 _float_or_null = _TypeOrNull(float) 934 935 936 def _rstripped_str(val): 937 return str(val).rstrip() 938 939 940 class _SettingWidget(_abc.ABC): 941 """ 942 A widget for the user to interact with to set something 943 """ 944 945 @_abc.abstractmethod 946 def get(self): 947 pass 948 949 950 class LabeledOptionMenu(_SettingWidget): 951 """ 952 Label (left) and radio buttons (right) 953 """ 954 955 def __init__( 956 self, 957 frame: _tk.Frame, 958 label: str, 959 choices: _Enum, 960 default: _Optional[_Enum] = None, 961 ): 962 """ 963 :param command: Called to propagate option selection. Is provided with the 964 value corresponding to the radio button selected. 965 """ 966 self._frame = frame 967 self._choices = choices 968 height = _BUTTON_HEIGHT 969 bg = None 970 self._label = _tk.Label( 971 frame, 972 width=_ADVANCED_OPTIONS_LEFT_WIDTH, 973 height=height, 974 bg=bg, 975 anchor="w", 976 text=label, 977 ) 978 self._label.pack(side=_tk.LEFT) 979 980 frame_menu = _tk.Frame(frame) 981 frame_menu.pack(side=_tk.RIGHT) 982 983 self._selected_value = None 984 default = (list(choices)[0] if default is None else default).value 985 self._menu = _tk.OptionMenu( 986 frame_menu, 987 _tk.StringVar(master=frame, value=default, name=label), 988 # default, 989 *[choice.value for choice in choices], # if choice.value!=default], 990 command=self._set, 991 ) 992 self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH) 993 self._menu.pack(side=_tk.RIGHT) 994 # Initialize 995 self._set(default) 996 997 def get(self) -> _Enum: 998 return self._selected_value 999 1000 def _set(self, val: str): 1001 """ 1002 Set the value selected 1003 """ 1004 self._selected_value = self._choices(val) 1005 1006 1007 class _Hovertip(Hovertip): 1008 """ 1009 Adjustments: 1010 1011 * Always black text (macOS) 1012 """ 1013 1014 def showcontents(self): 1015 # Override 1016 label = _tk.Label( 1017 self.tipwindow, 1018 text=self.text, 1019 justify=_tk.LEFT, 1020 background="#ffffe0", 1021 relief=_tk.SOLID, 1022 borderwidth=1, 1023 fg="black", 1024 ) 1025 label.pack() 1026 1027 1028 class LabeledText(_SettingWidget): 1029 """ 1030 Label (left) and text input (right) 1031 """ 1032 1033 def __init__( 1034 self, 1035 frame: _tk.Frame, 1036 label: str, 1037 default=None, 1038 type=None, 1039 left_width=_ADVANCED_OPTIONS_LEFT_WIDTH, 1040 right_width=_ADVANCED_OPTIONS_RIGHT_WIDTH, 1041 ): 1042 """ 1043 :param command: Called to propagate option selection. Is provided with the 1044 value corresponding to the radio button selected. 1045 :param type: If provided, casts value to given type 1046 :param left_width: How much space to use on the left side (text) 1047 :param right_width: How much space for the Text field 1048 """ 1049 self._frame = frame 1050 label_height = 2 1051 text_height = 1 1052 self._label = _tk.Label( 1053 frame, 1054 width=left_width, 1055 height=label_height, 1056 bg=None, 1057 anchor="e", 1058 text=label, 1059 ) 1060 self._label.pack(side=_tk.LEFT) 1061 1062 self._text = _tk.Text( 1063 frame, 1064 width=right_width, 1065 height=text_height, 1066 bg=None, 1067 ) 1068 self._text.pack(side=_tk.RIGHT) 1069 1070 self._type = (lambda x: x) if type is None else type 1071 1072 if default is not None: 1073 self._text.insert("1.0", str(default)) 1074 1075 # You can assign a tooltip for the label if you'd like. 1076 self.label_tooltip: _Optional[_Hovertip] = None 1077 1078 @property 1079 def label(self) -> _tk.Label: 1080 return self._label 1081 1082 def get(self): 1083 """ 1084 Attempt to get and return the value. 1085 May throw a tk.TclError indicating something went wrong getting the value. 1086 """ 1087 # "1.0" means Line 1, character zero (wat) 1088 return self._type(self._text.get("1.0", _tk.END)) 1089 1090 1091 class AdvancedOptionsGUI(object): 1092 """ 1093 A window to hold advanced options (Architecture and number of epochs) 1094 """ 1095 1096 def __init__(self, resume_main, parent: GUI): 1097 self._parent = parent 1098 self._root = _TopLevelWithOk(self.apply, resume_main) 1099 self._root.title("Advanced Options") 1100 1101 self.pack() 1102 1103 # "Ok": apply and destroy 1104 self._frame_ok = _tk.Frame(self._root) 1105 self._frame_ok.pack() 1106 self._button_ok = _tk.Button( 1107 self._frame_ok, 1108 text="Ok", 1109 width=_BUTTON_WIDTH, 1110 height=_BUTTON_HEIGHT, 1111 command=lambda: self._root.destroy(pressed_ok=True), 1112 ) 1113 self._button_ok.pack() 1114 1115 def apply(self): 1116 """ 1117 Set values to parent and destroy this object 1118 """ 1119 1120 def safe_apply(name): 1121 try: 1122 setattr( 1123 self._parent.advanced_options, name, getattr(self, "_" + name).get() 1124 ) 1125 except ValueError: 1126 pass 1127 1128 # TODO could clean up more / see `.pack_options()` 1129 for name in ("architecture", "num_epochs", "latency", "threshold_esr"): 1130 safe_apply(name) 1131 1132 def pack(self): 1133 # TODO things that are `_SettingWidget`s are named carefully, need to make this 1134 # easier to work with. 1135 1136 # Architecture: radio buttons 1137 self._frame_architecture = _tk.Frame(self._root) 1138 self._frame_architecture.pack() 1139 self._architecture = LabeledOptionMenu( 1140 self._frame_architecture, 1141 "Architecture", 1142 _core.Architecture, 1143 default=self._parent.advanced_options.architecture, 1144 ) 1145 1146 # Number of epochs: text box 1147 self._frame_epochs = _tk.Frame(self._root) 1148 self._frame_epochs.pack() 1149 1150 self._num_epochs = LabeledText( 1151 self._frame_epochs, 1152 "Epochs", 1153 default=str(self._parent.advanced_options.num_epochs), 1154 type=_non_negative_int, 1155 ) 1156 1157 # Delay: text box 1158 self._frame_latency = _tk.Frame(self._root) 1159 self._frame_latency.pack() 1160 1161 self._latency = LabeledText( 1162 self._frame_latency, 1163 "Reamp latency", 1164 default=_int_or_null.inverse(self._parent.advanced_options.latency), 1165 type=_int_or_null.forward, 1166 ) 1167 1168 # Threshold ESR 1169 self._frame_threshold_esr = _tk.Frame(self._root) 1170 self._frame_threshold_esr.pack() 1171 self._threshold_esr = LabeledText( 1172 self._frame_threshold_esr, 1173 "Threshold ESR", 1174 default=_float_or_null.inverse(self._parent.advanced_options.threshold_esr), 1175 type=_float_or_null.forward, 1176 ) 1177 1178 1179 class UserMetadataGUI(object): 1180 # Things that are auto-filled: 1181 # Model date 1182 # gain 1183 def __init__(self, resume_main, parent: GUI): 1184 self._parent = parent 1185 self._root = _TopLevelWithOk(self.apply, resume_main) 1186 self._root.title("Metadata") 1187 1188 # Pack all the widgets 1189 self.pack() 1190 1191 # "Ok": apply and destroy 1192 self._frame_ok = _tk.Frame(self._root) 1193 self._frame_ok.pack() 1194 self._button_ok = _tk.Button( 1195 self._frame_ok, 1196 text="Ok", 1197 width=_BUTTON_WIDTH, 1198 height=_BUTTON_HEIGHT, 1199 command=lambda: self._root.destroy(pressed_ok=True), 1200 ) 1201 self._button_ok.pack() 1202 1203 def apply(self): 1204 """ 1205 Set values to parent and destroy this object 1206 """ 1207 1208 def safe_apply(name): 1209 try: 1210 setattr( 1211 self._parent.user_metadata, name, getattr(self, "_" + name).get() 1212 ) 1213 except ValueError: 1214 pass 1215 1216 # TODO could clean up more / see `.pack()` 1217 for name in ( 1218 "name", 1219 "modeled_by", 1220 "gear_make", 1221 "gear_model", 1222 "gear_type", 1223 "tone_type", 1224 "input_level_dbu", 1225 "output_level_dbu", 1226 ): 1227 safe_apply(name) 1228 self._parent.user_metadata_flag = True 1229 1230 def pack(self): 1231 # TODO things that are `_SettingWidget`s are named carefully, need to make this 1232 # easier to work with. 1233 1234 LabeledText_ = _partial( 1235 LabeledText, 1236 left_width=_METADATA_LEFT_WIDTH, 1237 right_width=_METADATA_RIGHT_WIDTH, 1238 ) 1239 parent = self._parent 1240 1241 # Name 1242 self._frame_name = _tk.Frame(self._root) 1243 self._frame_name.pack() 1244 self._name = LabeledText_( 1245 self._frame_name, 1246 "NAM name", 1247 default=parent.user_metadata.name, 1248 type=_rstripped_str, 1249 ) 1250 # Modeled by 1251 self._frame_modeled_by = _tk.Frame(self._root) 1252 self._frame_modeled_by.pack() 1253 self._modeled_by = LabeledText_( 1254 self._frame_modeled_by, 1255 "Modeled by", 1256 default=parent.user_metadata.modeled_by, 1257 type=_rstripped_str, 1258 ) 1259 # Gear make 1260 self._frame_gear_make = _tk.Frame(self._root) 1261 self._frame_gear_make.pack() 1262 self._gear_make = LabeledText_( 1263 self._frame_gear_make, 1264 "Gear make", 1265 default=parent.user_metadata.gear_make, 1266 type=_rstripped_str, 1267 ) 1268 # Gear model 1269 self._frame_gear_model = _tk.Frame(self._root) 1270 self._frame_gear_model.pack() 1271 self._gear_model = LabeledText_( 1272 self._frame_gear_model, 1273 "Gear model", 1274 default=parent.user_metadata.gear_model, 1275 type=_rstripped_str, 1276 ) 1277 # Calibration: input & output dBu 1278 self._frame_input_dbu = _tk.Frame(self._root) 1279 self._frame_input_dbu.pack() 1280 self._input_level_dbu = LabeledText_( 1281 self._frame_input_dbu, 1282 "Reamp send level (dBu)", 1283 default=_float_or_null.inverse(parent.user_metadata.input_level_dbu), 1284 type=_float_or_null.forward, 1285 ) 1286 self._input_level_dbu.label_tooltip = _Hovertip( 1287 anchor_widget=self._input_level_dbu.label, 1288 text=( 1289 "(Ok to leave blank)\n\n" 1290 "Play a sine wave with frequency 1kHz and peak amplitude 0dBFS. Use\n" 1291 "a multimeter to measure the RMS voltage of the signal at the jack\n" 1292 "that connects to your gear, and convert to dBu.\n" 1293 "Record the value here." 1294 ), 1295 ) 1296 self._frame_output_dbu = _tk.Frame(self._root) 1297 self._frame_output_dbu.pack() 1298 self._output_level_dbu = LabeledText_( 1299 self._frame_output_dbu, 1300 "Reamp return level (dBu)", 1301 default=_float_or_null.inverse(parent.user_metadata.output_level_dbu), 1302 type=_float_or_null.forward, 1303 ) 1304 self._output_level_dbu.label_tooltip = _Hovertip( 1305 anchor_widget=self._output_level_dbu.label, 1306 text=( 1307 "(Ok to leave blank)\n\n" 1308 "Play a sine wave with frequency 1kHz into your interface where\n" 1309 "you're recording your gear. Keeping the interface's input gain\n" 1310 "trimmed as you will use it when recording, adjust the sine wave\n" 1311 "until the input peaks at exactly 0dBFS in your DAW. Measure the RMS\n" 1312 "voltage and convert to dBu.\n" 1313 "Record the value here." 1314 ), 1315 ) 1316 # Gear type 1317 self._frame_gear_type = _tk.Frame(self._root) 1318 self._frame_gear_type.pack() 1319 self._gear_type = LabeledOptionMenu( 1320 self._frame_gear_type, 1321 "Gear type", 1322 _GearType, 1323 default=parent.user_metadata.gear_type, 1324 ) 1325 # Tone type 1326 self._frame_tone_type = _tk.Frame(self._root) 1327 self._frame_tone_type.pack() 1328 self._tone_type = LabeledOptionMenu( 1329 self._frame_tone_type, 1330 "Tone type", 1331 _ToneType, 1332 default=parent.user_metadata.tone_type, 1333 ) 1334 1335 1336 def _install_error(): 1337 window = _tk.Tk() 1338 window.title("ERROR") 1339 label = _tk.Label( 1340 window, 1341 width=45, 1342 height=2, 1343 text="The NAM training software has not been installed correctly.", 1344 ) 1345 label.pack() 1346 button = _tk.Button(window, width=10, height=2, text="Quit", command=window.destroy) 1347 button.pack() 1348 window.mainloop() 1349 1350 1351 def run(): 1352 if _install_is_valid: 1353 _gui = GUI() 1354 _gui.mainloop() 1355 print("Shut down NAM trainer") 1356 else: 1357 _install_error() 1358 1359 1360 if __name__ == "__main__": 1361 run()