commit 9a1c72e8c4559f278cc1bb616e6543afe6f55aca
parent 067077812a0b23f77090bb64d75bedeb7e7f0a2b
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 23 Nov 2024 20:40:02 -0800
[BUGFIX] Workaround for PyTorch MPS bug with sequences longer than 65,536 samples (#506)
* Add testing to show failure cases for models using convolutions
* Fix imports
* Remove unused imports
* Fix
* Fix bug
* Remove debug statements
* Reason
* Skip condition
* Fix bug: decorator
Diffstat:
6 files changed, 131 insertions(+), 52 deletions(-)
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -170,6 +170,10 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
class BaseNet(_Base):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._mps_65536_fallback = False
+
def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
@@ -179,16 +183,53 @@ class BaseNet(_Base):
x = torch.cat(
(torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
)
- y = self._forward(x, **kwargs)
+ if x.shape[1] < self.receptive_field:
+ raise ValueError(
+ f"Input has {x.shape[1]} samples, which is too few for this model with "
+ f"receptive field {self.receptive_field}!"
+ )
+ y = self._forward_mps_safe(x, **kwargs)
if scalar:
y = y[0]
return y
def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
return self(x)
+
+ def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Wrap `._forward()` to protect against MPS-unsupported inptu lengths
+ beyond 65,536 samples.
+
+ Check this again when PyTorch 2.5.2 is released--hopefully it's fixed
+ then.
+ """
+ if not self._mps_65536_fallback:
+ try:
+ return self._forward(x, **kwargs)
+ except NotImplementedError as e:
+ if "Output channels > 65536 not supported at the MPS device." in str(e):
+ self._mps_65536_fallback = True
+ return self._forward_mps_safe(x, **kwargs)
+ else:
+ raise e
+ else:
+ # Stitch together the output one piece at a time to avoid the MPS error
+ stride = 65_536 - (self.receptive_field - 1)
+ # We need to make sure that the last segment is big enough that we have the required history for the receptive field.
+ out_list = []
+ for i in range(0, x.shape[1], stride):
+ j = min(i+65_536, x.shape[1])
+ xi = x[:, i:j]
+ out_list.append(self._forward(xi, **kwargs))
+ # Bit hacky, but correct.
+ if j == x.shape[1]:
+ break
+ return torch.cat(out_list, dim=1)
+
@abc.abstractmethod
- def _forward(self, x: torch.Tensor) -> torch.Tensor:
+ def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
The true forward method.
diff --git a/nam/models/linear.py b/nam/models/linear.py
@@ -6,9 +6,6 @@
Linear model
"""
-import json
-from pathlib import Path
-
import numpy as np
import torch
import torch.nn as nn
@@ -30,38 +27,6 @@ class Linear(BaseNet):
def receptive_field(self) -> int:
return self._net.weight.shape[2]
- def export(self, outdir: Path):
- training = self.training
- self.eval()
- with open(Path(outdir, "config.json"), "w") as fp:
- json.dump(
- {
- "version": __version__,
- "architecture": self.__class__.__name__,
- "config": {
- "receptive_field": self.receptive_field,
- "bias": self._bias,
- },
- },
- fp,
- indent=4,
- )
-
- params = [self._net.weight.flatten()]
- if self._bias:
- params.append(self._net.bias.flatten())
- params = torch.cat(params).detach().cpu().numpy()
- # Hope I don't regret using np.save...
- np.save(Path(outdir, "weights.npy"), params)
-
- # And an input/output to verify correct computation:
- x, y = self._export_input_output()
- np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
- np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())
-
- # And resume training state
- self.train(training)
-
def export_cpp_header(self):
raise NotImplementedError()
@@ -73,7 +38,14 @@ class Linear(BaseNet):
return self._net(x[:, None])[:, 0]
def _export_config(self):
- raise NotImplementedError()
+ return {
+ "receptive_field": self.receptive_field,
+ "bias": self._bias,
+ }
def _export_weights(self) -> np.ndarray:
- raise NotImplementedError()
+ params_list = [self._net.weight.flatten()]
+ if self._bias:
+ params_list.append(self._net.bias.flatten())
+ params = torch.cat(params_list).detach().cpu().numpy()
+ return params
diff --git a/tests/test_nam/test_models/_convolutional.py b/tests/test_nam/test_models/_convolutional.py
@@ -0,0 +1,30 @@
+# File: _conv_mixin.py
+# Created Date: Saturday November 23rd 2024
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+"""
+Mix-in tests for models with a convolution layer
+"""
+
+import pytest as _pytest
+import torch as _torch
+
+from .base import Base as _Base
+
+
+class Convolutional(_Base):
+ @_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test")
+ def test_process_input_longer_than_65536(self):
+ """
+ Processing inputs longer than 65,536 samples using the MPS backend can
+ cause problems.
+
+ See: https://github.com/sdatkinson/neural-amp-modeler/issues/505
+
+ Assert that precautions are taken.
+ """
+
+ x = _torch.zeros((65_536 + 1,)).to("mps")
+
+ model = self._construct().to("mps")
+ model(x)
diff --git a/tests/test_nam/test_models/test_conv_net.py b/tests/test_nam/test_models/test_conv_net.py
@@ -2,17 +2,14 @@
# Created Date: Friday May 6th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
-from pathlib import Path
-from tempfile import TemporaryDirectory
-
-import pytest
+import pytest as _pytest
from nam.models import conv_net
-from .base import Base
+from ._convolutional import Convolutional as _Convolutional
-class TestConvNet(Base):
+class TestConvNet(_Convolutional):
@classmethod
def setup_class(cls):
channels = 3
@@ -23,13 +20,13 @@ class TestConvNet(Base):
{"batchnorm": False, "activation": "Tanh"},
)
- @pytest.mark.parametrize(
+ @_pytest.mark.parametrize(
("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
)
def test_init(self, batchnorm, activation):
super().test_init(kwargs={"batchnorm": batchnorm, "activation": activation})
- @pytest.mark.parametrize(
+ @_pytest.mark.parametrize(
("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
)
def test_export(self, batchnorm, activation):
@@ -37,4 +34,4 @@ class TestConvNet(Base):
if __name__ == "__main__":
- pytest.main()
+ _pytest.main()
diff --git a/tests/test_nam/test_models/test_linear.py b/tests/test_nam/test_models/test_linear.py
@@ -0,0 +1,18 @@
+# File: test_linear.py
+# Created Date: Saturday November 23rd 2024
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+import pytest as _pytest
+
+from nam.models import linear as _linear
+
+from ._convolutional import Convolutional as _Convolutional
+
+
+class TestLinear(_Convolutional):
+ @classmethod
+ def setup_class(cls):
+ C = _linear.Linear
+ args = ()
+ kwargs = {"receptive_field": 2, "sample_rate": 44100}
+ super().setup_class(C, args, kwargs)
diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py
@@ -8,11 +8,28 @@ import torch
from nam.models.wavenet import WaveNet
from nam.train.core import Architecture, get_wavenet_config
+from ._convolutional import Convolutional as _Convolutional
+
+
+class TestWaveNet(_Convolutional):
+ @classmethod
+ def setup_class(cls):
+ C = WaveNet
+ args = ()
+ kwargs = {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "head_size": 1,
+ "channels": 1,
+ "kernel_size": 1,
+ "dilations": [1]
+ }
+ ]
+ }
+ super().setup_class(C, args, kwargs)
-# from .base import Base
-
-
-class TestWaveNet(object):
def test_import_weights(self):
config = get_wavenet_config(Architecture.FEATHER)
model_1 = WaveNet.init_from_config(config)
@@ -29,3 +46,7 @@ class TestWaveNet(object):
assert not torch.allclose(y2_before, y1)
assert torch.allclose(y2_after, y1)
+
+
+if __name__ == "__main__":
+ pytest.main()