neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit dd51f16f1e1760e4ed07efc3932d59c72ad55002
parent f743c037305f60b4e9800866ece4948e7cc508a9
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 25 Jun 2023 09:46:44 -0700

pass through wildcard kwargs to model forward pass

Diffstat:
Mnam/models/_base.py | 4++--
Mnam/models/recurrent.py | 8++++++--
2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -138,14 +138,14 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): class BaseNet(_Base): - def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None): + def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs): pad_start = self.pad_start_default if pad_start is None else pad_start scalar = x.ndim == 1 if scalar: x = x[None] if pad_start: x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1) - y = self._forward(x) + y = self._forward(x, **kwargs) if scalar: y = y[0] return y diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py @@ -243,12 +243,16 @@ class LSTM(BaseNet): """ return self._head(features)[:, :, 0] - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def _forward( + self, x: torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None + ) -> torch.Tensor: """ :param x: (B,L) or (B,L,D) :return: (B,L) """ - last_hidden_state = self._initial_state(len(x)) + last_hidden_state = ( + self._initial_state(len(x)) if initial_state is None else initial_state + ) if x.ndim == 2: x = x[:, :, None] if not self.training or self._train_truncate is None: