commit f2c3ff91c94bd06b19ed3e0052ade1e4dea02391
parent 9a1c72e8c4559f278cc1bb616e6543afe6f55aca
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 24 Nov 2024 13:18:21 -0800
Better messaging around PyTorch workaround (#509)
Priority is compatibility w/ PyTorch; left a link for folks to learn more if they want to customize.
Diffstat:
8 files changed, 76 insertions(+), 60 deletions(-)
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
@@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
- if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
+ if [ -f environments/requirements.txt ]; then pip install -r environments/requirements.txt; fi
python -m pip install .
- name: Lint with flake8
run: |
diff --git a/environment_cpu.yml b/environment_cpu.yml
@@ -1,34 +0,0 @@
-# File: environment.yml
-# Created Date: Saturday February 13th 2021
-# Author: Steven Atkinson (steven@atkinson.mn)
-
-name: nam
-channels:
- - conda-forge # pytest-mock
- - pytorch
-dependencies:
- - python>=3.9
- - black
- - flake8
- - h5py
- - jupyter
- - matplotlib
- - numpy<2
- - pip
- - pre-commit
- - pydantic
- - pytest
- - pytest-mock
- - pytorch
- - scipy
- - semver
- - tensorboard
- - tqdm
- - wheel
- - pip:
- - auraloss==0.3.0
- - pytorch_lightning
- - sounddevice
- - transformers>=4 # See requirements.txt
- - wavio >=0.0.5
- - -e .
diff --git a/environments/environment_cpu_apple.yml b/environments/environment_cpu_apple.yml
@@ -0,0 +1,38 @@
+# File: environment.yml
+# Created Date: Saturday February 13th 2021
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+# Environment for CPU and macOS (Intel and Apple Silicon)
+
+name: nam
+channels:
+ - conda-forge # pytest-mock
+ - pytorch
+dependencies:
+ - python>=3.9
+ - black
+ - flake8
+ - h5py
+ - jupyter
+ - matplotlib
+ - numpy<2
+ - pip
+ - pre-commit
+ - pydantic
+ - pytest
+ - pytest-mock
+ # Performance note:
+ # https://github.com/sdatkinson/neural-amp-modeler/issues/505
+ - pytorch
+ - scipy
+ - semver
+ - tensorboard
+ - tqdm
+ - wheel
+ - pip:
+ - auraloss==0.3.0
+ - pytorch_lightning
+ - sounddevice
+ - transformers>=4 # See requirements.txt
+ - wavio >=0.0.5
+ - -e .
diff --git a/environment_gpu.yml b/environments/environment_gpu.yml
diff --git a/environments/requirements.txt b/environments/requirements.txt
@@ -0,0 +1,25 @@
+# File: requirements.txt
+# Created Date: 2021-01-24
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss
+black
+flake8
+matplotlib
+numpy<2
+pip
+pre-commit
+pydantic>=2.0.0
+pytest
+pytest-mock
+pytorch_lightning
+scipy
+sounddevice
+# Performance note: https://github.com/sdatkinson/neural-amp-modeler/issues/505
+torch
+# `transformers` is not required, but if you have it, it needs to be recent
+# enough so I'm adding it.
+transformers>=4
+tqdm
+wavio
+wheel
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -209,6 +209,17 @@ class BaseNet(_Base):
return self._forward(x, **kwargs)
except NotImplementedError as e:
if "Output channels > 65536 not supported at the MPS device." in str(e):
+ print(
+ "===WARNING===\n"
+ "NAM encountered a bug in PyTorch's MPS backend and will "
+ "switch to a fallback.\n"
+ f"Your version of PyTorch is {torch.__version__}.\n"
+ "Please report this in an Issue at:\n"
+ "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
+ "\n"
+ "so that NAM's dependencies can avoid buggy versions of "
+ "PyTorch and the associated performance hit."
+ )
self._mps_65536_fallback = True
return self._forward_mps_safe(x, **kwargs)
else:
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -3,7 +3,7 @@
# Author: Steven Atkinson (steven@atkinson.mn)
"""
-Implements the base PyTorch Lightning model.
+Implements the base PyTorch Lightning module.
This is meant to combine an actual model (subclassed from `._base.BaseNet`)
along with loss function boilerplate.
diff --git a/requirements.txt b/requirements.txt
@@ -1,24 +0,0 @@
-# File: requirements.txt
-# Created Date: 2021-01-24
-# Author: Steven Atkinson (steven@atkinson.mn)
-
-auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss
-black
-flake8
-matplotlib
-numpy<2
-pip
-pre-commit
-pydantic>=2.0.0
-pytest
-pytest-mock
-pytorch_lightning
-scipy
-sounddevice
-torch
-# Not required, but if you have it, it needs to be recent enough so I'm adding
-# it.
-transformers>=4
-tqdm
-wavio
-wheel