commit f8509336e28eb292557c9abb8af019cb31892543
parent 148ca8f7c7a197df18b91551f10fb9a215dfc252
Author: Steven Atkinson <steven@atkinson.mn>
Date: Thu, 16 Feb 2023 20:13:21 -0800
ONNX support for LSTMs (#93)
* Implement ONNX exporting for LSTMs
* ONNX export option in export script
* Fix some bugs
* typing fix
Diffstat:
7 files changed, 203 insertions(+), 61 deletions(-)
diff --git a/bin/export/main.py b/bin/export/main.py
@@ -40,8 +40,13 @@ def main(args):
export_args = (outdir, param_config)
net.eval()
outdir.mkdir(parents=True, exist_ok=True)
- net.export(*export_args, include_snapshot=args.include_snapshot)
- net.export_cpp_header(Path(export_args[0], "HardCodedModel.h"), *export_args[1:])
+ net.export(*export_args, include_snapshot=args.snapshot)
+ if args.cpp:
+ net.export_cpp_header(
+ Path(export_args[0], "HardCodedModel.h"), *export_args[1:]
+ )
+ if args.onnx:
+ net.export_onnx(Path(outdir, "model.onnx"))
if __name__ == "__main__":
@@ -50,12 +55,18 @@ if __name__ == "__main__":
parser.add_argument("checkpoint", type=str)
parser.add_argument("outdir")
parser.add_argument(
- "--include-snapshot",
+ "--param-config", type=str, help="Configuration for a parametric model"
+ )
+ parser.add_argument("--onnx", action="store_true", help="Export an ONNX model")
+ parser.add_argument(
+ "--cpp", action="store_true", help="Export a CPP header for hard-coding a model"
+ )
+ parser.add_argument(
+ "--snapshot",
"-s",
+ action="store_true",
help="Computes an example input-output pair for the model for debugging "
"purposes",
)
- parser.add_argument(
- "--param-config", type=str, help="Configuration for a parametric model"
- )
+
main(parser.parse_args())
diff --git a/environment_cpu.yml b/environment_cpu.yml
@@ -21,6 +21,8 @@ dependencies:
- tqdm
- wheel
- pip:
+ - onnx
+ - onnxruntime
- pre-commit
- pytorch_lightning
- sounddevice
diff --git a/environment_gpu.yml b/environment_gpu.yml
@@ -23,6 +23,8 @@ dependencies:
- tqdm
- wheel
- pip:
+ - onnx
+ - onnxruntime # TODO GPU...
- pre-commit
- pytorch_lightning
- sounddevice
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -68,6 +68,15 @@ class Exportable(abc.ABC):
"""
pass
+ def export_onnx(self, filename: Path):
+ """
+ Export model in format for ONNX Runtime
+ """
+ raise NotImplementedError(
+ "Exporting to ONNX is not supported for models of type "
+ f"{self.__class__.__name__}"
+ )
+
@abc.abstractmethod
def _export_config(self):
"""
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -25,9 +25,10 @@ from ._base import BaseNet
class _L(nn.LSTM):
"""
- Tweaks to LSTM
+ Tweaks to PyTorch LSTM module
* Up the remembering
"""
+
def reset_parameters(self) -> None:
super().reset_parameters()
# https://danijar.com/tips-for-training-recurrent-neural-networks/
@@ -39,21 +40,27 @@ class _L(nn.LSTM):
for layer in range(self.num_layers):
for input in ("i", "h"):
# Balance out the scale of the cell w/ a -=1
- getattr(self, f"bias_{input}h_l{layer}").data[
- idx_input
- ] -= value
- getattr(self, f"bias_{input}h_l{layer}").data[
- idx_forget
- ] += value
+ getattr(self, f"bias_{input}h_l{layer}").data[idx_input] -= value
+ getattr(self, f"bias_{input}h_l{layer}").data[idx_forget] += value
+
+
+# State:
+# L: Number of LSTM layers
+# DH: Hidden state dimension
+# [0]: hidden (L,DH)
+# [1]: cell (L,DH)
+_LSTMHiddenType = torch.Tensor
+_LSTMCellType = torch.Tensor
+_LSTMHiddenCellType = Tuple[_LSTMHiddenType, _LSTMCellType]
class LSTMCore(_L):
def __init__(
- self,
- *args,
+ self,
+ *args,
train_burn_in: Optional[int] = None,
- train_truncate: Optional[int] = None,
- **kwargs
+ train_truncate: Optional[int] = None,
+ **kwargs,
):
super().__init__(*args, **kwargs)
if not self.batch_first:
@@ -77,8 +84,9 @@ class LSTMCore(_L):
if x.ndim != 3:
raise NotImplementedError("Need (B,L,D)")
last_hidden_state = (
- self._initial_state(None if x.ndim == 2 else len(x))
- if hidden_state is None else hidden_state
+ self._initial_state(None if x.ndim == 2 else len(x))
+ if hidden_state is None
+ else hidden_state
)
if not self.training or self._train_truncate is None:
output_features = super().forward(x, last_hidden_state)[0]
@@ -95,21 +103,23 @@ class LSTMCore(_L):
# Don't detach the burn-in state so that we can learn it.
last_hidden_state = tuple(z.detach() for z in last_hidden_state)
last_output_features, last_hidden_state = super().forward(
- x[:, i : i + self._train_truncate, :,],
- last_hidden_state,
+ x[:, i : i + self._train_truncate, :], last_hidden_state
)
output_features_list.append(last_output_features)
output_features = torch.cat(output_features_list, dim=1)
return output_features
- def _initial_state(self, n: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
- return (self._initial_hidden, self._initial_cell) if n is None else (
- torch.tile(self._initial_hidden[:, None], (1, n, 1)),
- torch.tile(self._initial_cell[:, None], (1, n, 1))
+ def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType:
+ return (
+ (self._initial_hidden, self._initial_cell)
+ if n is None
+ else (
+ torch.tile(self._initial_hidden[:, None], (1, n, 1)),
+ torch.tile(self._initial_cell[:, None], (1, n, 1)),
+ )
)
-
class LSTM(BaseNet):
"""
ABC for recurrent architectures
@@ -125,9 +135,9 @@ class LSTM(BaseNet):
):
"""
:param hidden_size: for LSTM
- :param train_burn_in: Detach calculations from first (this many) samples when
+ :param train_burn_in: Detach calculations from first (this many) samples when
training to burn in the hidden state.
- :param train_truncate: detach the hidden & cell states every this many steps
+ :param train_truncate: detach the hidden & cell states every this many steps
during training so that backpropagation through time is faster + to simulate
better starting states for h(t0)&c(t0) (instead of zeros)
TODO recognition head to start the hidden state in a good place?
@@ -138,9 +148,7 @@ class LSTM(BaseNet):
if "batch_first" in lstm_kwargs:
raise ValueError("batch_first cannot be set.")
self._input_size = input_size
- self._core = _L(
- self._input_size, hidden_size, batch_first=True, **lstm_kwargs
- )
+ self._core = _L(self._input_size, hidden_size, batch_first=True, **lstm_kwargs)
self._head = nn.Linear(hidden_size, 1)
self._train_burn_in = train_burn_in
self._train_truncate = train_truncate
@@ -164,7 +172,7 @@ class LSTM(BaseNet):
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
LSTM.export(self, Path(tmpdir)) # Hacky...need to work w/ CatLSTM
- with open(Path(tmpdir, "config.json"), "r") as fp:
+ with open(Path(tmpdir, "model.nam"), "r") as fp:
_c = json.load(fp)
version = _c["version"]
config = _c["config"]
@@ -185,20 +193,61 @@ class LSTM(BaseNet):
+ s_parametric
+ (
"std::vector<float> PARAMS{"
- + ", ".join(
- [f"{w:.16f}f" for w in np.load(Path(tmpdir, "weights.npy"))]
- )
+ + ", ".join([f"{w:.16f}f" for w in _c["weights"]])
+ "};\n",
)
)
- def _forward(self, x):
+ def export_onnx(self, filename: Path):
+ if self._input_size != 1:
+ raise NotImplementedError("Multi-dimensional inputs not supported yet")
+ o = _ONNXWrapped(self)
+ x = torch.randn((64,)) # (S,)
+ h, c = [z[:, 0, :] for z in self._initial_state(1)] # (L,DH), (L,DH)
+ torch.onnx.export(
+ o,
+ (x, h, c),
+ filename,
+ input_names=["x", "hin", "cin"],
+ output_names=["y", "hout", "cout"],
+ dynamic_axes={"x": {0: "num_frames"}, "y": {0: "num_frames"}},
+ )
+
+ def forward_onnx(
+ self, x: torch.Tensor, h: _LSTMHiddenType, c: _LSTMCellType
+ ) -> Tuple[torch.Tensor, _LSTMHiddenType, _LSTMCellType]:
+ """
+ Forward pass used by ONNX export
+ Only supports scalar inputs right now.
+
+ N: Sequeence length
+ L: Nubmer of layers
+ DH: Hidden state dimension
+
+ :param x: (N,)
+ :param state: (L, DH)
+ :param cell: (L, DH)
+
+ :return: (N,), (L, DH), (L, DH)
+ """
+ features, (h, c) = self._core(x[None, :, None], (h[:, None, :], c[:, None, :]))
+ y = self._apply_head(features) # (1,S)
+ return y[0, :], h[:, 0, :], c[:, 0, :]
+
+ def _apply_head(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ :param features: (B,S,DH)
+ :return: (B,S)
+ """
+ return self._head(features)[:, :, 0]
+
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: (B,L) or (B,L,D)
:return: (B,L)
"""
last_hidden_state = self._initial_state(len(x))
- if x.ndim==2:
+ if x.ndim == 2:
x = x[:, :, None]
if not self.training or self._train_truncate is None:
output_features = self._core(x, last_hidden_state)[0]
@@ -215,12 +264,11 @@ class LSTM(BaseNet):
# Don't detach the burn-in state so that we can learn it.
last_hidden_state = tuple(z.detach() for z in last_hidden_state)
last_output_features, last_hidden_state = self._core(
- x[:, i : i + self._train_truncate, :,],
- last_hidden_state,
+ x[:, i : i + self._train_truncate, :], last_hidden_state
)
output_features_list.append(last_output_features)
output_features = torch.cat(output_features_list, dim=1)
- return self._head(output_features)[:, :, 0]
+ return self._apply_head(output_features)
def _export_cell_weights(
self, i: int, hidden_state: torch.Tensor, cell_state: torch.Tensor
@@ -281,7 +329,7 @@ class LSTM(BaseNet):
]
)
- def _get_initial_state(self, inputs=None) -> Tuple[torch.Tensor, torch.Tensor]:
+ def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType:
"""
Convenience function to find a good hidden state to start the plugin at
@@ -296,31 +344,60 @@ class LSTM(BaseNet):
_, (h, c) = self._core(inputs)
return h, c
- def _initial_state(self, n: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
+ def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType:
"""
Literally what the forward pass starts with.
Default is zeroes; this should be better since it can be learned.
"""
- return (self._initial_hidden, self._initial_cell) if n is None else (
- torch.tile(self._initial_hidden[:, None], (1, n, 1)),
- torch.tile(self._initial_cell[:, None], (1, n, 1))
+ return (
+ (self._initial_hidden, self._initial_cell)
+ if n is None
+ else (
+ torch.tile(self._initial_hidden[:, None], (1, n, 1)),
+ torch.tile(self._initial_cell[:, None], (1, n, 1)),
+ )
)
+class _ONNXWrapped(nn.Module):
+ def __init__(self, net: LSTM):
+ super().__init__()
+ self._net = net
+
+ def forward(
+ self, x: torch.Tensor, hidden: _LSTMHiddenType, cell: _LSTMCellType
+ ) -> Tuple[torch.Tensor, _LSTMHiddenType, _LSTMCellType]:
+ """
+ N: Sequeence length
+ L: Nubmer of layers
+ DH: Hidden state dimension
+
+ :param x: (N,)
+ :param state: (L, DH)
+ :param cell: (L, DH)
+
+ :return: (N,), (L, DH), (L, DH)
+ """
+ return self._net.forward_onnx(x, hidden, cell)
+
+
# TODO refactor together
+
class _SkippyLSTM(nn.Module):
- def __init__(self, input_size, hidden_size, skip_in: bool=False, num_layers=1, **kwargs):
+ def __init__(
+ self, input_size, hidden_size, skip_in: bool = False, num_layers=1, **kwargs
+ ):
super().__init__()
layers_per_lstm = 1
self._skip_in = skip_in
self._lstms = nn.ModuleList(
[
_L(
- self._layer_input_size(input_size, hidden_size, i),
- hidden_size,
- layers_per_lstm,
- batch_first=True
+ self._layer_input_size(input_size, hidden_size, i),
+ hidden_size,
+ layers_per_lstm,
+ batch_first=True,
)
for i in range(num_layers)
]
@@ -329,9 +406,9 @@ class _SkippyLSTM(nn.Module):
torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size))
)
self._initial_cell = nn.Parameter(
- torch.zeros((self.num_layers, layers_per_lstm , self.hidden_size))
+ torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size))
)
-
+
@property
def hidden_size(self):
return self._lstms[0].hidden_size
@@ -360,16 +437,16 @@ class _SkippyLSTM(nn.Module):
for layer, h0i, c0i in zip(self._lstms, h0, c0):
if self._skip_in:
# TODO dense-block
- layer_input = input if hidden is None else torch.cat([input, hidden], dim=2)
+ layer_input = (
+ input if hidden is None else torch.cat([input, hidden], dim=2)
+ )
else:
layer_input = input if hidden is None else hidden
hidden, (hi, ci) = layer(layer_input, (h0i, c0i))
hiddens.append(hidden)
h_arr.append(hi)
c_arr.append(ci)
- return (
- torch.cat(hiddens, dim=2), (torch.stack(h_arr), torch.stack(c_arr))
- )
+ return (torch.cat(hiddens, dim=2), (torch.stack(h_arr), torch.stack(c_arr)))
def initial_state(self, input: torch.Tensor):
"""
@@ -377,11 +454,11 @@ class _SkippyLSTM(nn.Module):
:return: (L,B,Li,DH)
"""
- assert input.ndim==3, "Batch only for now"
+ assert input.ndim == 3, "Batch only for now"
batch_size = len(input) # Assume batch_first
return (
torch.tile(self._initial_hidden[:, :, None], (1, 1, batch_size, 1)),
- torch.tile(self._initial_cell[:, :, None], (1, 1, batch_size, 1))
+ torch.tile(self._initial_cell[:, :, None], (1, 1, batch_size, 1)),
)
def _layer_input_size(self, input_size, hidden_size, i) -> int:
diff --git a/requirements.txt b/requirements.txt
@@ -6,6 +6,8 @@ black
flake8
matplotlib
numpy
+onnx
+onnxruntime
pip
pre-commit
pytest
diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py
@@ -2,18 +2,57 @@
# Created Date: Sunday July 17th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
-from .base import Base
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import onnx
+import onnxruntime
+import pytest
+import torch
from nam.models import recurrent
+from .base import Base
+
class TestLSTM(Base):
@classmethod
def setup_class(cls):
+ num_layers = 2
hidden_size = 3
- return super().setup_class(
+ super().setup_class(
recurrent.LSTM,
args=(hidden_size,),
- kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": 2},
+ kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": num_layers},
)
+ cls._num_layers = num_layers
+ cls._hidden_size = hidden_size
+
+ def test_export_onnx(self):
+ model = self._construct()
+ with TemporaryDirectory() as tmpdir:
+ filename = Path(tmpdir, "model.onnx")
+ model.export_onnx(filename)
+ onnx_model = onnx.load(filename)
+ session = onnxruntime.InferenceSession(str(filename))
+ onnx.checker.check_model(onnx_model)
+ wrapped_model = recurrent._ONNXWrapped(model)
+ x = torch.Tensor([0.5, -0.5, 0.4, -0.4, 0.3, -0.3, 0.2])
+ hin = torch.zeros((self._num_layers, self._hidden_size))
+ cin = torch.zeros((self._num_layers, self._hidden_size))
+
+ with torch.no_grad():
+ y_expected, hout_expected, cout_expected = [
+ z.detach().cpu().numpy() for z in wrapped_model(x, hin, cin)
+ ]
+
+ input_names = [z.name for z in session.get_inputs()]
+ onnx_inputs = {i: z.detach().cpu().numpy() for i, z in zip(input_names, (x, hin, cin))}
+ y_actual, hout_actual, cout_actual = session.run([], onnx_inputs)
+ def approx(val):
+ return pytest.approx(val, rel=1.0e-6, abs=1.0e-6)
+
+ assert y_expected == approx(y_actual)
+ assert hout_expected == approx(hout_actual)
+ assert cout_expected == approx(cout_actual)