commit 6f9b7245f3b7540b8472cd4454d4ff55d4977604
parent af479901f47a3ae86ac6c697bd78835cb1fe1ed0
Author: Steven Atkinson <steven@atkinson.mn>
Date: Fri, 12 Apr 2024 19:46:14 -0700
[ENHANCEMENT] GUI: Button to download the training input audio file (#397)
Update GUI
Download button
Organize the buttons a bit
Diffstat:
2 files changed, 59 insertions(+), 16 deletions(-)
diff --git a/nam/train/_names.py b/nam/train/_names.py
@@ -14,7 +14,7 @@ class VersionAndName(NamedTuple):
name: str
-# From most the least recently-released
+# From most- to the least-recently-released:
INPUT_BASENAMES = (
VersionAndName(Version(3, 0, 0), "v3_0_0.wav"),
VersionAndName(Version(2, 0, 0), "v2_0_0.wav"),
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -27,6 +27,7 @@ _ensure_graceful_shutdowns()
import re
import tkinter as tk
+import webbrowser
from dataclasses import dataclass
from enum import Enum
from functools import partial
@@ -42,7 +43,7 @@ try: # 3rd-party and 1st-party imports
from nam.models.metadata import GearType, UserMetadata, ToneType
# Ok private access here--this is technically allowed access
- from nam.train._names import LATEST_VERSION
+ from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
_install_is_valid = True
_HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
@@ -101,8 +102,9 @@ class _PathButton(object):
self._info_str = info_str
self._path: Optional[Path] = None
self._path_type = path_type
+ self._frame = frame
self._button = tk.Button(
- frame,
+ self._frame,
text=button_text,
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
@@ -111,14 +113,14 @@ class _PathButton(object):
)
self._button.pack(side=tk.LEFT)
self._label = tk.Label(
- frame,
+ self._frame,
width=_TEXT_WIDTH,
height=_BUTTON_HEIGHT,
fg="black",
bg=None,
anchor="w",
)
- self._label.pack(side=tk.RIGHT)
+ self._label.pack(side=tk.LEFT)
self._hooks = hooks
self._set_text()
@@ -151,6 +153,42 @@ class _PathButton(object):
h()
+class _InputPathButton(_PathButton):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Download the training file!
+ self._button_download_input = tk.Button(
+ self._frame,
+ text="Download input file",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._download_input_file,
+ )
+ self._button_download_input.pack(side=tk.RIGHT)
+
+ @classmethod
+ def _download_input_file(cls):
+ file_urls = {
+ "v3_0_0.wav": "https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link",
+ "v2_0_0.wav": "https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link",
+ "v1_1_1.wav": "",
+ "v1.wav": "",
+ }
+ # Pick the most recent file.
+ for input_basename in INPUT_BASENAMES:
+ name = input_basename.name
+ url = file_urls.get(name)
+ if url:
+ if name != LATEST_VERSION.name:
+ print(
+ f"WARNING: File {name} is out of date. "
+ "This needs to be updated!"
+ )
+ webbrowser.open(url)
+ return
+
+
class _CheckboxKeys(Enum):
"""
Keys for checkboxes
@@ -168,10 +206,10 @@ class _GUI(object):
self._root.title(f"NAM Trainer - v{__version__}")
# Buttons for paths:
- self._frame_input_path = tk.Frame(self._root)
- self._frame_input_path.pack()
- self._path_button_input = _PathButton(
- self._frame_input_path,
+ self._frame_input = tk.Frame(self._root)
+ self._frame_input.pack(anchor="w")
+ self._path_button_input = _InputPathButton(
+ self._frame_input,
"Input Audio",
f"Select input (DI) file (e.g. {LATEST_VERSION.name})",
_PathType.FILE,
@@ -179,7 +217,7 @@ class _GUI(object):
)
self._frame_output_path = tk.Frame(self._root)
- self._frame_output_path.pack()
+ self._frame_output_path.pack(anchor="w")
self._path_button_output = _PathButton(
self._frame_output_path,
"Output Audio",
@@ -189,7 +227,7 @@ class _GUI(object):
)
self._frame_train_destination = tk.Frame(self._root)
- self._frame_train_destination.pack()
+ self._frame_train_destination.pack(anchor="w")
self._path_button_train_destination = _PathButton(
self._frame_train_destination,
"Train Destination",
@@ -201,7 +239,7 @@ class _GUI(object):
# Metadata
self.user_metadata = UserMetadata()
self._frame_metadata = tk.Frame(self._root)
- self._frame_metadata.pack()
+ self._frame_metadata.pack(anchor="w")
self._button_metadata = tk.Button(
self._frame_metadata,
text="Metadata...",
@@ -216,6 +254,13 @@ class _GUI(object):
# This should probably be to the right somewhere
self._get_additional_options_frame()
+ # Last frames: avdanced options & train in the SE corner:
+ self._frame_advanced_options = tk.Frame(self._root)
+ self._frame_train = tk.Frame(self._root)
+ # Pack train first so that it's on bottom.
+ self._frame_train.pack(side=tk.BOTTOM, anchor="e")
+ self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e")
+
# Advanced options for training
default_architecture = core.Architecture.STANDARD
self.advanced_options = _AdvancedOptions(
@@ -225,8 +270,7 @@ class _GUI(object):
_DEFAULT_IGNORE_CHECKS,
)
# Window to edit them:
- self._frame_advanced_options = tk.Frame(self._root)
- self._frame_advanced_options.pack()
+
self._button_advanced_options = tk.Button(
self._frame_advanced_options,
text="Advanced options...",
@@ -238,8 +282,7 @@ class _GUI(object):
self._button_advanced_options.pack()
# Train button
- self._frame_train = tk.Frame(self._root)
- self._frame_train.pack()
+
self._button_train = tk.Button(
self._frame_train,
text="Train",