commit 16a2108fe64602661daa55303e6a06bf348b59d9
parent 2959e5303e60d8d2ca1924ac630cf818c3c1dd0c
Author: Steven Atkinson <steven@atkinson.mn>
Date: Wed, 7 Feb 2024 00:13:34 -0800
[FEATURE] Add support for Proteus training files (#376)
Add support for Proteus training files
Diffstat:
3 files changed, 139 insertions(+), 34 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -70,6 +70,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
"ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
"36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0),
+ "80e224bd5622fd6153ff1fd9f34cb3bd": Version(4, 0, 0),
}.get(file_hash)
if version is None:
print(
@@ -80,7 +81,8 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
def detect_weak(input_path) -> Optional[Version]:
def assign_hash(path):
- Hashes = Tuple[Optional[str], Optional[str]]
+ Hash = Optional[str]
+ Hashes = Tuple[Hash, Hash]
def _hash(x: np.ndarray) -> str:
return hashlib.md5(x).hexdigest()
@@ -133,9 +135,20 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
end_hash = _hash(x[start_of_end_interval:])
return start_hash, end_hash
+ def assign_hash_v4(path) -> Hash:
+ # Use this to create recognized hashes for new files
+ x, info = wav_to_np(path, info=True)
+ rate = info.rate
+ if rate != _V4_DATA_INFO.rate:
+ return None
+ # I don't care about anything in the file except the starting blip and
+ start_hash = _hash(x[: int(1 * _V4_DATA_INFO.rate)])
+ return start_hash
+
start_hash_v1, end_hash_v1 = assign_hashes_v1(path)
start_hash_v2, end_hash_v2 = assign_hashes_v2(path)
start_hash_v3, end_hash_v3 = assign_hashes_v3(path)
+ hash_v4 = assign_hash_v4(path)
return (
start_hash_v1,
end_hash_v1,
@@ -143,6 +156,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
end_hash_v2,
start_hash_v3,
end_hash_v3,
+ hash_v4,
)
(
@@ -152,6 +166,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
end_hash_v2,
start_hash_v3,
end_hash_v3,
+ hash_v4,
) = assign_hash(input_path)
print(
"Weak hashes:\n"
@@ -161,9 +176,11 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
f" End (v2) : {end_hash_v2}\n"
f" Start (v3) : {start_hash_v3}\n"
f" End (v3) : {end_hash_v3}\n"
+ f" Proteus : {hash_v4}\n"
)
- # Check for matches, starting with most recent
+ # Check for matches, starting with most recent. Proteus last since its match is
+ # the most permissive.
version = {
(
"dadb5d62f6c3973a59bf01439799809b",
@@ -192,6 +209,9 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
"8458126969a3f9d8e19a53554eb1fd52",
): Version(1, 1, 1),
}.get((start_hash_v1, end_hash_v1))
+ if version is not None:
+ return version
+ version = {"46151c8030798081acc00a725325a07d": Version(4, 0, 0)}.get(hash_v4)
return version
version = detect_strong(input_path)
@@ -211,17 +231,6 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
class _DataInfo(BaseModel):
"""
:param major_version: Data major version
- :param rate: Sample rate, in Hz
- :param t_blips: How long the blips are, in samples
- :param first_blips_start: When the first blips section starts, in samples
- :param t_validate: Validation signal length, in samples
- :param train_start: Where training signal starts, in samples.
- :param validation_start: Where validation signal starts, in samples. Less than zero
- (from the end of the array).
- :param noise_interval: Inside which we quantify the noise level
- :param blip_locations: In samples, absolute location in the file. Negative values
- mean from the end instead of from the start (typical "Python" negastive
- indexing).
"""
major_version: int
@@ -286,6 +295,30 @@ _V3_DATA_INFO = _DataInfo(
noise_interval=(492_000, 498_000),
blip_locations=((504_000, 552_000),),
)
+# V4 (aka GuitarML Proteus)
+# https://github.com/GuitarML/Releases/releases/download/v1.0.0/Proteus_Capture_Utility.zip
+# * 44.1k
+# * Odd length...
+# * There's a blip on sample zero. This has to be ignored or else over-compensated
+# latencies will come out wrong!
+# (0:00-0:01) Blips at 0:00.0 and 0:00.5
+# (0:01-0:09) Sine sweeps
+# (0:09-0:17) White noise
+# (0:17:0.20) Rising white noise (to 0:20.333 appx)
+# (0:20-3:30.858) General training data (ends on sample 9,298,872)
+# I'm arbitrarily assigning the last 10 seconds as validation data.
+_V4_DATA_INFO = _DataInfo(
+ major_version=4,
+ rate=44_100.0,
+ t_blips=44_099, # Need to ignore the first blip!
+ first_blips_start=1, # Need to ignore the first blip!
+ t_validate=441_000,
+ # Blips are problematic for training because they don't have preceding silence
+ train_start=44_100,
+ validation_start=-441_000,
+ noise_interval=(6_000, 12_000),
+ blip_locations=((22_050,),),
+)
_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
@@ -393,6 +426,7 @@ def _calibrate_delay_v_all(
_calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO)
_calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO)
_calibrate_delay_v3 = partial(_calibrate_delay_v_all, _V3_DATA_INFO)
+_calibrate_delay_v4 = partial(_calibrate_delay_v_all, _V4_DATA_INFO)
def _plot_delay_v_all(
@@ -445,6 +479,7 @@ def _plot_delay_v_all(
_plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO)
_plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO)
_plot_delay_v3 = partial(_plot_delay_v_all, _V3_DATA_INFO)
+_plot_delay_v4 = partial(_plot_delay_v_all, _V4_DATA_INFO)
def _calibrate_delay(
@@ -454,12 +489,16 @@ def _calibrate_delay(
output_path: str,
silent: bool = False,
) -> int:
+ """
+ :param is_proteus: Forget the version; do"""
if input_version.major == 1:
calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
elif input_version.major == 2:
calibrate, plot = _calibrate_delay_v2, _plot_delay_v2
elif input_version.major == 3:
calibrate, plot = _calibrate_delay_v3, _plot_delay_v3
+ elif input_version.major == 4:
+ calibrate, plot = _calibrate_delay_v4, _plot_delay_v4
else:
raise NotImplementedError(
f"Input calibration not implemented for input version {input_version}"
@@ -654,6 +693,29 @@ def _check_v3(input_path, output_path, silent: bool, *args, **kwargs) -> bool:
return True
+def _check_v4(input_path, output_path, silent: bool, *args, **kwargs) -> bool:
+ # Things we can't check:
+ # Latency compensation agreement
+ # Data replicability
+ print("Using Proteus audio file. Standard data checks aren't possible!")
+ signal, info = wav_to_np(output_path, info=True)
+ passed = True
+ if info.rate != _V4_DATA_INFO.rate:
+ print(
+ f"Output signal has sample rate {info.rate}; expected {_V4_DATA_INFO.rate}!"
+ )
+ passed = False
+ # I don't care what's in the files except that they're long enough to hold the blip
+ # and the last 10 seconds I decided to use as validation
+ required_length = int((1.0 + 10.0) * _V4_DATA_INFO.rate)
+ if len(signal) < required_length:
+ print(
+ "File doesn't meet the minimum length requirements for latency compensation and validation signal!"
+ )
+ passed = False
+ return passed
+
+
def _check(
input_path: str, output_path: str, input_version: Version, delay: int, silent: bool
) -> bool:
@@ -668,6 +730,8 @@ def _check(
f = _check_v2
elif input_version.major == 3:
f = _check_v3
+ elif input_version.major == 4:
+ f = _check_v4
else:
print(f"Checks not implemented for input version {input_version}; skip")
return True
@@ -821,13 +885,34 @@ def _get_configs(
train_stop = validation_start
train_kwargs = {"start": 480_000, "stop": train_stop}
validation_kwargs = {"start": validation_start}
+ elif data_info.major_version == 4:
+ validation_start = data_info.validation_start
+ train_stop = validation_start
+ train_kwargs = {"stop": train_stop}
+ # Proteus doesn't have silence to get a clean split. Bite the bullet.
+ print(
+ "Using Proteus files:\n"
+ " * There isn't a silent point to split the validation set, so some of "
+ "your gear's response from the train set will leak into the start of "
+ "the validation set and impact validation accuracy (Bypassing data "
+ "quality check)\n"
+ " * Since the validation set is different, the ESRs reported for this "
+ "model aren't comparable to those from the other 'NAM' training files."
+ )
+ validation_kwargs = {
+ "start": validation_start,
+ "require_input_pre_silence": False,
+ }
else:
raise NotImplementedError(f"kwargs for input version {input_version}")
return train_kwargs, validation_kwargs
- data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO, 3: _V3_DATA_INFO}[
- input_version.major
- ]
+ data_info = {
+ 1: _V1_DATA_INFO,
+ 2: _V2_DATA_INFO,
+ 3: _V3_DATA_INFO,
+ 4: _V4_DATA_INFO,
+ }[input_version.major]
train_kwargs, validation_kwargs = get_kwargs(data_info)
data_config = {
"train": {"ny": ny, **train_kwargs},
@@ -994,10 +1079,6 @@ def _nasty_checks_modal():
modal.mainloop()
-# Example usage:
-# show_modal("Hello, World!")
-
-
def train(
input_path: str,
output_path: str,
diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py
@@ -13,6 +13,7 @@ from pathlib import Path
import pytest
__all__ = [
+ "requires_proteus",
"requires_v1_0_0",
"requires_v1_1_1",
"requires_v2_0_0",
@@ -33,6 +34,7 @@ requires_v1_0_0 = _requires_v("v1.wav")
requires_v1_1_1 = _requires_v("v1_1_1.wav")
requires_v2_0_0 = _requires_v("v2_0_0.wav")
requires_v3_0_0 = _requires_v("v3_0_0.wav")
+requires_proteus = _requires_v("Proteus_Capture.wav")
def resource_path(name: str) -> Path:
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -21,6 +21,7 @@ from nam.train import core
from nam.train._version import Version
from ...resources import (
+ requires_proteus,
requires_v1_0_0,
requires_v1_1_1,
requires_v2_0_0,
@@ -34,6 +35,8 @@ __all__ = []
def _resource_path(version: Version) -> Path:
if version == Version(1, 0, 0):
name = "v1.wav"
+ elif version == Version(4, 0, 0):
+ name = "Proteus_Capture.wav"
else:
name = f'v{str(version).replace(".", "_")}.wav'
return resource_path(name)
@@ -167,24 +170,37 @@ class TestCalibrateDelayV3(_TCalibrateDelay):
_data_info = core._V3_DATA_INFO
+class TestCalibrateDelayV4(_TCalibrateDelay):
+ _calibrate_delay = core._calibrate_delay_v4
+ _data_info = core._V4_DATA_INFO
+
+
def _make_t_validation_dataset_class(
version: Version, decorator, data_info: core._DataInfo
):
class C(object):
- @decorator
- def test_validation_preceded_by_silence(self):
- """
- Validate that the datasets that we've made are valid
- """
- x = wav_to_tensor(_resource_path(version))
- Dataset._validate_preceding_silence(
- x,
- data_info.validation_start,
- _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
- data_info.rate,
- )
-
- return C
+ pass
+
+ # Proteus has a bad validation split; don't define the silence test for it.
+ if version == Version(4, 0, 0):
+ return C
+ else:
+
+ class C2(C):
+ @decorator
+ def test_validation_preceded_by_silence(self):
+ """
+ Validate that the datasets that we've made are valid
+ """
+ x = wav_to_tensor(_resource_path(version))
+ Dataset._validate_preceding_silence(
+ x,
+ data_info.validation_start,
+ _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
+ data_info.rate,
+ )
+
+ return C2
TestValidationDatasetV1_0_0 = _make_t_validation_dataset_class(
@@ -207,6 +223,12 @@ TestValidationDatasetV3_0_0 = _make_t_validation_dataset_class(
)
+# Aka Proteus
+TestValidationDatasetV4_0_0 = _make_t_validation_dataset_class(
+ Version(4, 0, 0), requires_proteus, core._V4_DATA_INFO
+)
+
+
def test_v3_check_doesnt_make_figure_if_silent(mocker):
"""
Issue 337