commit dde0baeb3b51cea4ba6fb5941d899f3448c6ed5f
parent 4e97b46cc24417a550d92a82bce2e1d012f10ed5
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 30 Apr 2023 11:52:58 -0700
Fix model device handling (#225)
* Models to CPU
* Update GPU kwargs
* End-to-end tests of CLI trainer
Diffstat:
14 files changed, 246 insertions(+), 4 deletions(-)
diff --git a/bin/export/main.py b/bin/export/main.py
@@ -38,6 +38,7 @@ def main(args):
k: Param.init_from_config(v) for k, v in json.load(fp).items()
}
export_args = (outdir, param_config)
+ net.cpu()
net.eval()
outdir.mkdir(parents=True, exist_ok=True)
net.export(*export_args, include_snapshot=args.snapshot)
diff --git a/bin/train/inputs/data/single_pair.json b/bin/train/inputs/data/single_pair.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"train": {
"start": null,
"stop": -432000,
diff --git a/bin/train/inputs/data/two_pairs.json b/bin/train/inputs/data/two_pairs.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"train": {
"x_path": "C:\\path\\to\\train\\source.wav",
"y_path": "C:\\path\\to\\train\\target.wav",
diff --git a/bin/train/inputs/learning/default.json b/bin/train/inputs/learning/default.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"train_dataloader": {
"batch_size": 16,
"shuffle": true,
@@ -8,7 +12,8 @@
},
"val_dataloader": {},
"trainer": {
- "gpus": 1,
+ "accelerator": "gpu",
+ "devices": 1,
"max_epochs": 100
},
"trainer_fit_kwargs": {}
diff --git a/bin/train/inputs/learning/demo.json b/bin/train/inputs/learning/demo.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"train_dataloader": {
"batch_size": 16,
"shuffle": true,
@@ -8,7 +12,8 @@
},
"val_dataloader": {},
"trainer": {
- "gpus": 0,
+ "accelerator": "gpu",
+ "devices": 1,
"max_epochs": 10
},
"trainer_fit_kwargs": {}
diff --git a/bin/train/inputs/models/catlstm.json b/bin/train/inputs/models/catlstm.json
@@ -7,7 +7,10 @@
" non-parametric version, even if you're modeling a fair number of knobs.",
" * You'll probably have a much larger dataset, so validating every so often ",
" in steps instead of epochs helps. Make sure to also set val_check_interval",
- " under the trainer dict in your learning config JSON."
+ " under the trainer dict in your learning config JSON.",
+ "",
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
],
"net": {
"name": "CatLSTM",
diff --git a/bin/train/inputs/models/convnet.json b/bin/train/inputs/models/convnet.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"net": {
"name": "ConvNet",
"config": {
diff --git a/bin/train/inputs/models/demonet.json b/bin/train/inputs/models/demonet.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"net": {
"name": "WaveNet",
"config": {
diff --git a/bin/train/inputs/models/lstm.json b/bin/train/inputs/models/lstm.json
@@ -9,7 +9,10 @@
" 1e-4 after 1000 epochs. I've found LSTMs to work with a pretty aggressive",
" learning rate that would be out of the question for other architectures.",
" * Number of units between 8 and 96, layers from 1 to 5 all seem to be ok",
- " depending on the dataset, though bigger models might not make real-time."
+ " depending on the dataset, though bigger models might not make real-time.",
+ "",
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
],
"net": {
"name": "LSTM",
diff --git a/bin/train/inputs/models/wavenet.json b/bin/train/inputs/models/wavenet.json
@@ -1,4 +1,8 @@
{
+ "_notes": [
+ "Dev note: Ensure that tests/test_bin/test_train/test_main.py's data is ",
+ "representative of this!"
+ ],
"net": {
"name": "WaveNet",
"config": {
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -2,6 +2,7 @@
# Created Date: Saturday February 5th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
+
# Hack to recover graceful shutdowns in Windows.
# This has to happen ASAP
# See:
@@ -201,6 +202,7 @@ def main_inner(
trainer.checkpoint_callback.best_model_path,
**Model.parse_config(model_config),
)
+ model.cpu()
model.eval()
if make_plots:
plot(
diff --git a/tests/test_bin/__init__.py b/tests/test_bin/__init__.py
diff --git a/tests/test_bin/test_train/__init__.py b/tests/test_bin/test_train/__init__.py
diff --git a/tests/test_bin/test_train/test_main.py b/tests/test_bin/test_train/test_main.py
@@ -0,0 +1,203 @@
+# File: test_main.py
+# Created Date: Sunday April 30th 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+import json
+from enum import Enum
+from pathlib import Path
+from subprocess import check_call
+from tempfile import TemporaryDirectory
+from typing import Dict, Tuple, Union
+
+import numpy as np
+import pytest
+import torch
+
+from nam.data import np_to_wav
+
+_BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path(
+ "bin", "train", "main.py"
+)
+
+
+class _Device(Enum):
+ CPU = "cpu"
+ GPU = "gpu"
+ MPS = "mps"
+
+
+class Test(object):
+ @classmethod
+ def setup_class(cls):
+ cls._num_samples = 128
+ cls._num_samples_validation = 15
+ cls._ny = 2
+ cls._batch_size = 2
+
+ def test_cpu(self):
+ self._t_main(_Device.CPU)
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU test")
+ def test_gpu(self):
+ self._t_main(_Device.GPU)
+
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS test")
+ def test_mps(self):
+ self._t_main(_Device.MPS)
+
+ @classmethod
+ def _data_config_path(cls, root_path: Path) -> Path:
+ return Path(cls._input_path(root_path), "data_config.json")
+
+ def _get_configs(
+ self, root_path: Path, device: _Device
+ ) -> Tuple[Dict, Dict, Dict]: # TODO pydantic models
+ data_config = {
+ "train": {
+ "start": None,
+ "stop": -self._num_samples_validation,
+ "ny": self._ny,
+ },
+ "validation": {
+ "start": -self._num_samples_validation,
+ "stop": None,
+ "ny": None,
+ },
+ "common": {
+ "x_path": str(self._x_path(root_path)),
+ "y_path": str(self._y_path(root_path)),
+ "delay": 0,
+ },
+ }
+ stage_channels = (3, 2)
+ model_config = {
+ "net": {
+ "name": "WaveNet",
+ "config": {
+ "layers_configs": [
+ {
+ "condition_size": 1,
+ "input_size": 1,
+ "channels": stage_channels[0],
+ "head_size": stage_channels[1],
+ "kernel_size": 3,
+ "dilations": [1],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": stage_channels[0],
+ "channels": stage_channels[1],
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [2],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ },
+ "optimizer": {"lr": 0.004},
+ "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.993}},
+ }
+
+ def extra_trainer_kwargs(device) -> Dict[str, Union[int, str]]:
+ return {
+ _Device.GPU: {"accelerator": "gpu", "devices": 1},
+ _Device.MPS: {"accelerator": "mps", "devices": 1},
+ }.get(device, {})
+
+ learning_config = {
+ "train_dataloader": {
+ "batch_size": 3,
+ "shuffle": True,
+ "pin_memory": True,
+ "drop_last": True,
+ "num_workers": 0,
+ },
+ "val_dataloader": {},
+ "trainer": {"max_epochs": 2, **extra_trainer_kwargs(device)},
+ "trainer_fit_kwargs": {},
+ }
+
+ return data_config, model_config, learning_config
+
+ def _get_data(self) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ :return: (N,), (N,)
+ """
+ x = np.random.rand(self._num_samples) - 0.5
+ y = 1.1 * x
+ return x, y
+
+ @classmethod
+ def _input_path(cls, root_path: Path, ensure: bool = False) -> Path:
+ p = Path(root_path, "inputs")
+ if ensure:
+ p.mkdir()
+ return p
+
+ @classmethod
+ def _learning_config_path(cls, root_path: Path) -> Path:
+ return Path(cls._input_path(root_path), "learning_config.json")
+
+ @classmethod
+ def _model_config_path(cls, root_path: Path) -> Path:
+ return Path(cls._input_path(root_path), "model_config.json")
+
+ @classmethod
+ def _output_path(cls, root_path: Path, ensure: bool = False) -> Path:
+ p = Path(root_path, "outputs")
+ if ensure:
+ p.mkdir()
+ return p
+
+ def _setup_files(self, root_path: Path, device: _Device):
+ x, y = self._get_data()
+ np_to_wav(x, self._x_path(root_path))
+ np_to_wav(y, self._y_path(root_path))
+ data_config, model_config, learning_config = self._get_configs(
+ root_path, device
+ )
+ with open(self._data_config_path(root_path), "w") as fp:
+ json.dump(data_config, fp)
+ with open(self._model_config_path(root_path), "w") as fp:
+ json.dump(model_config, fp)
+ with open(self._learning_config_path(root_path), "w") as fp:
+ json.dump(learning_config, fp)
+
+ def _t_main(self, device: _Device):
+ """
+ End-to-end test of bin/train/main.py
+ """
+ with TemporaryDirectory() as tempdir:
+ tempdir = Path(tempdir)
+ self._input_path(tempdir, ensure=True)
+ self._setup_files(tempdir, device)
+ check_call(
+ [
+ "python",
+ str(_BIN_TRAIN_MAIN_PY_PATH),
+ str(self._data_config_path(tempdir)),
+ str(self._model_config_path(tempdir)),
+ str(self._learning_config_path(tempdir)),
+ str(self._output_path(tempdir, ensure=True)),
+ "--no-show",
+ ]
+ )
+
+ @classmethod
+ def _x_path(cls, root_path: Path) -> Path:
+ return Path(cls._input_path(root_path), "input.wav")
+
+ @classmethod
+ def _y_path(cls, root_path: Path) -> Path:
+ return Path(cls._input_path(root_path), "output.wav")
+
+
+if __name__ == "__main__":
+ pytest.main()