commit bdeed78ff1a52ea9c639f03dc6ef1d316ea47a4b
parent aec1c69bc26d20aa64b65463c11894eed03ec87b
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 14 Jan 2024 12:11:17 -0800
[FEATURE] start and stop trimming for datasets in seconds instead of samples (#363)
* Add more start/stop options for datasets
* Simplify nam.data.Dataset.parse_config()
Diffstat:
2 files changed, 167 insertions(+), 40 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -229,6 +229,10 @@ class Dataset(AbstractDataset, InitializableFromConfig):
ny: Optional[int],
start: Optional[int] = None,
stop: Optional[int] = None,
+ start_samples: Optional[int] = None,
+ stop_samples: Optional[int] = None,
+ start_seconds: Optional[Union[int, float]] = None,
+ stop_seconds: Optional[Union[int, float]] = None,
delay: Optional[Union[int, float]] = None,
delay_interpolation_method: Union[
str, _DelayInterpolationMethod
@@ -252,8 +256,18 @@ class Dataset(AbstractDataset, InitializableFromConfig):
shouldn't be too large or else you won't be able to provide a large batch
size (where each input-output pair could be something substantially
different and improve batch diversity).
- :param start: In samples; clip x and y up to this point.
- :param stop: In samples; clip x and y past this point.
+ :param start: [DEPRECATED; use start_samples instead.] In samples; clip x and y
+ at this point. Negative values are taken from the end of the audio.
+ :param stop: [DEPRECATED; use stop_samples instead.] In samples; clip x and y at
+ this point. Negative values are taken from the end of the audio.
+ :param start_samples: Clip x and y at this point. Negative values are taken from
+ the end of the audio.
+ :param stop: Clip x and y at this point. Negative values are taken from the end
+ of the audio.
+ :param start_seconds: Clip x and y at this point. Negative values are taken from
+ the end of the audio. Requires providing `sample_rate`.
+ :param stop_seconds: Clip x and y at this point. Negative values are taken from
+ the end of the audio. Requires providing `sample_rate`.
:param delay: In samples. Positive means we get rid of the start of x, end of y
(i.e. we are correcting for an alignment error in which y is delayed behind
x). If a non-integer delay is provided, then y is interpolated, with
@@ -275,8 +289,20 @@ class Dataset(AbstractDataset, InitializableFromConfig):
into the data set that we're trying to use. If `None`, don't assert.
"""
self._validate_x_y(x, y)
- self._validate_start_stop(x, y, start, stop)
- self._sample_rate = self._validate_sample_rate(sample_rate, rate)
+ self._sample_rate = self._validate_sample_rate(
+ sample_rate, rate, default=_DEFAULT_RATE
+ )
+ start, stop = self._validate_start_stop(
+ x,
+ y,
+ start,
+ stop,
+ start_samples,
+ stop_samples,
+ start_seconds,
+ stop_seconds,
+ self._sample_rate,
+ )
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
@@ -349,15 +375,17 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
- x, x_wavinfo = wav_to_tensor(
- config["x_path"], info=True, rate=config.get("rate")
+ config = deepcopy(config)
+ sample_rate = cls._validate_sample_rate(
+ config.pop("sample_rate", None), config.pop("rate", None)
)
- rate = x_wavinfo.rate
+ x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
+ sample_rate = x_wavinfo.rate
try:
y = wav_to_tensor(
- config["y_path"],
- rate=rate,
- preroll=config.get("y_preroll"),
+ config.pop("y_path"),
+ rate=sample_rate,
+ preroll=config.pop("y_preroll", None),
required_shape=(len(x), 1),
required_wavinfo=x_wavinfo,
)
@@ -367,8 +395,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
y_samples, y_channels = e.shape_actual
msg = "Your audio files aren't the same shape as each other!"
if x_channels != y_channels:
- ctosm = {1: "mono", 2: "stereo"}
- msg += f"\n * The input is {ctosm[x_channels]}, but the output is {ctosm[y_channels]}!"
+ channels_to_stereo_mono = {1: "mono", 2: "stereo"}
+ msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!"
if x_samples != y_samples:
def sample_to_time(s, rate):
@@ -387,28 +415,14 @@ class Dataset(AbstractDataset, InitializableFromConfig):
f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples"
)
- msg += f"\n * The input is {sample_to_time(x_samples, rate)} long"
- msg += f"\n * The output is {sample_to_time(y_samples, rate)} long"
+ msg += (
+ f"\n * The input is {sample_to_time(x_samples, sample_rate)} long"
+ )
+ msg += (
+ f"\n * The output is {sample_to_time(y_samples, sample_rate)} long"
+ )
raise ValueError(msg)
- return {
- "x": x,
- "y": y,
- "nx": config["nx"],
- "ny": config["ny"],
- "start": config.get("start"),
- "stop": config.get("stop"),
- "delay": config.get("delay"),
- "delay_interpolation_method": config.get(
- "delay_interpolation_method", _DelayInterpolationMethod.CUBIC.value
- ),
- "y_scale": config.get("y_scale", 1.0),
- "x_path": config["x_path"],
- "y_path": config["y_path"],
- "sample_rate": rate,
- "require_input_pre_silence": config.get(
- "require_input_pre_silence", _DEFAULT_REQUIRE_INPUT_PRE_SILENCE
- ),
- }
+ return {"x": x, "y": y, "sample_rate": sample_rate, **config}
@classmethod
def _apply_delay(
@@ -456,10 +470,10 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _validate_sample_rate(
- cls, sample_rate: Optional[float], rate: Optional[int]
+ cls, sample_rate: Optional[float], rate: Optional[int], default=None
) -> float:
if sample_rate is None and rate is None: # Default value
- return _DEFAULT_RATE
+ return default
if rate is not None:
if sample_rate is not None:
raise ValueError(
@@ -475,23 +489,79 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _validate_start_stop(
- self,
+ cls,
x: torch.Tensor,
y: torch.Tensor,
start: Optional[int] = None,
stop: Optional[int] = None,
- ):
+ start_samples: Optional[int] = None,
+ stop_samples: Optional[int] = None,
+ start_seconds: Optional[Union[int, float]] = None,
+ stop_seconds: Optional[Union[int, float]] = None,
+ sample_rate: Optional[int] = None,
+ ) -> Tuple[Optional[int], Optional[int]]:
"""
- Check for potential input errors.
+ Parse the requested start and stop trim points.
These may be valid indices in Python, but probably point to invalid usage, so
we will raise an exception if something fishy is going on (e.g. starting after
the end of the file, etc)
+
+ :return: parsed start/stop (if valid).
"""
+
+ def parse_start_stop(s, samples, seconds, rate):
+ # Assumes validated inputs
+ if s is not None:
+ return s
+ if samples is not None:
+ return samples
+ if seconds is not None:
+ return int(seconds * rate)
+ # else
+ return None
+
+ # Resolve different ways of asking for start/stop...
+ if start is not None:
+ logger.warning("Using `start` is deprecated; use `start_samples` instead.")
+ if start is not None:
+ logger.warning("Using `stop` is deprecated; use `start_samples` instead.")
+ if (
+ int(start is not None)
+ + int(start_samples is not None)
+ + int(start_seconds is not None)
+ >= 2
+ ):
+ raise ValueError(
+ "More than one start provided. Use only one of `start`, `start_samples`, or `start_seconds`!"
+ )
+ if (
+ int(stop is not None)
+ + int(stop_samples is not None)
+ + int(stop_seconds is not None)
+ >= 2
+ ):
+ raise ValueError(
+ "More than one stop provided. Use only one of `stop`, `stop_samples`, or `stop_seconds`!"
+ )
+ if start_seconds is not None and sample_rate is None:
+ raise ValueError(
+ "Provided `start_seconds` without sample rate; cannot resolve into samples!"
+ )
+ if stop_seconds is not None and sample_rate is None:
+ raise ValueError(
+ "Provided `stop_seconds` without sample rate; cannot resolve into samples!"
+ )
+
+ # By this point, we should have a valid, unambiguous way of asking.
+ start = parse_start_stop(start, start_samples, start_seconds, sample_rate)
+ stop = parse_start_stop(stop, stop_samples, stop_seconds, sample_rate)
+ # And only use start/stop from this point.
+
# We could do this whole thing with `if len(x[start: stop]==0`, but being more
# explicit makes the error messages better for users.
if start is None and stop is None:
- return
+ return start, stop
if len(x) != len(y):
raise ValueError(
f"Input and output are different length. Input has {len(x)} samples, "
@@ -530,6 +600,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
f"Array length {n} with start={start} and stop={stop} would get "
"rid of all of the data!"
)
+ return start, stop
@classmethod
def _validate_x_y(self, x, y):
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -7,7 +7,7 @@ import os
from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import Tuple
+from typing import Optional, Tuple, Union
import numpy as np
import pytest
@@ -170,6 +170,62 @@ class TestDataset(object):
init()
@pytest.mark.parametrize(
+ "start,start_samples,start_seconds,stop,stop_samples,stop_seconds,sample_rate,raises",
+ (
+ # Nones across the board (valid)
+ (None, None, None, None, None, None, None, None),
+ # start and stop (valid)
+ (1, None, None, -1, None, None, None, None),
+ # start_samples and stop_samples (valid)
+ (None, 1, None, None, -1, None, None, None),
+ # start_seconds and stop_seconds with sample_rate (valid)
+ (None, None, 0.5, None, None, -0.5, 2, None),
+ # Multiple start-like, even if they agree (invalid)
+ (1, 1, None, None, None, None, None, ValueError),
+ # Multiple stop-like, even if they agree (invalid)
+ (None, None, None, -1, -1, None, None, ValueError),
+ # seconds w/o sample rate (invalid)
+ (None, None, 1.0, None, None, None, None, ValueError),
+ ),
+ )
+ def test_validate_start_stop(
+ self,
+ start: Optional[int],
+ start_samples: Optional[int],
+ start_seconds: Optional[Union[int, float]],
+ stop: Optional[int],
+ stop_samples: Optional[int],
+ stop_seconds: Optional[Union[int, float]],
+ sample_rate: Optional[int],
+ raises: Optional[Exception],
+ ):
+ """
+ Assert correct behavior of `._validate_start_stop()` class method.
+ """
+
+ def f():
+ # Don't provide start/stop that are too large for the fake data plz.
+ x, y = torch.zeros((2, 32))
+ data.Dataset._validate_start_stop(
+ x,
+ y,
+ start,
+ stop,
+ start_samples,
+ stop_samples,
+ start_seconds,
+ stop_seconds,
+ sample_rate,
+ )
+ assert True
+
+ if raises is None:
+ f()
+ else:
+ with pytest.raises(raises):
+ f()
+
+ @pytest.mark.parametrize(
"n,stop,valid",
(
(13, None, True), # No stop restrictions; nothing wrong