commit 35494ade96c6f9819bbf78064fb907df84895092
parent 64f79fdc4c68b8735b879f495d4e5099cd39aa6d
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 28 Sep 2024 10:53:26 -0700
[FEATURE] Allow unequal-length input and output (#476)
* Raise DataError instead of ValueError
* Add sample rate and length validation for data (require exact match for now)
* Fix docstring
* Define critical checks
* Don't allow ignoring critical checks
* allow_unequal_lengths
* Fix test
Diffstat:
3 files changed, 72 insertions(+), 46 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -227,6 +227,21 @@ class StopError(StartStopError):
_DEFAULT_REQUIRE_INPUT_PRE_SILENCE = 0.4
+def _sample_to_time(s, rate):
+ seconds = s // rate
+ remainder = s % rate
+ hours, minutes = 0, 0
+ seconds_per_hour = 3600
+ while seconds >= seconds_per_hour:
+ hours += 1
+ seconds -= seconds_per_hour
+ seconds_per_minute = 60
+ while seconds >= seconds_per_minute:
+ minutes += 1
+ seconds -= seconds_per_minute
+ return f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples"
+
+
class Dataset(AbstractDataset, InitializableFromConfig):
"""
Take a pair of matched audio files and serve input + output pairs.
@@ -382,57 +397,65 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
+ """
+ :param config:
+ Must contain:
+ x_path (path-like)
+ y_path (path-like)
+ May contain:
+ sample_rate (int)
+ y_preroll (int)
+ allow_unequal_lengths (bool)
+ Must NOT contain:
+ x (torch.Tensor) - loaded from x_path
+ y (torch.Tensor) - loaded from y_path
+ Everything else is passed on to __init__
+ """
config = deepcopy(config)
- if "rate" in config:
- raise ValueError(
- "use of `rate` was deprecated in version 0.8. Use `sample_rate` "
- "instead."
- )
sample_rate = config.pop("sample_rate", None)
x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
sample_rate = x_wavinfo.rate
- try:
+ if config.pop("allow_unequal_lengths", False):
y = wav_to_tensor(
config.pop("y_path"),
rate=sample_rate,
preroll=config.pop("y_preroll", None),
- required_shape=(len(x), 1),
required_wavinfo=x_wavinfo,
)
- except AudioShapeMismatchError as e:
- # Really verbose message since users see this.
- x_samples, x_channels = e.shape_expected
- y_samples, y_channels = e.shape_actual
- msg = "Your audio files aren't the same shape as each other!"
- if x_channels != y_channels:
- channels_to_stereo_mono = {1: "mono", 2: "stereo"}
- msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!"
- if x_samples != y_samples:
-
- def sample_to_time(s, rate):
- seconds = s // rate
- remainder = s % rate
- hours, minutes = 0, 0
- seconds_per_hour = 3600
- while seconds >= seconds_per_hour:
- hours += 1
- seconds -= seconds_per_hour
- seconds_per_minute = 60
- while seconds >= seconds_per_minute:
- minutes += 1
- seconds -= seconds_per_minute
- return (
- f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples"
- )
-
- msg += (
- f"\n * The input is {sample_to_time(x_samples, sample_rate)} long"
- )
- msg += (
- f"\n * The output is {sample_to_time(y_samples, sample_rate)} long"
+ # Truncate to the shorter of the two
+ if len(x) == 0:
+ raise DataError("Input is zero-length!")
+ if len(y) == 0:
+ raise DataError("Output is zero-length!")
+ n = min(len(x), len(y))
+ if n < len(x):
+ print(f"Truncating input to {_sample_to_time(n, sample_rate)}")
+ if n < len(y):
+ print(f"Truncating output to {_sample_to_time(n, sample_rate)}")
+ x, y = [z[:n] for z in (x, y)]
+ else:
+ try:
+ y = wav_to_tensor(
+ config.pop("y_path"),
+ rate=sample_rate,
+ preroll=config.pop("y_preroll", None),
+ required_shape=(len(x), 1),
+ required_wavinfo=x_wavinfo,
)
- msg += f"\n\nOriginal exception:\n{e}"
- raise DataError(msg)
+ except AudioShapeMismatchError as e:
+ # Really verbose message since users see this.
+ x_samples, x_channels = e.shape_expected
+ y_samples, y_channels = e.shape_actual
+ msg = "Your audio files aren't the same shape as each other!"
+ if x_channels != y_channels:
+ channels_to_stereo_mono = {1: "mono", 2: "stereo"}
+ msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!"
+ if x_samples != y_samples:
+
+ msg += f"\n * The input is {_sample_to_time(x_samples, sample_rate)} long"
+ msg += f"\n * The output is {_sample_to_time(y_samples, sample_rate)} long"
+ msg += f"\n\nOriginal exception:\n{e}"
+ raise DataError(msg)
return {"x": x, "y": y, "sample_rate": sample_rate, **config}
@classmethod
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -713,8 +713,9 @@ def _check_v3(
print("V3 checks...")
rate = _V3_DATA_INFO.rate
y = wav_to_tensor(output_path, rate=rate)
+ n = len(wav_to_tensor(input_path)) # to End-crop output
y_val_1 = y[: _V3_DATA_INFO.t_validate]
- y_val_2 = y[-_V3_DATA_INFO.t_validate :]
+ y_val_2 = y[n - _V3_DATA_INFO.t_validate : n]
esr_replicate = esr(y_val_1, y_val_2).item()
print(f"Replicate ESR is {esr_replicate:.8f}.")
esr_replicate_threshold = 0.01
@@ -908,7 +909,7 @@ _CAB_MRSTFT_PRE_EMPH_COEF = 0.85
def _get_data_config(
input_version: Version, input_path: Path, output_path: Path, ny: int, latency: int
) -> dict:
- def get_kwargs(data_info: _DataInfo):
+ def get_split_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
train_val_split = data_info.validation_start
train_kwargs = {"stop_samples": train_val_split}
@@ -955,7 +956,7 @@ def _get_data_config(
3: _V3_DATA_INFO,
4: _V4_DATA_INFO,
}[input_version.major]
- train_kwargs, validation_kwargs = get_kwargs(data_info)
+ train_kwargs, validation_kwargs = get_split_kwargs(data_info)
data_config = {
"train": {"ny": ny, **train_kwargs},
"validation": {"ny": None, **validation_kwargs},
@@ -963,6 +964,7 @@ def _get_data_config(
"x_path": input_path,
"y_path": output_path,
"delay": latency,
+ "allow_unequal_lengths": True,
},
}
return data_config
@@ -1560,7 +1562,7 @@ def _check_audio_lengths(
input_path: Path,
output_path: Path,
max_under_seconds: Optional[float] = 0.0,
- max_over_seconds: Optional[float] = 0.0,
+ max_over_seconds: Optional[float] = 1.0,
) -> _LengthValidation:
"""
Check that the input and output have the right lengths compared to each
@@ -1587,7 +1589,7 @@ def _check_audio_lengths(
passed = True
if max_under_seconds is not None and delta_seconds < -max_under_seconds:
passed = False
- if max_over_seconds is not None and delta_seconds > max_under_seconds:
+ if max_over_seconds is not None and delta_seconds > max_over_seconds:
passed = False
return _LengthValidation(passed=passed, delta_seconds=delta_seconds)
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -266,9 +266,10 @@ def test_v3_check_doesnt_make_figure_if_silent(mocker):
x = np.random.rand(core._V3_DATA_INFO.t_validate + 1) - 0.5
with TemporaryDirectory() as tmpdir:
+ input_path = Path(tmpdir, "input.wav")
output_path = Path(tmpdir, "output.wav")
+ np_to_wav(x, input_path) # Doesn't need to be the actual thing for now
np_to_wav(x, output_path)
- input_path = None # Isn't used right now.
# If this makes a figure, then it wasn't silent!
core._check_v3(input_path, output_path, silent=True)