commit 6ede731da4cca4ee2bea76d6d0b38b925d98df93
parent 66c052b3e57169dc0e107e069f00cf987467db40
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 21 Jul 2024 13:08:07 -0700
Extensions (#441)
* Implement extensions for GUI trainer
* Test
* mkdir
* removesuffix
Diffstat:
2 files changed, 115 insertions(+), 1 deletion(-)
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -25,6 +25,56 @@ def _ensure_graceful_shutdowns():
_ensure_graceful_shutdowns()
+
+def _apply_extensions():
+ def removesuffix(s: str, suffix: str) -> str:
+ # Remove once 3.8 is dropped
+ if len(suffix) == 0:
+ return s
+ return s[: -len(suffix)] if s.endswith(suffix) else s
+
+ import importlib
+ import os
+ import sys
+
+ # DRY: Make sure this matches the test!
+ extensions_path = os.path.join(
+ os.environ["HOME"], ".neural-amp-modeler", "extensions"
+ )
+ if not os.path.exists(extensions_path):
+ return
+ if not os.path.isdir(extensions_path):
+ print(
+ f"WARNING: non-directory object found at expected extensions path {extensions_path}; skip"
+ )
+ print("Applying extensions...")
+ if extensions_path not in sys.path:
+ sys.path.append(extensions_path)
+ extensions_path_not_in_sys_path = True
+ else:
+ extensions_path_not_in_sys_path = False
+ for name in os.listdir(extensions_path):
+ if name in {"__pycache__", ".DS_Store"}:
+ continue
+ try:
+ importlib.import_module(removesuffix(name, ".py")) # Runs it
+ print(f" {name} [SUCCESS]")
+ except Exception as e:
+ print(f" {name} [FAILED]")
+ print(e)
+ if extensions_path_not_in_sys_path:
+ for i, p in enumerate(sys.path):
+ if p == extensions_path:
+ sys.path = sys.path[:i] + sys.path[i + 1 :]
+ break
+ else:
+ raise RuntimeError("Failed to remove extensions path from sys.path?")
+ print("Done!")
+
+
+_apply_extensions()
+
+
import re
import tkinter as tk
import sys
@@ -677,7 +727,9 @@ class _GUI(object):
File and explain what's wrong with it.
"""
# TODO put this closer to what it looks at, i.e. core.DataValidationOutput
- msg = f"\t{Path(output_path).name}:\n" # They all have the same directory so
+ msg = (
+ f"\t{Path(output_path).name}:\n" # They all have the same directory so
+ )
if validation_output.latency.manual is None:
if validation_output.latency.calibration.warnings.matches_lookahead:
msg += (
diff --git a/tests/test_nam/test_train/test_gui.py b/tests/test_nam/test_train/test_gui.py
@@ -2,7 +2,10 @@
# Created Date: Friday May 24th 2024
# Author: Steven Atkinson (steven@atkinson.mn)
+import importlib
+import os
import tkinter as tk
+from pathlib import Path
import pytest
@@ -19,5 +22,64 @@ class TestPathButton(object):
label.pack()
+def test_extensions():
+ """
+ Test that we can use a simple extension.
+ """
+ # DRY: Make sure this matches the code!
+ extensions_path = Path(
+ os.path.join(os.environ["HOME"], ".neural-amp-modeler", "extensions")
+ )
+
+ def get_name():
+ i = 0
+ while True:
+ basename = f"test_extension_{i}.py"
+ path = Path(extensions_path, basename)
+ if not path.exists():
+ return path
+ else:
+ i += 1
+
+ path = get_name()
+ path.parent.mkdir(parents=True, exist_ok=True)
+
+ try:
+ # Make the extension
+ # It's going to set an attribute inside nam.core. We'll know the extension worked if
+ # that attr is set.
+ attr_name = "my_test_attr"
+ attr_val = "THIS IS A TEST ATTRIBUTE I SHOULDN'T BE HERE"
+ with open(path, "w") as f:
+ f.writelines(
+ [
+ 'print("RUNNING TEST!")\n',
+ "from nam.train import core\n",
+ f'name = "{attr_name}"\n',
+ "assert not hasattr(core, name)\n"
+ f'setattr(core, name, "{attr_val}")\n',
+ ]
+ )
+
+ # Now trigger the extension by importing from the GUI module:
+ from nam.train import gui # noqa F401
+
+ # If some other test already imported this, then we need to trigger a re-load or
+ # else the extension won't get picked up!
+ importlib.reload(gui)
+
+ # Now let's have a look:
+ from nam.train import core
+
+ assert hasattr(core, attr_name)
+ assert getattr(core, attr_name) == attr_val
+ finally:
+ if path.exists():
+ path.unlink()
+ # You might want to comment that .unlink() and uncomment this if this test isn't
+ # passing and you're struggling:
+ # pass
+
+
if __name__ == "__main__":
pytest.main()