neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 16a2108fe64602661daa55303e6a06bf348b59d9
parent 2959e5303e60d8d2ca1924ac630cf818c3c1dd0c
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Wed,  7 Feb 2024 00:13:34 -0800

[FEATURE] Add support for Proteus training files (#376)

Add support for Proteus training files
Diffstat:
Mnam/train/core.py | 121++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------
Mtests/resources/__init__.py | 2++
Mtests/test_nam/test_train/test_core.py | 50++++++++++++++++++++++++++++++++++++--------------
3 files changed, 139 insertions(+), 34 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -70,6 +70,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), "ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0), "36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0), + "80e224bd5622fd6153ff1fd9f34cb3bd": Version(4, 0, 0), }.get(file_hash) if version is None: print( @@ -80,7 +81,8 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: def detect_weak(input_path) -> Optional[Version]: def assign_hash(path): - Hashes = Tuple[Optional[str], Optional[str]] + Hash = Optional[str] + Hashes = Tuple[Hash, Hash] def _hash(x: np.ndarray) -> str: return hashlib.md5(x).hexdigest() @@ -133,9 +135,20 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: end_hash = _hash(x[start_of_end_interval:]) return start_hash, end_hash + def assign_hash_v4(path) -> Hash: + # Use this to create recognized hashes for new files + x, info = wav_to_np(path, info=True) + rate = info.rate + if rate != _V4_DATA_INFO.rate: + return None + # I don't care about anything in the file except the starting blip and + start_hash = _hash(x[: int(1 * _V4_DATA_INFO.rate)]) + return start_hash + start_hash_v1, end_hash_v1 = assign_hashes_v1(path) start_hash_v2, end_hash_v2 = assign_hashes_v2(path) start_hash_v3, end_hash_v3 = assign_hashes_v3(path) + hash_v4 = assign_hash_v4(path) return ( start_hash_v1, end_hash_v1, @@ -143,6 +156,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: end_hash_v2, start_hash_v3, end_hash_v3, + hash_v4, ) ( @@ -152,6 +166,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: end_hash_v2, start_hash_v3, end_hash_v3, + hash_v4, ) = assign_hash(input_path) print( "Weak hashes:\n" @@ -161,9 +176,11 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: f" End (v2) : {end_hash_v2}\n" f" Start (v3) : {start_hash_v3}\n" f" End (v3) : {end_hash_v3}\n" + f" Proteus : {hash_v4}\n" ) - # Check for matches, starting with most recent + # Check for matches, starting with most recent. Proteus last since its match is + # the most permissive. version = { ( "dadb5d62f6c3973a59bf01439799809b", @@ -192,6 +209,9 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: "8458126969a3f9d8e19a53554eb1fd52", ): Version(1, 1, 1), }.get((start_hash_v1, end_hash_v1)) + if version is not None: + return version + version = {"46151c8030798081acc00a725325a07d": Version(4, 0, 0)}.get(hash_v4) return version version = detect_strong(input_path) @@ -211,17 +231,6 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: class _DataInfo(BaseModel): """ :param major_version: Data major version - :param rate: Sample rate, in Hz - :param t_blips: How long the blips are, in samples - :param first_blips_start: When the first blips section starts, in samples - :param t_validate: Validation signal length, in samples - :param train_start: Where training signal starts, in samples. - :param validation_start: Where validation signal starts, in samples. Less than zero - (from the end of the array). - :param noise_interval: Inside which we quantify the noise level - :param blip_locations: In samples, absolute location in the file. Negative values - mean from the end instead of from the start (typical "Python" negastive - indexing). """ major_version: int @@ -286,6 +295,30 @@ _V3_DATA_INFO = _DataInfo( noise_interval=(492_000, 498_000), blip_locations=((504_000, 552_000),), ) +# V4 (aka GuitarML Proteus) +# https://github.com/GuitarML/Releases/releases/download/v1.0.0/Proteus_Capture_Utility.zip +# * 44.1k +# * Odd length... +# * There's a blip on sample zero. This has to be ignored or else over-compensated +# latencies will come out wrong! +# (0:00-0:01) Blips at 0:00.0 and 0:00.5 +# (0:01-0:09) Sine sweeps +# (0:09-0:17) White noise +# (0:17:0.20) Rising white noise (to 0:20.333 appx) +# (0:20-3:30.858) General training data (ends on sample 9,298,872) +# I'm arbitrarily assigning the last 10 seconds as validation data. +_V4_DATA_INFO = _DataInfo( + major_version=4, + rate=44_100.0, + t_blips=44_099, # Need to ignore the first blip! + first_blips_start=1, # Need to ignore the first blip! + t_validate=441_000, + # Blips are problematic for training because they don't have preceding silence + train_start=44_100, + validation_start=-441_000, + noise_interval=(6_000, 12_000), + blip_locations=((22_050,),), +) _DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003 _DELAY_CALIBRATION_REL_THRESHOLD = 0.001 @@ -393,6 +426,7 @@ def _calibrate_delay_v_all( _calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO) _calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO) _calibrate_delay_v3 = partial(_calibrate_delay_v_all, _V3_DATA_INFO) +_calibrate_delay_v4 = partial(_calibrate_delay_v_all, _V4_DATA_INFO) def _plot_delay_v_all( @@ -445,6 +479,7 @@ def _plot_delay_v_all( _plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO) _plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO) _plot_delay_v3 = partial(_plot_delay_v_all, _V3_DATA_INFO) +_plot_delay_v4 = partial(_plot_delay_v_all, _V4_DATA_INFO) def _calibrate_delay( @@ -454,12 +489,16 @@ def _calibrate_delay( output_path: str, silent: bool = False, ) -> int: + """ + :param is_proteus: Forget the version; do""" if input_version.major == 1: calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 elif input_version.major == 2: calibrate, plot = _calibrate_delay_v2, _plot_delay_v2 elif input_version.major == 3: calibrate, plot = _calibrate_delay_v3, _plot_delay_v3 + elif input_version.major == 4: + calibrate, plot = _calibrate_delay_v4, _plot_delay_v4 else: raise NotImplementedError( f"Input calibration not implemented for input version {input_version}" @@ -654,6 +693,29 @@ def _check_v3(input_path, output_path, silent: bool, *args, **kwargs) -> bool: return True +def _check_v4(input_path, output_path, silent: bool, *args, **kwargs) -> bool: + # Things we can't check: + # Latency compensation agreement + # Data replicability + print("Using Proteus audio file. Standard data checks aren't possible!") + signal, info = wav_to_np(output_path, info=True) + passed = True + if info.rate != _V4_DATA_INFO.rate: + print( + f"Output signal has sample rate {info.rate}; expected {_V4_DATA_INFO.rate}!" + ) + passed = False + # I don't care what's in the files except that they're long enough to hold the blip + # and the last 10 seconds I decided to use as validation + required_length = int((1.0 + 10.0) * _V4_DATA_INFO.rate) + if len(signal) < required_length: + print( + "File doesn't meet the minimum length requirements for latency compensation and validation signal!" + ) + passed = False + return passed + + def _check( input_path: str, output_path: str, input_version: Version, delay: int, silent: bool ) -> bool: @@ -668,6 +730,8 @@ def _check( f = _check_v2 elif input_version.major == 3: f = _check_v3 + elif input_version.major == 4: + f = _check_v4 else: print(f"Checks not implemented for input version {input_version}; skip") return True @@ -821,13 +885,34 @@ def _get_configs( train_stop = validation_start train_kwargs = {"start": 480_000, "stop": train_stop} validation_kwargs = {"start": validation_start} + elif data_info.major_version == 4: + validation_start = data_info.validation_start + train_stop = validation_start + train_kwargs = {"stop": train_stop} + # Proteus doesn't have silence to get a clean split. Bite the bullet. + print( + "Using Proteus files:\n" + " * There isn't a silent point to split the validation set, so some of " + "your gear's response from the train set will leak into the start of " + "the validation set and impact validation accuracy (Bypassing data " + "quality check)\n" + " * Since the validation set is different, the ESRs reported for this " + "model aren't comparable to those from the other 'NAM' training files." + ) + validation_kwargs = { + "start": validation_start, + "require_input_pre_silence": False, + } else: raise NotImplementedError(f"kwargs for input version {input_version}") return train_kwargs, validation_kwargs - data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO, 3: _V3_DATA_INFO}[ - input_version.major - ] + data_info = { + 1: _V1_DATA_INFO, + 2: _V2_DATA_INFO, + 3: _V3_DATA_INFO, + 4: _V4_DATA_INFO, + }[input_version.major] train_kwargs, validation_kwargs = get_kwargs(data_info) data_config = { "train": {"ny": ny, **train_kwargs}, @@ -994,10 +1079,6 @@ def _nasty_checks_modal(): modal.mainloop() -# Example usage: -# show_modal("Hello, World!") - - def train( input_path: str, output_path: str, diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py @@ -13,6 +13,7 @@ from pathlib import Path import pytest __all__ = [ + "requires_proteus", "requires_v1_0_0", "requires_v1_1_1", "requires_v2_0_0", @@ -33,6 +34,7 @@ requires_v1_0_0 = _requires_v("v1.wav") requires_v1_1_1 = _requires_v("v1_1_1.wav") requires_v2_0_0 = _requires_v("v2_0_0.wav") requires_v3_0_0 = _requires_v("v3_0_0.wav") +requires_proteus = _requires_v("Proteus_Capture.wav") def resource_path(name: str) -> Path: diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -21,6 +21,7 @@ from nam.train import core from nam.train._version import Version from ...resources import ( + requires_proteus, requires_v1_0_0, requires_v1_1_1, requires_v2_0_0, @@ -34,6 +35,8 @@ __all__ = [] def _resource_path(version: Version) -> Path: if version == Version(1, 0, 0): name = "v1.wav" + elif version == Version(4, 0, 0): + name = "Proteus_Capture.wav" else: name = f'v{str(version).replace(".", "_")}.wav' return resource_path(name) @@ -167,24 +170,37 @@ class TestCalibrateDelayV3(_TCalibrateDelay): _data_info = core._V3_DATA_INFO +class TestCalibrateDelayV4(_TCalibrateDelay): + _calibrate_delay = core._calibrate_delay_v4 + _data_info = core._V4_DATA_INFO + + def _make_t_validation_dataset_class( version: Version, decorator, data_info: core._DataInfo ): class C(object): - @decorator - def test_validation_preceded_by_silence(self): - """ - Validate that the datasets that we've made are valid - """ - x = wav_to_tensor(_resource_path(version)) - Dataset._validate_preceding_silence( - x, - data_info.validation_start, - _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, - data_info.rate, - ) - - return C + pass + + # Proteus has a bad validation split; don't define the silence test for it. + if version == Version(4, 0, 0): + return C + else: + + class C2(C): + @decorator + def test_validation_preceded_by_silence(self): + """ + Validate that the datasets that we've made are valid + """ + x = wav_to_tensor(_resource_path(version)) + Dataset._validate_preceding_silence( + x, + data_info.validation_start, + _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, + data_info.rate, + ) + + return C2 TestValidationDatasetV1_0_0 = _make_t_validation_dataset_class( @@ -207,6 +223,12 @@ TestValidationDatasetV3_0_0 = _make_t_validation_dataset_class( ) +# Aka Proteus +TestValidationDatasetV4_0_0 = _make_t_validation_dataset_class( + Version(4, 0, 0), requires_proteus, core._V4_DATA_INFO +) + + def test_v3_check_doesnt_make_figure_if_silent(mocker): """ Issue 337