neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 625aa8a622296d2b8d4378b6519f9dfe26c6d746
parent 26fdad726c47e0035c79a395634ec57b4c427f0c
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Thu, 19 Sep 2024 08:29:12 -0700

[FEATURE] Automatically check for updates (#467)

* Button for update when available

* Rearrange methods, add shut down print statement

* Cleanup imports

* Function for getting current version

* Handle when there's no internet
Diffstat:
Mnam/train/_version.py | 13+++++++++++++
Mnam/train/gui/__init__.py | 101++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
Mtests/test_nam/test_train/test_gui/test_main.py | 3---
3 files changed, 111 insertions(+), 6 deletions(-)

diff --git a/nam/train/_version.py b/nam/train/_version.py @@ -6,6 +6,8 @@ Version utility """ +from .._version import __version__ + class Version: def __init__(self, major: int, minor: int, patch: int): @@ -13,6 +15,11 @@ class Version: self.minor = minor self.patch = patch + @classmethod + def from_string(cls, s: str): + major, minor, patch = [int(x) for x in s.split(".")] + return cls(major, minor, patch) + def __eq__(self, other) -> bool: return ( self.major == other.major @@ -21,6 +28,8 @@ class Version: ) def __lt__(self, other) -> bool: + if self == other: + return False if self.major != other.major: return self.major < other.major if self.minor != other.minor: @@ -33,3 +42,7 @@ class Version: PROTEUS_VERSION = Version(4, 0, 0) + + +def get_current_version() -> Version: + return Version.from_string(__version__) diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -11,7 +11,9 @@ Usage: """ import re +import requests import tkinter as tk +import subprocess import sys import webbrowser from dataclasses import dataclass @@ -19,7 +21,7 @@ from enum import Enum from functools import partial from pathlib import Path from tkinter import filedialog -from typing import Callable, Dict, Optional, Sequence +from typing import Callable, Dict, NamedTuple, Optional, Sequence try: # 3rd-party and 1st-party imports import torch @@ -33,7 +35,7 @@ try: # 3rd-party and 1st-party imports # Ok private access here--this is technically allowed access from nam.train import metadata from nam.train._names import INPUT_BASENAMES, LATEST_VERSION - from nam.train.metadata import TRAINING_KEY + from nam.train._version import Version, get_current_version _install_is_valid = True _HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available() @@ -384,6 +386,7 @@ class _GUIWidgets(Enum): METADATA = "metadata" ADVANCED_OPTIONS = "advanced_options" TRAIN = "train" + UPDATE = "update" class _GUI(object): @@ -446,7 +449,9 @@ class _GUI(object): # Last frames: avdanced options & train in the SE corner: self._frame_advanced_options = tk.Frame(self._root) self._frame_train = tk.Frame(self._root) - # Pack train first so that it's on bottom. + self._frame_update = tk.Frame(self._root) + # Pack must be in reverse order + self._frame_update.pack(side=tk.BOTTOM, anchor="e") self._frame_train.pack(side=tk.BOTTOM, anchor="e") self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e") @@ -481,6 +486,8 @@ class _GUI(object): ) self._widgets[_GUIWidgets.TRAIN].pack() + self._pack_update_button_if_update_is_available() + self._check_button_states() def get_mrstft_fit(self) -> bool: @@ -569,6 +576,93 @@ class _GUI(object): self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self)) + def _pack_update_button(self, version_from: Version, version_to: Version): + """ + Pack a button that a user can click to update + """ + + def update_nam(): + result = subprocess.run( + [ + f"{sys.executable}", + "-m", + "pip", + "install", + "--upgrade", + "neural-amp-modeler", + ] + ) + if result.returncode == 0: + self._wait_while_func( + (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), + "Update complete! Restart NAM for changes to take effect.", + ) + else: + self._wait_while_func( + (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), + "Update failed! See logs.", + ) + + self._widgets[_GUIWidgets.UPDATE] = tk.Button( + self._frame_update, + text=f"Update ({str(version_from)} -> {str(version_to)})", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + command=update_nam, + ) + self._widgets[_GUIWidgets.UPDATE].pack() + + def _pack_update_button_if_update_is_available(self): + class UpdateInfo(NamedTuple): + available: bool + current_version: Version + new_version: Optional[Version] + + def get_info() -> UpdateInfo: + # TODO error handling + url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases" + current_version = get_current_version() + try: + response = requests.get(url) + except requests.exceptions.ConnectionError: + print("WARNING: Failed to reach the server to check for updates") + return UpdateInfo( + available=False, current_version=current_version, new_version=None + ) + if response.status_code != 200: + print(f"Failed to fetch releases. Status code: {response.status_code}") + return UpdateInfo( + available=False, current_version=current_version, new_version=None + ) + else: + releases = response.json() + latest_version = None + if releases: + for release in releases: + tag = release["tag_name"] + if not tag.startswith("v"): + print(f"Found invalid version {tag}") + else: + this_version = Version.from_string(tag[1:]) + if latest_version is None or this_version > latest_version: + latest_version = this_version + else: + print("No releases found for this repository.") + update_available = ( + latest_version is not None and latest_version > current_version + ) + return UpdateInfo( + available=update_available, + current_version=current_version, + new_version=latest_version, + ) + + update_info = get_info() + if update_info.available: + self._pack_update_button( + update_info.current_version, update_info.new_version + ) + def _resume(self): self._set_all_widget_states_to(tk.NORMAL) self._check_button_states() @@ -1092,6 +1186,7 @@ def run(): if _install_is_valid: _gui = _GUI() _gui.mainloop() + print("Shut down NAM trainer") else: _install_error() diff --git a/tests/test_nam/test_train/test_gui/test_main.py b/tests/test_nam/test_train/test_gui/test_main.py @@ -2,10 +2,7 @@ # Created Date: Friday May 24th 2024 # Author: Steven Atkinson (steven@atkinson.mn) -import importlib -import os import tkinter as tk -from pathlib import Path import pytest