neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit d04fa0ca76e30c2f558f2208b36d8542e9c0290b
parent 6d33b41cf9f361b774e25c4e1e6340503ae7410e
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 17 Aug 2024 00:12:03 -0700

[BUGFIX] Fix LSTM cuDNN bug for PyTorch 2.4 (#454)

* Fix LSTM cuDNN bug

* Add tests
Diffstat:
Mnam/data.py | 2++
Mnam/models/recurrent.py | 58+++++++++++++++++++++++++++++++++++++++++++++++++---------
Mtests/test_nam/test_models/test_recurrent.py | 13+++++++++++++
3 files changed, 64 insertions(+), 9 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -449,6 +449,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): return cls._apply_delay_int(x, y, delay) elif isinstance(delay, float): return cls._apply_delay_float(x, y, delay, method) + else: + raise TypeError(type(delay)) @classmethod def _apply_delay_int( diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py @@ -8,6 +8,7 @@ Recurrent models (LSTM) TODO batch_first=False (I get it...) """ +import abc import json from pathlib import Path from tempfile import TemporaryDirectory @@ -120,6 +121,26 @@ class LSTMCore(_L): ) +# TODO get this somewhere more core-ish +class _ExportsWeights(abc.ABC): + @abc.abstractmethod + def export_weights(self) -> np.ndarray: + """ + :return: a 1D array of weights + """ + pass + + +class _Linear(nn.Linear, _ExportsWeights): + def export_weights(self): + return np.concatenate( + [ + self.weight.data.detach().cpu().numpy().flatten(), + self.bias.data.detach().cpu().numpy().flatten(), + ] + ) + + class LSTM(BaseNet): """ ABC for recurrent architectures @@ -150,7 +171,7 @@ class LSTM(BaseNet): raise ValueError("batch_first cannot be set.") self._input_size = input_size self._core = _L(self._input_size, hidden_size, batch_first=True, **lstm_kwargs) - self._head = nn.Linear(hidden_size, 1) + self._head = self._init_head(hidden_size) self._train_burn_in = train_burn_in self._train_truncate = train_truncate self._initial_cell = nn.Parameter( @@ -162,6 +183,13 @@ class LSTM(BaseNet): self._get_initial_state_burn_in = 48_000 @property + def input_device(self) -> torch.device: + """ + What device does the input need to be on? + """ + return self._core.bias_ih_l0.device + + @property def receptive_field(self) -> int: return 1 @@ -250,17 +278,29 @@ class LSTM(BaseNet): :param x: (B,L) or (B,L,D) :return: (B,L) """ + + def process_in_blocks(x, hidden_state=None): + # See: https://github.com/sdatkinson/neural-amp-modeler/issues/450 + BLOCK_SIZE = 65_535 + outputs = [] + for i in range(0, x.shape[1], BLOCK_SIZE): + out, hidden_state = self._core( + x[:, i : i + BLOCK_SIZE, :], hidden_state + ) + outputs.append(out) + return torch.cat(outputs, dim=1), hidden_state # assert batch_first + last_hidden_state = ( self._initial_state(len(x)) if initial_state is None else initial_state ) if x.ndim == 2: x = x[:, :, None] if not self.training or self._train_truncate is None: - output_features = self._core(x, last_hidden_state)[0] + output_features = process_in_blocks(x, last_hidden_state)[0] else: output_features_list = [] if self._train_burn_in is not None: - last_output_features, last_hidden_state = self._core( + last_output_features, last_hidden_state = process_in_blocks( x[:, : self._train_burn_in, :], last_hidden_state ) output_features_list.append(last_output_features.detach()) @@ -269,7 +309,7 @@ class LSTM(BaseNet): if i > burn_in_offset: # Don't detach the burn-in state so that we can learn it. last_hidden_state = tuple(z.detach() for z in last_hidden_state) - last_output_features, last_hidden_state = self._core( + last_output_features, last_hidden_state = process_in_blocks( x[:, i : i + self._train_truncate, :], last_hidden_state ) output_features_list.append(last_output_features) @@ -329,10 +369,7 @@ class LSTM(BaseNet): self._export_cell_weights(i, h, c) for i, (h, c) in enumerate(zip(*self._get_initial_state())) ] - + [ - self._head.weight.data.detach().cpu().numpy().flatten(), - self._head.bias.data.detach().cpu().numpy().flatten(), - ] + + [self._head.export_weights()] ) def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType: @@ -350,10 +387,13 @@ class LSTM(BaseNet): torch.zeros((1, self._get_initial_state_burn_in, 1)) if inputs is None else inputs - ) + ).to(self.input_device) _, (h, c) = self._core(inputs) return h, c + def _init_head(self, hidden_size: int) -> _ExportsWeights: + return _Linear(hidden_size, 1) + def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType: """ Literally what the forward pass starts with. diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py @@ -69,3 +69,16 @@ class TestLSTM(Base): assert y_expected == approx(y_actual) assert hout_expected == approx(hout_actual) assert cout_expected == approx(cout_actual) + + def test_get_initial_state_cpu(self): + return self._t_initial_state("cpu") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU test") + def test_get_initial_state_gpu(self): + self._t_initial_state("cuda") + + def _t_initial_state(self, device): + model = self._construct().to(device) + h, c = model._get_initial_state() + assert isinstance(h, torch.Tensor) + assert isinstance(c, torch.Tensor)