neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit f8509336e28eb292557c9abb8af019cb31892543
parent 148ca8f7c7a197df18b91551f10fb9a215dfc252
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Thu, 16 Feb 2023 20:13:21 -0800

ONNX support for LSTMs (#93)

* Implement ONNX exporting for LSTMs

* ONNX export option in export script

* Fix some bugs

* typing fix
Diffstat:
Mbin/export/main.py | 23+++++++++++++++++------
Menvironment_cpu.yml | 2++
Menvironment_gpu.yml | 2++
Mnam/models/_exportable.py | 9+++++++++
Mnam/models/recurrent.py | 181++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------
Mrequirements.txt | 2++
Mtests/test_nam/test_models/test_recurrent.py | 45++++++++++++++++++++++++++++++++++++++++++---
7 files changed, 203 insertions(+), 61 deletions(-)

diff --git a/bin/export/main.py b/bin/export/main.py @@ -40,8 +40,13 @@ def main(args): export_args = (outdir, param_config) net.eval() outdir.mkdir(parents=True, exist_ok=True) - net.export(*export_args, include_snapshot=args.include_snapshot) - net.export_cpp_header(Path(export_args[0], "HardCodedModel.h"), *export_args[1:]) + net.export(*export_args, include_snapshot=args.snapshot) + if args.cpp: + net.export_cpp_header( + Path(export_args[0], "HardCodedModel.h"), *export_args[1:] + ) + if args.onnx: + net.export_onnx(Path(outdir, "model.onnx")) if __name__ == "__main__": @@ -50,12 +55,18 @@ if __name__ == "__main__": parser.add_argument("checkpoint", type=str) parser.add_argument("outdir") parser.add_argument( - "--include-snapshot", + "--param-config", type=str, help="Configuration for a parametric model" + ) + parser.add_argument("--onnx", action="store_true", help="Export an ONNX model") + parser.add_argument( + "--cpp", action="store_true", help="Export a CPP header for hard-coding a model" + ) + parser.add_argument( + "--snapshot", "-s", + action="store_true", help="Computes an example input-output pair for the model for debugging " "purposes", ) - parser.add_argument( - "--param-config", type=str, help="Configuration for a parametric model" - ) + main(parser.parse_args()) diff --git a/environment_cpu.yml b/environment_cpu.yml @@ -21,6 +21,8 @@ dependencies: - tqdm - wheel - pip: + - onnx + - onnxruntime - pre-commit - pytorch_lightning - sounddevice diff --git a/environment_gpu.yml b/environment_gpu.yml @@ -23,6 +23,8 @@ dependencies: - tqdm - wheel - pip: + - onnx + - onnxruntime # TODO GPU... - pre-commit - pytorch_lightning - sounddevice diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -68,6 +68,15 @@ class Exportable(abc.ABC): """ pass + def export_onnx(self, filename: Path): + """ + Export model in format for ONNX Runtime + """ + raise NotImplementedError( + "Exporting to ONNX is not supported for models of type " + f"{self.__class__.__name__}" + ) + @abc.abstractmethod def _export_config(self): """ diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py @@ -25,9 +25,10 @@ from ._base import BaseNet class _L(nn.LSTM): """ - Tweaks to LSTM + Tweaks to PyTorch LSTM module * Up the remembering """ + def reset_parameters(self) -> None: super().reset_parameters() # https://danijar.com/tips-for-training-recurrent-neural-networks/ @@ -39,21 +40,27 @@ class _L(nn.LSTM): for layer in range(self.num_layers): for input in ("i", "h"): # Balance out the scale of the cell w/ a -=1 - getattr(self, f"bias_{input}h_l{layer}").data[ - idx_input - ] -= value - getattr(self, f"bias_{input}h_l{layer}").data[ - idx_forget - ] += value + getattr(self, f"bias_{input}h_l{layer}").data[idx_input] -= value + getattr(self, f"bias_{input}h_l{layer}").data[idx_forget] += value + + +# State: +# L: Number of LSTM layers +# DH: Hidden state dimension +# [0]: hidden (L,DH) +# [1]: cell (L,DH) +_LSTMHiddenType = torch.Tensor +_LSTMCellType = torch.Tensor +_LSTMHiddenCellType = Tuple[_LSTMHiddenType, _LSTMCellType] class LSTMCore(_L): def __init__( - self, - *args, + self, + *args, train_burn_in: Optional[int] = None, - train_truncate: Optional[int] = None, - **kwargs + train_truncate: Optional[int] = None, + **kwargs, ): super().__init__(*args, **kwargs) if not self.batch_first: @@ -77,8 +84,9 @@ class LSTMCore(_L): if x.ndim != 3: raise NotImplementedError("Need (B,L,D)") last_hidden_state = ( - self._initial_state(None if x.ndim == 2 else len(x)) - if hidden_state is None else hidden_state + self._initial_state(None if x.ndim == 2 else len(x)) + if hidden_state is None + else hidden_state ) if not self.training or self._train_truncate is None: output_features = super().forward(x, last_hidden_state)[0] @@ -95,21 +103,23 @@ class LSTMCore(_L): # Don't detach the burn-in state so that we can learn it. last_hidden_state = tuple(z.detach() for z in last_hidden_state) last_output_features, last_hidden_state = super().forward( - x[:, i : i + self._train_truncate, :,], - last_hidden_state, + x[:, i : i + self._train_truncate, :], last_hidden_state ) output_features_list.append(last_output_features) output_features = torch.cat(output_features_list, dim=1) return output_features - def _initial_state(self, n: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]: - return (self._initial_hidden, self._initial_cell) if n is None else ( - torch.tile(self._initial_hidden[:, None], (1, n, 1)), - torch.tile(self._initial_cell[:, None], (1, n, 1)) + def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType: + return ( + (self._initial_hidden, self._initial_cell) + if n is None + else ( + torch.tile(self._initial_hidden[:, None], (1, n, 1)), + torch.tile(self._initial_cell[:, None], (1, n, 1)), + ) ) - class LSTM(BaseNet): """ ABC for recurrent architectures @@ -125,9 +135,9 @@ class LSTM(BaseNet): ): """ :param hidden_size: for LSTM - :param train_burn_in: Detach calculations from first (this many) samples when + :param train_burn_in: Detach calculations from first (this many) samples when training to burn in the hidden state. - :param train_truncate: detach the hidden & cell states every this many steps + :param train_truncate: detach the hidden & cell states every this many steps during training so that backpropagation through time is faster + to simulate better starting states for h(t0)&c(t0) (instead of zeros) TODO recognition head to start the hidden state in a good place? @@ -138,9 +148,7 @@ class LSTM(BaseNet): if "batch_first" in lstm_kwargs: raise ValueError("batch_first cannot be set.") self._input_size = input_size - self._core = _L( - self._input_size, hidden_size, batch_first=True, **lstm_kwargs - ) + self._core = _L(self._input_size, hidden_size, batch_first=True, **lstm_kwargs) self._head = nn.Linear(hidden_size, 1) self._train_burn_in = train_burn_in self._train_truncate = train_truncate @@ -164,7 +172,7 @@ class LSTM(BaseNet): with TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) LSTM.export(self, Path(tmpdir)) # Hacky...need to work w/ CatLSTM - with open(Path(tmpdir, "config.json"), "r") as fp: + with open(Path(tmpdir, "model.nam"), "r") as fp: _c = json.load(fp) version = _c["version"] config = _c["config"] @@ -185,20 +193,61 @@ class LSTM(BaseNet): + s_parametric + ( "std::vector<float> PARAMS{" - + ", ".join( - [f"{w:.16f}f" for w in np.load(Path(tmpdir, "weights.npy"))] - ) + + ", ".join([f"{w:.16f}f" for w in _c["weights"]]) + "};\n", ) ) - def _forward(self, x): + def export_onnx(self, filename: Path): + if self._input_size != 1: + raise NotImplementedError("Multi-dimensional inputs not supported yet") + o = _ONNXWrapped(self) + x = torch.randn((64,)) # (S,) + h, c = [z[:, 0, :] for z in self._initial_state(1)] # (L,DH), (L,DH) + torch.onnx.export( + o, + (x, h, c), + filename, + input_names=["x", "hin", "cin"], + output_names=["y", "hout", "cout"], + dynamic_axes={"x": {0: "num_frames"}, "y": {0: "num_frames"}}, + ) + + def forward_onnx( + self, x: torch.Tensor, h: _LSTMHiddenType, c: _LSTMCellType + ) -> Tuple[torch.Tensor, _LSTMHiddenType, _LSTMCellType]: + """ + Forward pass used by ONNX export + Only supports scalar inputs right now. + + N: Sequeence length + L: Nubmer of layers + DH: Hidden state dimension + + :param x: (N,) + :param state: (L, DH) + :param cell: (L, DH) + + :return: (N,), (L, DH), (L, DH) + """ + features, (h, c) = self._core(x[None, :, None], (h[:, None, :], c[:, None, :])) + y = self._apply_head(features) # (1,S) + return y[0, :], h[:, 0, :], c[:, 0, :] + + def _apply_head(self, features: torch.Tensor) -> torch.Tensor: + """ + :param features: (B,S,DH) + :return: (B,S) + """ + return self._head(features)[:, :, 0] + + def _forward(self, x: torch.Tensor) -> torch.Tensor: """ :param x: (B,L) or (B,L,D) :return: (B,L) """ last_hidden_state = self._initial_state(len(x)) - if x.ndim==2: + if x.ndim == 2: x = x[:, :, None] if not self.training or self._train_truncate is None: output_features = self._core(x, last_hidden_state)[0] @@ -215,12 +264,11 @@ class LSTM(BaseNet): # Don't detach the burn-in state so that we can learn it. last_hidden_state = tuple(z.detach() for z in last_hidden_state) last_output_features, last_hidden_state = self._core( - x[:, i : i + self._train_truncate, :,], - last_hidden_state, + x[:, i : i + self._train_truncate, :], last_hidden_state ) output_features_list.append(last_output_features) output_features = torch.cat(output_features_list, dim=1) - return self._head(output_features)[:, :, 0] + return self._apply_head(output_features) def _export_cell_weights( self, i: int, hidden_state: torch.Tensor, cell_state: torch.Tensor @@ -281,7 +329,7 @@ class LSTM(BaseNet): ] ) - def _get_initial_state(self, inputs=None) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType: """ Convenience function to find a good hidden state to start the plugin at @@ -296,31 +344,60 @@ class LSTM(BaseNet): _, (h, c) = self._core(inputs) return h, c - def _initial_state(self, n: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]: + def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType: """ Literally what the forward pass starts with. Default is zeroes; this should be better since it can be learned. """ - return (self._initial_hidden, self._initial_cell) if n is None else ( - torch.tile(self._initial_hidden[:, None], (1, n, 1)), - torch.tile(self._initial_cell[:, None], (1, n, 1)) + return ( + (self._initial_hidden, self._initial_cell) + if n is None + else ( + torch.tile(self._initial_hidden[:, None], (1, n, 1)), + torch.tile(self._initial_cell[:, None], (1, n, 1)), + ) ) +class _ONNXWrapped(nn.Module): + def __init__(self, net: LSTM): + super().__init__() + self._net = net + + def forward( + self, x: torch.Tensor, hidden: _LSTMHiddenType, cell: _LSTMCellType + ) -> Tuple[torch.Tensor, _LSTMHiddenType, _LSTMCellType]: + """ + N: Sequeence length + L: Nubmer of layers + DH: Hidden state dimension + + :param x: (N,) + :param state: (L, DH) + :param cell: (L, DH) + + :return: (N,), (L, DH), (L, DH) + """ + return self._net.forward_onnx(x, hidden, cell) + + # TODO refactor together + class _SkippyLSTM(nn.Module): - def __init__(self, input_size, hidden_size, skip_in: bool=False, num_layers=1, **kwargs): + def __init__( + self, input_size, hidden_size, skip_in: bool = False, num_layers=1, **kwargs + ): super().__init__() layers_per_lstm = 1 self._skip_in = skip_in self._lstms = nn.ModuleList( [ _L( - self._layer_input_size(input_size, hidden_size, i), - hidden_size, - layers_per_lstm, - batch_first=True + self._layer_input_size(input_size, hidden_size, i), + hidden_size, + layers_per_lstm, + batch_first=True, ) for i in range(num_layers) ] @@ -329,9 +406,9 @@ class _SkippyLSTM(nn.Module): torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size)) ) self._initial_cell = nn.Parameter( - torch.zeros((self.num_layers, layers_per_lstm , self.hidden_size)) + torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size)) ) - + @property def hidden_size(self): return self._lstms[0].hidden_size @@ -360,16 +437,16 @@ class _SkippyLSTM(nn.Module): for layer, h0i, c0i in zip(self._lstms, h0, c0): if self._skip_in: # TODO dense-block - layer_input = input if hidden is None else torch.cat([input, hidden], dim=2) + layer_input = ( + input if hidden is None else torch.cat([input, hidden], dim=2) + ) else: layer_input = input if hidden is None else hidden hidden, (hi, ci) = layer(layer_input, (h0i, c0i)) hiddens.append(hidden) h_arr.append(hi) c_arr.append(ci) - return ( - torch.cat(hiddens, dim=2), (torch.stack(h_arr), torch.stack(c_arr)) - ) + return (torch.cat(hiddens, dim=2), (torch.stack(h_arr), torch.stack(c_arr))) def initial_state(self, input: torch.Tensor): """ @@ -377,11 +454,11 @@ class _SkippyLSTM(nn.Module): :return: (L,B,Li,DH) """ - assert input.ndim==3, "Batch only for now" + assert input.ndim == 3, "Batch only for now" batch_size = len(input) # Assume batch_first return ( torch.tile(self._initial_hidden[:, :, None], (1, 1, batch_size, 1)), - torch.tile(self._initial_cell[:, :, None], (1, 1, batch_size, 1)) + torch.tile(self._initial_cell[:, :, None], (1, 1, batch_size, 1)), ) def _layer_input_size(self, input_size, hidden_size, i) -> int: diff --git a/requirements.txt b/requirements.txt @@ -6,6 +6,8 @@ black flake8 matplotlib numpy +onnx +onnxruntime pip pre-commit pytest diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py @@ -2,18 +2,57 @@ # Created Date: Sunday July 17th 2022 # Author: Steven Atkinson (steven@atkinson.mn) -from .base import Base +from pathlib import Path +from tempfile import TemporaryDirectory + +import onnx +import onnxruntime +import pytest +import torch from nam.models import recurrent +from .base import Base + class TestLSTM(Base): @classmethod def setup_class(cls): + num_layers = 2 hidden_size = 3 - return super().setup_class( + super().setup_class( recurrent.LSTM, args=(hidden_size,), - kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": 2}, + kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": num_layers}, ) + cls._num_layers = num_layers + cls._hidden_size = hidden_size + + def test_export_onnx(self): + model = self._construct() + with TemporaryDirectory() as tmpdir: + filename = Path(tmpdir, "model.onnx") + model.export_onnx(filename) + onnx_model = onnx.load(filename) + session = onnxruntime.InferenceSession(str(filename)) + onnx.checker.check_model(onnx_model) + wrapped_model = recurrent._ONNXWrapped(model) + x = torch.Tensor([0.5, -0.5, 0.4, -0.4, 0.3, -0.3, 0.2]) + hin = torch.zeros((self._num_layers, self._hidden_size)) + cin = torch.zeros((self._num_layers, self._hidden_size)) + + with torch.no_grad(): + y_expected, hout_expected, cout_expected = [ + z.detach().cpu().numpy() for z in wrapped_model(x, hin, cin) + ] + + input_names = [z.name for z in session.get_inputs()] + onnx_inputs = {i: z.detach().cpu().numpy() for i, z in zip(input_names, (x, hin, cin))} + y_actual, hout_actual, cout_actual = session.run([], onnx_inputs) + def approx(val): + return pytest.approx(val, rel=1.0e-6, abs=1.0e-6) + + assert y_expected == approx(y_actual) + assert hout_expected == approx(hout_actual) + assert cout_expected == approx(cout_actual)