commit 90c2add4b399d8088c788d512f73baf02288e8bb
parent bd4dab08e65c1da803d0d6d18aac7cfb516dd96f
Author: Steven Atkinson <steven@atkinson.mn>
Date: Thu, 9 Nov 2023 18:19:53 -0800
[FEATURE] Dataset class registry (#339)
Dataset class registry
Diffstat:
2 files changed, 77 insertions(+), 13 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -9,7 +9,7 @@ from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
-from typing import Dict, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@@ -679,11 +679,7 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
- init = (
- ParametricDataset.init_from_config
- if config["parametric"]
- else Dataset.init_from_config
- )
+ init = _dataset_init_registry[config.get("type", "dataset")]
return {
"datasets": tuple(
init(c) for c in tqdm(config["dataset_configs"], desc="Loading data")
@@ -750,21 +746,55 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
)
+_dataset_init_registry = {
+ "dataset": Dataset.init_from_config,
+ "parametric": ParametricDataset.init_from_config, # To be removed in v0.8
+}
+
+
+def register_dataset_initializer(
+ name: str, constructor: Callable[[Any], AbstractDataset]
+):
+ """
+ If you have otehr data set types, you can register their initializer by name using
+ this.
+
+ For example, the basic NAM is registered by default under the name "default", but if
+ it weren't, you could register it like this:
+
+ >>> from nam import data
+ >>> data.register_dataset_initializer("parametric", data.Dataset.init_from_config)
+
+ :param name: The name that'll be used in the config to ask for the data set type
+ :param constructor: The constructor that'll be fed the config.
+ """
+ if name in _dataset_init_registry:
+ raise KeyError(
+ f"A constructor for dataset name '{name}' is already registered!"
+ )
+ _dataset_init_registry[name] = constructor
+
+
def init_dataset(config, split: Split) -> AbstractDataset:
- parametric = config.get("parametric", False)
+ if "parametric" in config:
+ logger.warning(
+ "Using the 'parametric' keyword is deprecated and will be removed in next "
+ "version. Instead, register the parametric dataset type using "
+ "`nam.data.register_dataset_initializer()` and then specify "
+ '`"type": "name"` in the config, using the name you registered.'
+ )
+ name = "parametric" if config["parametric"] else "dataset"
+ else:
+ name = config.get("type", "dataset")
base_config = config[split.value]
common = config.get("common", {})
if isinstance(base_config, dict):
- init = (
- ParametricDataset.init_from_config
- if parametric
- else Dataset.init_from_config
- )
+ init = _dataset_init_registry[name]
return init({**common, **base_config})
elif isinstance(base_config, list):
return ConcatDataset.init_from_config(
{
- "parametric": parametric,
+ "type": name,
"dataset_configs": [{**common, **c} for c in base_config],
}
)
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -352,5 +352,39 @@ def test_audio_mismatch_shapes_in_order():
assert e.shape_actual == (y_samples, num_channels)
+def test_register_dataset_initializer():
+ """
+ Assert that you can add and use new data sets
+ """
+
+ class MyDataset(data.Dataset):
+ pass
+
+ name = "my_dataset"
+
+ data.register_dataset_initializer(name, MyDataset.init_from_config)
+
+ x = np.random.rand(32) - 0.5
+ y = x
+ split = data.Split.TRAIN
+
+ with TemporaryDirectory() as tmpdir:
+ x_path = Path(tmpdir, "x.wav")
+ y_path = Path(tmpdir, "y.wav")
+ data.np_to_wav(x, x_path)
+ data.np_to_wav(y, y_path)
+ config = {
+ "type": name,
+ split.value: {
+ "x_path": str(x_path),
+ "y_path": str(y_path),
+ "nx": 3,
+ "ny": 2,
+ },
+ }
+ dataset = data.init_dataset(config, split)
+ assert isinstance(dataset, MyDataset)
+
+
if __name__ == "__main__":
pytest.main()