commit dd51f16f1e1760e4ed07efc3932d59c72ad55002
parent f743c037305f60b4e9800866ece4948e7cc508a9
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 25 Jun 2023 09:46:44 -0700
pass through wildcard kwargs to model forward pass
Diffstat:
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -138,14 +138,14 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
class BaseNet(_Base):
- def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None):
+ def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
if scalar:
x = x[None]
if pad_start:
x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1)
- y = self._forward(x)
+ y = self._forward(x, **kwargs)
if scalar:
y = y[0]
return y
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -243,12 +243,16 @@ class LSTM(BaseNet):
"""
return self._head(features)[:, :, 0]
- def _forward(self, x: torch.Tensor) -> torch.Tensor:
+ def _forward(
+ self, x: torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None
+ ) -> torch.Tensor:
"""
:param x: (B,L) or (B,L,D)
:return: (B,L)
"""
- last_hidden_state = self._initial_state(len(x))
+ last_hidden_state = (
+ self._initial_state(len(x)) if initial_state is None else initial_state
+ )
if x.ndim == 2:
x = x[:, :, None]
if not self.training or self._train_truncate is None: