commit 625aa8a622296d2b8d4378b6519f9dfe26c6d746
parent 26fdad726c47e0035c79a395634ec57b4c427f0c
Author: Steven Atkinson <steven@atkinson.mn>
Date: Thu, 19 Sep 2024 08:29:12 -0700
[FEATURE] Automatically check for updates (#467)
* Button for update when available
* Rearrange methods, add shut down print statement
* Cleanup imports
* Function for getting current version
* Handle when there's no internet
Diffstat:
3 files changed, 111 insertions(+), 6 deletions(-)
diff --git a/nam/train/_version.py b/nam/train/_version.py
@@ -6,6 +6,8 @@
Version utility
"""
+from .._version import __version__
+
class Version:
def __init__(self, major: int, minor: int, patch: int):
@@ -13,6 +15,11 @@ class Version:
self.minor = minor
self.patch = patch
+ @classmethod
+ def from_string(cls, s: str):
+ major, minor, patch = [int(x) for x in s.split(".")]
+ return cls(major, minor, patch)
+
def __eq__(self, other) -> bool:
return (
self.major == other.major
@@ -21,6 +28,8 @@ class Version:
)
def __lt__(self, other) -> bool:
+ if self == other:
+ return False
if self.major != other.major:
return self.major < other.major
if self.minor != other.minor:
@@ -33,3 +42,7 @@ class Version:
PROTEUS_VERSION = Version(4, 0, 0)
+
+
+def get_current_version() -> Version:
+ return Version.from_string(__version__)
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -11,7 +11,9 @@ Usage:
"""
import re
+import requests
import tkinter as tk
+import subprocess
import sys
import webbrowser
from dataclasses import dataclass
@@ -19,7 +21,7 @@ from enum import Enum
from functools import partial
from pathlib import Path
from tkinter import filedialog
-from typing import Callable, Dict, Optional, Sequence
+from typing import Callable, Dict, NamedTuple, Optional, Sequence
try: # 3rd-party and 1st-party imports
import torch
@@ -33,7 +35,7 @@ try: # 3rd-party and 1st-party imports
# Ok private access here--this is technically allowed access
from nam.train import metadata
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
- from nam.train.metadata import TRAINING_KEY
+ from nam.train._version import Version, get_current_version
_install_is_valid = True
_HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
@@ -384,6 +386,7 @@ class _GUIWidgets(Enum):
METADATA = "metadata"
ADVANCED_OPTIONS = "advanced_options"
TRAIN = "train"
+ UPDATE = "update"
class _GUI(object):
@@ -446,7 +449,9 @@ class _GUI(object):
# Last frames: avdanced options & train in the SE corner:
self._frame_advanced_options = tk.Frame(self._root)
self._frame_train = tk.Frame(self._root)
- # Pack train first so that it's on bottom.
+ self._frame_update = tk.Frame(self._root)
+ # Pack must be in reverse order
+ self._frame_update.pack(side=tk.BOTTOM, anchor="e")
self._frame_train.pack(side=tk.BOTTOM, anchor="e")
self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e")
@@ -481,6 +486,8 @@ class _GUI(object):
)
self._widgets[_GUIWidgets.TRAIN].pack()
+ self._pack_update_button_if_update_is_available()
+
self._check_button_states()
def get_mrstft_fit(self) -> bool:
@@ -569,6 +576,93 @@ class _GUI(object):
self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self))
+ def _pack_update_button(self, version_from: Version, version_to: Version):
+ """
+ Pack a button that a user can click to update
+ """
+
+ def update_nam():
+ result = subprocess.run(
+ [
+ f"{sys.executable}",
+ "-m",
+ "pip",
+ "install",
+ "--upgrade",
+ "neural-amp-modeler",
+ ]
+ )
+ if result.returncode == 0:
+ self._wait_while_func(
+ (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
+ "Update complete! Restart NAM for changes to take effect.",
+ )
+ else:
+ self._wait_while_func(
+ (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
+ "Update failed! See logs.",
+ )
+
+ self._widgets[_GUIWidgets.UPDATE] = tk.Button(
+ self._frame_update,
+ text=f"Update ({str(version_from)} -> {str(version_to)})",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ command=update_nam,
+ )
+ self._widgets[_GUIWidgets.UPDATE].pack()
+
+ def _pack_update_button_if_update_is_available(self):
+ class UpdateInfo(NamedTuple):
+ available: bool
+ current_version: Version
+ new_version: Optional[Version]
+
+ def get_info() -> UpdateInfo:
+ # TODO error handling
+ url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases"
+ current_version = get_current_version()
+ try:
+ response = requests.get(url)
+ except requests.exceptions.ConnectionError:
+ print("WARNING: Failed to reach the server to check for updates")
+ return UpdateInfo(
+ available=False, current_version=current_version, new_version=None
+ )
+ if response.status_code != 200:
+ print(f"Failed to fetch releases. Status code: {response.status_code}")
+ return UpdateInfo(
+ available=False, current_version=current_version, new_version=None
+ )
+ else:
+ releases = response.json()
+ latest_version = None
+ if releases:
+ for release in releases:
+ tag = release["tag_name"]
+ if not tag.startswith("v"):
+ print(f"Found invalid version {tag}")
+ else:
+ this_version = Version.from_string(tag[1:])
+ if latest_version is None or this_version > latest_version:
+ latest_version = this_version
+ else:
+ print("No releases found for this repository.")
+ update_available = (
+ latest_version is not None and latest_version > current_version
+ )
+ return UpdateInfo(
+ available=update_available,
+ current_version=current_version,
+ new_version=latest_version,
+ )
+
+ update_info = get_info()
+ if update_info.available:
+ self._pack_update_button(
+ update_info.current_version, update_info.new_version
+ )
+
def _resume(self):
self._set_all_widget_states_to(tk.NORMAL)
self._check_button_states()
@@ -1092,6 +1186,7 @@ def run():
if _install_is_valid:
_gui = _GUI()
_gui.mainloop()
+ print("Shut down NAM trainer")
else:
_install_error()
diff --git a/tests/test_nam/test_train/test_gui/test_main.py b/tests/test_nam/test_train/test_gui/test_main.py
@@ -2,10 +2,7 @@
# Created Date: Friday May 24th 2024
# Author: Steven Atkinson (steven@atkinson.mn)
-import importlib
-import os
import tkinter as tk
-from pathlib import Path
import pytest