neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

cli.py (3829B)


      1 # File: cli.py
      2 # Created Date: Saturday July 27th 2024
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 """
      6 Command line interface entry points (GUI trainer, full trainer)
      7 """
      8 
      9 
     10 # This must happen first
     11 def _ensure_graceful_shutdowns():
     12     """
     13     Hack to recover graceful shutdowns in Windows.
     14     This has to happen ASAP
     15     See:
     16     https://github.com/sdatkinson/neural-amp-modeler/issues/105
     17     https://stackoverflow.com/a/44822794
     18     """
     19     import os
     20 
     21     if os.name == "nt":  # OS is Windows
     22         os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
     23 
     24 
     25 _ensure_graceful_shutdowns()
     26 
     27 
     28 # This must happen ASAP but not before the graceful shutdown hack
     29 def _apply_extensions():
     30     """
     31     Find and apply extensions to NAM
     32     """
     33 
     34     def removesuffix(s: str, suffix: str) -> str:
     35         # Remove once 3.8 is dropped
     36         if len(suffix) == 0:
     37             return s
     38         return s[: -len(suffix)] if s.endswith(suffix) else s
     39 
     40     import importlib
     41     import os
     42     import sys
     43 
     44     # DRY: Make sure this matches the test!
     45     home_path = os.environ["HOMEPATH"] if os.name == "nt" else os.environ["HOME"]
     46     extensions_path = os.path.join(home_path, ".neural-amp-modeler", "extensions")
     47     if not os.path.exists(extensions_path):
     48         return
     49     if not os.path.isdir(extensions_path):
     50         print(
     51             f"WARNING: non-directory object found at expected extensions path {extensions_path}; skip"
     52         )
     53     print("Applying extensions...")
     54     if extensions_path not in sys.path:
     55         sys.path.append(extensions_path)
     56         extensions_path_not_in_sys_path = True
     57     else:
     58         extensions_path_not_in_sys_path = False
     59     for name in os.listdir(extensions_path):
     60         if name in {"__pycache__", ".DS_Store"}:
     61             continue
     62         try:
     63             importlib.import_module(removesuffix(name, ".py"))  # Runs it
     64             print(f"  {name} [SUCCESS]")
     65         except Exception as e:
     66             print(f"  {name} [FAILED]")
     67             print(e)
     68     if extensions_path_not_in_sys_path:
     69         for i, p in enumerate(sys.path):
     70             if p == extensions_path:
     71                 sys.path = sys.path[:i] + sys.path[i + 1 :]
     72                 break
     73         else:
     74             raise RuntimeError("Failed to remove extensions path from sys.path?")
     75     print("Done!")
     76 
     77 
     78 _apply_extensions()
     79 
     80 import json as _json
     81 from argparse import ArgumentParser as _ArgumentParser
     82 from pathlib import Path as _Path
     83 
     84 from nam.train.full import main as _nam_full
     85 from nam.train.gui import run as nam_gui  # noqa F401 Used as an entry point
     86 from nam.util import timestamp as _timestamp
     87 
     88 
     89 def nam_hello_world():
     90     """
     91     This is a minimal CLI entry point that's meant to be used to ensure that NAM
     92     was installed successfully
     93     """
     94     from nam import __version__
     95     msg = f"""
     96     Neural Amp Modeler
     97 
     98     by Steven Atkinson
     99 
    100     Version {__version__}
    101     """
    102     print(msg)
    103 
    104 
    105 def nam_full():
    106     parser = _ArgumentParser()
    107     parser.add_argument("data_config_path", type=str)
    108     parser.add_argument("model_config_path", type=str)
    109     parser.add_argument("learning_config_path", type=str)
    110     parser.add_argument("outdir")
    111     parser.add_argument("--no-show", action="store_true", help="Don't show plots")
    112 
    113     args = parser.parse_args()
    114 
    115     def ensure_outdir(outdir: str) -> _Path:
    116         outdir = _Path(outdir, _timestamp())
    117         outdir.mkdir(parents=True, exist_ok=False)
    118         return outdir
    119 
    120     outdir = ensure_outdir(args.outdir)
    121     # Read
    122     with open(args.data_config_path, "r") as fp:
    123         data_config = _json.load(fp)
    124     with open(args.model_config_path, "r") as fp:
    125         model_config = _json.load(fp)
    126     with open(args.learning_config_path, "r") as fp:
    127         learning_config = _json.load(fp)
    128     _nam_full(data_config, model_config, learning_config, outdir, args.no_show)