neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit f2c3ff91c94bd06b19ed3e0052ade1e4dea02391
parent 9a1c72e8c4559f278cc1bb616e6543afe6f55aca
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 24 Nov 2024 13:18:21 -0800

Better messaging around PyTorch workaround (#509)

Priority is compatibility w/ PyTorch; left a link for folks to learn more if they want to customize.
Diffstat:
M.github/workflows/python-package.yml | 2+-
Denvironment_cpu.yml | 34----------------------------------
Aenvironments/environment_cpu_apple.yml | 38++++++++++++++++++++++++++++++++++++++
Renvironment_gpu.yml -> environments/environment_gpu.yml | 0
Aenvironments/requirements.txt | 25+++++++++++++++++++++++++
Mnam/models/_base.py | 11+++++++++++
Mnam/models/base.py | 2+-
Drequirements.txt | 24------------------------
8 files changed, 76 insertions(+), 60 deletions(-)

diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + if [ -f environments/requirements.txt ]; then pip install -r environments/requirements.txt; fi python -m pip install . - name: Lint with flake8 run: | diff --git a/environment_cpu.yml b/environment_cpu.yml @@ -1,34 +0,0 @@ -# File: environment.yml -# Created Date: Saturday February 13th 2021 -# Author: Steven Atkinson (steven@atkinson.mn) - -name: nam -channels: - - conda-forge # pytest-mock - - pytorch -dependencies: - - python>=3.9 - - black - - flake8 - - h5py - - jupyter - - matplotlib - - numpy<2 - - pip - - pre-commit - - pydantic - - pytest - - pytest-mock - - pytorch - - scipy - - semver - - tensorboard - - tqdm - - wheel - - pip: - - auraloss==0.3.0 - - pytorch_lightning - - sounddevice - - transformers>=4 # See requirements.txt - - wavio >=0.0.5 - - -e . diff --git a/environments/environment_cpu_apple.yml b/environments/environment_cpu_apple.yml @@ -0,0 +1,38 @@ +# File: environment.yml +# Created Date: Saturday February 13th 2021 +# Author: Steven Atkinson (steven@atkinson.mn) + +# Environment for CPU and macOS (Intel and Apple Silicon) + +name: nam +channels: + - conda-forge # pytest-mock + - pytorch +dependencies: + - python>=3.9 + - black + - flake8 + - h5py + - jupyter + - matplotlib + - numpy<2 + - pip + - pre-commit + - pydantic + - pytest + - pytest-mock + # Performance note: + # https://github.com/sdatkinson/neural-amp-modeler/issues/505 + - pytorch + - scipy + - semver + - tensorboard + - tqdm + - wheel + - pip: + - auraloss==0.3.0 + - pytorch_lightning + - sounddevice + - transformers>=4 # See requirements.txt + - wavio >=0.0.5 + - -e . diff --git a/environment_gpu.yml b/environments/environment_gpu.yml diff --git a/environments/requirements.txt b/environments/requirements.txt @@ -0,0 +1,25 @@ +# File: requirements.txt +# Created Date: 2021-01-24 +# Author: Steven Atkinson (steven@atkinson.mn) + +auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss +black +flake8 +matplotlib +numpy<2 +pip +pre-commit +pydantic>=2.0.0 +pytest +pytest-mock +pytorch_lightning +scipy +sounddevice +# Performance note: https://github.com/sdatkinson/neural-amp-modeler/issues/505 +torch +# `transformers` is not required, but if you have it, it needs to be recent +# enough so I'm adding it. +transformers>=4 +tqdm +wavio +wheel diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -209,6 +209,17 @@ class BaseNet(_Base): return self._forward(x, **kwargs) except NotImplementedError as e: if "Output channels > 65536 not supported at the MPS device." in str(e): + print( + "===WARNING===\n" + "NAM encountered a bug in PyTorch's MPS backend and will " + "switch to a fallback.\n" + f"Your version of PyTorch is {torch.__version__}.\n" + "Please report this in an Issue at:\n" + "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose" + "\n" + "so that NAM's dependencies can avoid buggy versions of " + "PyTorch and the associated performance hit." + ) self._mps_65536_fallback = True return self._forward_mps_safe(x, **kwargs) else: diff --git a/nam/models/base.py b/nam/models/base.py @@ -3,7 +3,7 @@ # Author: Steven Atkinson (steven@atkinson.mn) """ -Implements the base PyTorch Lightning model. +Implements the base PyTorch Lightning module. This is meant to combine an actual model (subclassed from `._base.BaseNet`) along with loss function boilerplate. diff --git a/requirements.txt b/requirements.txt @@ -1,24 +0,0 @@ -# File: requirements.txt -# Created Date: 2021-01-24 -# Author: Steven Atkinson (steven@atkinson.mn) - -auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss -black -flake8 -matplotlib -numpy<2 -pip -pre-commit -pydantic>=2.0.0 -pytest -pytest-mock -pytorch_lightning -scipy -sounddevice -torch -# Not required, but if you have it, it needs to be recent enough so I'm adding -# it. -transformers>=4 -tqdm -wavio -wheel