commit d04fa0ca76e30c2f558f2208b36d8542e9c0290b
parent 6d33b41cf9f361b774e25c4e1e6340503ae7410e
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 17 Aug 2024 00:12:03 -0700
[BUGFIX] Fix LSTM cuDNN bug for PyTorch 2.4 (#454)
* Fix LSTM cuDNN bug
* Add tests
Diffstat:
3 files changed, 64 insertions(+), 9 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -449,6 +449,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return cls._apply_delay_int(x, y, delay)
elif isinstance(delay, float):
return cls._apply_delay_float(x, y, delay, method)
+ else:
+ raise TypeError(type(delay))
@classmethod
def _apply_delay_int(
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -8,6 +8,7 @@ Recurrent models (LSTM)
TODO batch_first=False (I get it...)
"""
+import abc
import json
from pathlib import Path
from tempfile import TemporaryDirectory
@@ -120,6 +121,26 @@ class LSTMCore(_L):
)
+# TODO get this somewhere more core-ish
+class _ExportsWeights(abc.ABC):
+ @abc.abstractmethod
+ def export_weights(self) -> np.ndarray:
+ """
+ :return: a 1D array of weights
+ """
+ pass
+
+
+class _Linear(nn.Linear, _ExportsWeights):
+ def export_weights(self):
+ return np.concatenate(
+ [
+ self.weight.data.detach().cpu().numpy().flatten(),
+ self.bias.data.detach().cpu().numpy().flatten(),
+ ]
+ )
+
+
class LSTM(BaseNet):
"""
ABC for recurrent architectures
@@ -150,7 +171,7 @@ class LSTM(BaseNet):
raise ValueError("batch_first cannot be set.")
self._input_size = input_size
self._core = _L(self._input_size, hidden_size, batch_first=True, **lstm_kwargs)
- self._head = nn.Linear(hidden_size, 1)
+ self._head = self._init_head(hidden_size)
self._train_burn_in = train_burn_in
self._train_truncate = train_truncate
self._initial_cell = nn.Parameter(
@@ -162,6 +183,13 @@ class LSTM(BaseNet):
self._get_initial_state_burn_in = 48_000
@property
+ def input_device(self) -> torch.device:
+ """
+ What device does the input need to be on?
+ """
+ return self._core.bias_ih_l0.device
+
+ @property
def receptive_field(self) -> int:
return 1
@@ -250,17 +278,29 @@ class LSTM(BaseNet):
:param x: (B,L) or (B,L,D)
:return: (B,L)
"""
+
+ def process_in_blocks(x, hidden_state=None):
+ # See: https://github.com/sdatkinson/neural-amp-modeler/issues/450
+ BLOCK_SIZE = 65_535
+ outputs = []
+ for i in range(0, x.shape[1], BLOCK_SIZE):
+ out, hidden_state = self._core(
+ x[:, i : i + BLOCK_SIZE, :], hidden_state
+ )
+ outputs.append(out)
+ return torch.cat(outputs, dim=1), hidden_state # assert batch_first
+
last_hidden_state = (
self._initial_state(len(x)) if initial_state is None else initial_state
)
if x.ndim == 2:
x = x[:, :, None]
if not self.training or self._train_truncate is None:
- output_features = self._core(x, last_hidden_state)[0]
+ output_features = process_in_blocks(x, last_hidden_state)[0]
else:
output_features_list = []
if self._train_burn_in is not None:
- last_output_features, last_hidden_state = self._core(
+ last_output_features, last_hidden_state = process_in_blocks(
x[:, : self._train_burn_in, :], last_hidden_state
)
output_features_list.append(last_output_features.detach())
@@ -269,7 +309,7 @@ class LSTM(BaseNet):
if i > burn_in_offset:
# Don't detach the burn-in state so that we can learn it.
last_hidden_state = tuple(z.detach() for z in last_hidden_state)
- last_output_features, last_hidden_state = self._core(
+ last_output_features, last_hidden_state = process_in_blocks(
x[:, i : i + self._train_truncate, :], last_hidden_state
)
output_features_list.append(last_output_features)
@@ -329,10 +369,7 @@ class LSTM(BaseNet):
self._export_cell_weights(i, h, c)
for i, (h, c) in enumerate(zip(*self._get_initial_state()))
]
- + [
- self._head.weight.data.detach().cpu().numpy().flatten(),
- self._head.bias.data.detach().cpu().numpy().flatten(),
- ]
+ + [self._head.export_weights()]
)
def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType:
@@ -350,10 +387,13 @@ class LSTM(BaseNet):
torch.zeros((1, self._get_initial_state_burn_in, 1))
if inputs is None
else inputs
- )
+ ).to(self.input_device)
_, (h, c) = self._core(inputs)
return h, c
+ def _init_head(self, hidden_size: int) -> _ExportsWeights:
+ return _Linear(hidden_size, 1)
+
def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType:
"""
Literally what the forward pass starts with.
diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py
@@ -69,3 +69,16 @@ class TestLSTM(Base):
assert y_expected == approx(y_actual)
assert hout_expected == approx(hout_actual)
assert cout_expected == approx(cout_actual)
+
+ def test_get_initial_state_cpu(self):
+ return self._t_initial_state("cpu")
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU test")
+ def test_get_initial_state_gpu(self):
+ self._t_initial_state("cuda")
+
+ def _t_initial_state(self, device):
+ model = self._construct().to(device)
+ h, c = model._get_initial_state()
+ assert isinstance(h, torch.Tensor)
+ assert isinstance(c, torch.Tensor)