neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 9a1c72e8c4559f278cc1bb616e6543afe6f55aca
parent 067077812a0b23f77090bb64d75bedeb7e7f0a2b
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 23 Nov 2024 20:40:02 -0800

[BUGFIX] Workaround for PyTorch MPS bug with sequences longer than 65,536 samples (#506)

* Add testing to show failure cases for models using convolutions

* Fix imports

* Remove unused imports

* Fix

* Fix bug

* Remove debug statements

* Reason

* Skip condition

* Fix bug: decorator
Diffstat:
Mnam/models/_base.py | 45+++++++++++++++++++++++++++++++++++++++++++--
Mnam/models/linear.py | 46+++++++++-------------------------------------
Atests/test_nam/test_models/_convolutional.py | 30++++++++++++++++++++++++++++++
Mtests/test_nam/test_models/test_conv_net.py | 15++++++---------
Atests/test_nam/test_models/test_linear.py | 18++++++++++++++++++
Mtests/test_nam/test_models/test_wavenet.py | 29+++++++++++++++++++++++++----
6 files changed, 131 insertions(+), 52 deletions(-)

diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -170,6 +170,10 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): class BaseNet(_Base): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._mps_65536_fallback = False + def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs): pad_start = self.pad_start_default if pad_start is None else pad_start scalar = x.ndim == 1 @@ -179,16 +183,53 @@ class BaseNet(_Base): x = torch.cat( (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 ) - y = self._forward(x, **kwargs) + if x.shape[1] < self.receptive_field: + raise ValueError( + f"Input has {x.shape[1]} samples, which is too few for this model with " + f"receptive field {self.receptive_field}!" + ) + y = self._forward_mps_safe(x, **kwargs) if scalar: y = y[0] return y def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: return self(x) + + def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Wrap `._forward()` to protect against MPS-unsupported inptu lengths + beyond 65,536 samples. + + Check this again when PyTorch 2.5.2 is released--hopefully it's fixed + then. + """ + if not self._mps_65536_fallback: + try: + return self._forward(x, **kwargs) + except NotImplementedError as e: + if "Output channels > 65536 not supported at the MPS device." in str(e): + self._mps_65536_fallback = True + return self._forward_mps_safe(x, **kwargs) + else: + raise e + else: + # Stitch together the output one piece at a time to avoid the MPS error + stride = 65_536 - (self.receptive_field - 1) + # We need to make sure that the last segment is big enough that we have the required history for the receptive field. + out_list = [] + for i in range(0, x.shape[1], stride): + j = min(i+65_536, x.shape[1]) + xi = x[:, i:j] + out_list.append(self._forward(xi, **kwargs)) + # Bit hacky, but correct. + if j == x.shape[1]: + break + return torch.cat(out_list, dim=1) + @abc.abstractmethod - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ The true forward method. diff --git a/nam/models/linear.py b/nam/models/linear.py @@ -6,9 +6,6 @@ Linear model """ -import json -from pathlib import Path - import numpy as np import torch import torch.nn as nn @@ -30,38 +27,6 @@ class Linear(BaseNet): def receptive_field(self) -> int: return self._net.weight.shape[2] - def export(self, outdir: Path): - training = self.training - self.eval() - with open(Path(outdir, "config.json"), "w") as fp: - json.dump( - { - "version": __version__, - "architecture": self.__class__.__name__, - "config": { - "receptive_field": self.receptive_field, - "bias": self._bias, - }, - }, - fp, - indent=4, - ) - - params = [self._net.weight.flatten()] - if self._bias: - params.append(self._net.bias.flatten()) - params = torch.cat(params).detach().cpu().numpy() - # Hope I don't regret using np.save... - np.save(Path(outdir, "weights.npy"), params) - - # And an input/output to verify correct computation: - x, y = self._export_input_output() - np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy()) - np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy()) - - # And resume training state - self.train(training) - def export_cpp_header(self): raise NotImplementedError() @@ -73,7 +38,14 @@ class Linear(BaseNet): return self._net(x[:, None])[:, 0] def _export_config(self): - raise NotImplementedError() + return { + "receptive_field": self.receptive_field, + "bias": self._bias, + } def _export_weights(self) -> np.ndarray: - raise NotImplementedError() + params_list = [self._net.weight.flatten()] + if self._bias: + params_list.append(self._net.bias.flatten()) + params = torch.cat(params_list).detach().cpu().numpy() + return params diff --git a/tests/test_nam/test_models/_convolutional.py b/tests/test_nam/test_models/_convolutional.py @@ -0,0 +1,30 @@ +# File: _conv_mixin.py +# Created Date: Saturday November 23rd 2024 +# Author: Steven Atkinson (steven@atkinson.mn) + +""" +Mix-in tests for models with a convolution layer +""" + +import pytest as _pytest +import torch as _torch + +from .base import Base as _Base + + +class Convolutional(_Base): + @_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test") + def test_process_input_longer_than_65536(self): + """ + Processing inputs longer than 65,536 samples using the MPS backend can + cause problems. + + See: https://github.com/sdatkinson/neural-amp-modeler/issues/505 + + Assert that precautions are taken. + """ + + x = _torch.zeros((65_536 + 1,)).to("mps") + + model = self._construct().to("mps") + model(x) diff --git a/tests/test_nam/test_models/test_conv_net.py b/tests/test_nam/test_models/test_conv_net.py @@ -2,17 +2,14 @@ # Created Date: Friday May 6th 2022 # Author: Steven Atkinson (steven@atkinson.mn) -from pathlib import Path -from tempfile import TemporaryDirectory - -import pytest +import pytest as _pytest from nam.models import conv_net -from .base import Base +from ._convolutional import Convolutional as _Convolutional -class TestConvNet(Base): +class TestConvNet(_Convolutional): @classmethod def setup_class(cls): channels = 3 @@ -23,13 +20,13 @@ class TestConvNet(Base): {"batchnorm": False, "activation": "Tanh"}, ) - @pytest.mark.parametrize( + @_pytest.mark.parametrize( ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh")) ) def test_init(self, batchnorm, activation): super().test_init(kwargs={"batchnorm": batchnorm, "activation": activation}) - @pytest.mark.parametrize( + @_pytest.mark.parametrize( ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh")) ) def test_export(self, batchnorm, activation): @@ -37,4 +34,4 @@ class TestConvNet(Base): if __name__ == "__main__": - pytest.main() + _pytest.main() diff --git a/tests/test_nam/test_models/test_linear.py b/tests/test_nam/test_models/test_linear.py @@ -0,0 +1,18 @@ +# File: test_linear.py +# Created Date: Saturday November 23rd 2024 +# Author: Steven Atkinson (steven@atkinson.mn) + +import pytest as _pytest + +from nam.models import linear as _linear + +from ._convolutional import Convolutional as _Convolutional + + +class TestLinear(_Convolutional): + @classmethod + def setup_class(cls): + C = _linear.Linear + args = () + kwargs = {"receptive_field": 2, "sample_rate": 44100} + super().setup_class(C, args, kwargs) diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py @@ -8,11 +8,28 @@ import torch from nam.models.wavenet import WaveNet from nam.train.core import Architecture, get_wavenet_config +from ._convolutional import Convolutional as _Convolutional + + +class TestWaveNet(_Convolutional): + @classmethod + def setup_class(cls): + C = WaveNet + args = () + kwargs = { + "layers_configs": [ + { + "input_size": 1, + "condition_size": 1, + "head_size": 1, + "channels": 1, + "kernel_size": 1, + "dilations": [1] + } + ] + } + super().setup_class(C, args, kwargs) -# from .base import Base - - -class TestWaveNet(object): def test_import_weights(self): config = get_wavenet_config(Architecture.FEATHER) model_1 = WaveNet.init_from_config(config) @@ -29,3 +46,7 @@ class TestWaveNet(object): assert not torch.allclose(y2_before, y1) assert torch.allclose(y2_after, y1) + + +if __name__ == "__main__": + pytest.main()