neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 90c2add4b399d8088c788d512f73baf02288e8bb
parent bd4dab08e65c1da803d0d6d18aac7cfb516dd96f
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Thu,  9 Nov 2023 18:19:53 -0800

[FEATURE] Dataset class registry (#339)

Dataset class registry
Diffstat:
Mnam/data.py | 56+++++++++++++++++++++++++++++++++++++++++++-------------
Mtests/test_nam/test_data.py | 34++++++++++++++++++++++++++++++++++
2 files changed, 77 insertions(+), 13 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -9,7 +9,7 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -679,11 +679,7 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig): @classmethod def parse_config(cls, config): - init = ( - ParametricDataset.init_from_config - if config["parametric"] - else Dataset.init_from_config - ) + init = _dataset_init_registry[config.get("type", "dataset")] return { "datasets": tuple( init(c) for c in tqdm(config["dataset_configs"], desc="Loading data") @@ -750,21 +746,55 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig): ) +_dataset_init_registry = { + "dataset": Dataset.init_from_config, + "parametric": ParametricDataset.init_from_config, # To be removed in v0.8 +} + + +def register_dataset_initializer( + name: str, constructor: Callable[[Any], AbstractDataset] +): + """ + If you have otehr data set types, you can register their initializer by name using + this. + + For example, the basic NAM is registered by default under the name "default", but if + it weren't, you could register it like this: + + >>> from nam import data + >>> data.register_dataset_initializer("parametric", data.Dataset.init_from_config) + + :param name: The name that'll be used in the config to ask for the data set type + :param constructor: The constructor that'll be fed the config. + """ + if name in _dataset_init_registry: + raise KeyError( + f"A constructor for dataset name '{name}' is already registered!" + ) + _dataset_init_registry[name] = constructor + + def init_dataset(config, split: Split) -> AbstractDataset: - parametric = config.get("parametric", False) + if "parametric" in config: + logger.warning( + "Using the 'parametric' keyword is deprecated and will be removed in next " + "version. Instead, register the parametric dataset type using " + "`nam.data.register_dataset_initializer()` and then specify " + '`"type": "name"` in the config, using the name you registered.' + ) + name = "parametric" if config["parametric"] else "dataset" + else: + name = config.get("type", "dataset") base_config = config[split.value] common = config.get("common", {}) if isinstance(base_config, dict): - init = ( - ParametricDataset.init_from_config - if parametric - else Dataset.init_from_config - ) + init = _dataset_init_registry[name] return init({**common, **base_config}) elif isinstance(base_config, list): return ConcatDataset.init_from_config( { - "parametric": parametric, + "type": name, "dataset_configs": [{**common, **c} for c in base_config], } ) diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -352,5 +352,39 @@ def test_audio_mismatch_shapes_in_order(): assert e.shape_actual == (y_samples, num_channels) +def test_register_dataset_initializer(): + """ + Assert that you can add and use new data sets + """ + + class MyDataset(data.Dataset): + pass + + name = "my_dataset" + + data.register_dataset_initializer(name, MyDataset.init_from_config) + + x = np.random.rand(32) - 0.5 + y = x + split = data.Split.TRAIN + + with TemporaryDirectory() as tmpdir: + x_path = Path(tmpdir, "x.wav") + y_path = Path(tmpdir, "y.wav") + data.np_to_wav(x, x_path) + data.np_to_wav(y, y_path) + config = { + "type": name, + split.value: { + "x_path": str(x_path), + "y_path": str(y_path), + "nx": 3, + "ny": 2, + }, + } + dataset = data.init_dataset(config, split) + assert isinstance(dataset, MyDataset) + + if __name__ == "__main__": pytest.main()