neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 35494ade96c6f9819bbf78064fb907df84895092
parent 64f79fdc4c68b8735b879f495d4e5099cd39aa6d
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 28 Sep 2024 10:53:26 -0700

[FEATURE] Allow unequal-length input and output (#476)

* Raise DataError instead of ValueError

* Add sample rate and length validation for data (require exact match for now)

* Fix docstring

* Define critical checks

* Don't allow ignoring critical checks

* allow_unequal_lengths

* Fix test
Diffstat:
Mnam/data.py | 103++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
Mnam/train/core.py | 12+++++++-----
Mtests/test_nam/test_train/test_core.py | 3++-
3 files changed, 72 insertions(+), 46 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -227,6 +227,21 @@ class StopError(StartStopError): _DEFAULT_REQUIRE_INPUT_PRE_SILENCE = 0.4 +def _sample_to_time(s, rate): + seconds = s // rate + remainder = s % rate + hours, minutes = 0, 0 + seconds_per_hour = 3600 + while seconds >= seconds_per_hour: + hours += 1 + seconds -= seconds_per_hour + seconds_per_minute = 60 + while seconds >= seconds_per_minute: + minutes += 1 + seconds -= seconds_per_minute + return f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples" + + class Dataset(AbstractDataset, InitializableFromConfig): """ Take a pair of matched audio files and serve input + output pairs. @@ -382,57 +397,65 @@ class Dataset(AbstractDataset, InitializableFromConfig): @classmethod def parse_config(cls, config): + """ + :param config: + Must contain: + x_path (path-like) + y_path (path-like) + May contain: + sample_rate (int) + y_preroll (int) + allow_unequal_lengths (bool) + Must NOT contain: + x (torch.Tensor) - loaded from x_path + y (torch.Tensor) - loaded from y_path + Everything else is passed on to __init__ + """ config = deepcopy(config) - if "rate" in config: - raise ValueError( - "use of `rate` was deprecated in version 0.8. Use `sample_rate` " - "instead." - ) sample_rate = config.pop("sample_rate", None) x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate) sample_rate = x_wavinfo.rate - try: + if config.pop("allow_unequal_lengths", False): y = wav_to_tensor( config.pop("y_path"), rate=sample_rate, preroll=config.pop("y_preroll", None), - required_shape=(len(x), 1), required_wavinfo=x_wavinfo, ) - except AudioShapeMismatchError as e: - # Really verbose message since users see this. - x_samples, x_channels = e.shape_expected - y_samples, y_channels = e.shape_actual - msg = "Your audio files aren't the same shape as each other!" - if x_channels != y_channels: - channels_to_stereo_mono = {1: "mono", 2: "stereo"} - msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!" - if x_samples != y_samples: - - def sample_to_time(s, rate): - seconds = s // rate - remainder = s % rate - hours, minutes = 0, 0 - seconds_per_hour = 3600 - while seconds >= seconds_per_hour: - hours += 1 - seconds -= seconds_per_hour - seconds_per_minute = 60 - while seconds >= seconds_per_minute: - minutes += 1 - seconds -= seconds_per_minute - return ( - f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples" - ) - - msg += ( - f"\n * The input is {sample_to_time(x_samples, sample_rate)} long" - ) - msg += ( - f"\n * The output is {sample_to_time(y_samples, sample_rate)} long" + # Truncate to the shorter of the two + if len(x) == 0: + raise DataError("Input is zero-length!") + if len(y) == 0: + raise DataError("Output is zero-length!") + n = min(len(x), len(y)) + if n < len(x): + print(f"Truncating input to {_sample_to_time(n, sample_rate)}") + if n < len(y): + print(f"Truncating output to {_sample_to_time(n, sample_rate)}") + x, y = [z[:n] for z in (x, y)] + else: + try: + y = wav_to_tensor( + config.pop("y_path"), + rate=sample_rate, + preroll=config.pop("y_preroll", None), + required_shape=(len(x), 1), + required_wavinfo=x_wavinfo, ) - msg += f"\n\nOriginal exception:\n{e}" - raise DataError(msg) + except AudioShapeMismatchError as e: + # Really verbose message since users see this. + x_samples, x_channels = e.shape_expected + y_samples, y_channels = e.shape_actual + msg = "Your audio files aren't the same shape as each other!" + if x_channels != y_channels: + channels_to_stereo_mono = {1: "mono", 2: "stereo"} + msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!" + if x_samples != y_samples: + + msg += f"\n * The input is {_sample_to_time(x_samples, sample_rate)} long" + msg += f"\n * The output is {_sample_to_time(y_samples, sample_rate)} long" + msg += f"\n\nOriginal exception:\n{e}" + raise DataError(msg) return {"x": x, "y": y, "sample_rate": sample_rate, **config} @classmethod diff --git a/nam/train/core.py b/nam/train/core.py @@ -713,8 +713,9 @@ def _check_v3( print("V3 checks...") rate = _V3_DATA_INFO.rate y = wav_to_tensor(output_path, rate=rate) + n = len(wav_to_tensor(input_path)) # to End-crop output y_val_1 = y[: _V3_DATA_INFO.t_validate] - y_val_2 = y[-_V3_DATA_INFO.t_validate :] + y_val_2 = y[n - _V3_DATA_INFO.t_validate : n] esr_replicate = esr(y_val_1, y_val_2).item() print(f"Replicate ESR is {esr_replicate:.8f}.") esr_replicate_threshold = 0.01 @@ -908,7 +909,7 @@ _CAB_MRSTFT_PRE_EMPH_COEF = 0.85 def _get_data_config( input_version: Version, input_path: Path, output_path: Path, ny: int, latency: int ) -> dict: - def get_kwargs(data_info: _DataInfo): + def get_split_kwargs(data_info: _DataInfo): if data_info.major_version == 1: train_val_split = data_info.validation_start train_kwargs = {"stop_samples": train_val_split} @@ -955,7 +956,7 @@ def _get_data_config( 3: _V3_DATA_INFO, 4: _V4_DATA_INFO, }[input_version.major] - train_kwargs, validation_kwargs = get_kwargs(data_info) + train_kwargs, validation_kwargs = get_split_kwargs(data_info) data_config = { "train": {"ny": ny, **train_kwargs}, "validation": {"ny": None, **validation_kwargs}, @@ -963,6 +964,7 @@ def _get_data_config( "x_path": input_path, "y_path": output_path, "delay": latency, + "allow_unequal_lengths": True, }, } return data_config @@ -1560,7 +1562,7 @@ def _check_audio_lengths( input_path: Path, output_path: Path, max_under_seconds: Optional[float] = 0.0, - max_over_seconds: Optional[float] = 0.0, + max_over_seconds: Optional[float] = 1.0, ) -> _LengthValidation: """ Check that the input and output have the right lengths compared to each @@ -1587,7 +1589,7 @@ def _check_audio_lengths( passed = True if max_under_seconds is not None and delta_seconds < -max_under_seconds: passed = False - if max_over_seconds is not None and delta_seconds > max_under_seconds: + if max_over_seconds is not None and delta_seconds > max_over_seconds: passed = False return _LengthValidation(passed=passed, delta_seconds=delta_seconds) diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -266,9 +266,10 @@ def test_v3_check_doesnt_make_figure_if_silent(mocker): x = np.random.rand(core._V3_DATA_INFO.t_validate + 1) - 0.5 with TemporaryDirectory() as tmpdir: + input_path = Path(tmpdir, "input.wav") output_path = Path(tmpdir, "output.wav") + np_to_wav(x, input_path) # Doesn't need to be the actual thing for now np_to_wav(x, output_path) - input_path = None # Isn't used right now. # If this makes a figure, then it wasn't silent! core._check_v3(input_path, output_path, silent=True)