neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit bd4dab08e65c1da803d0d6d18aac7cfb516dd96f
parent 4c71ff40429034de868a49b271a679a83f8ca350
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Thu,  9 Nov 2023 08:25:48 -0800

[BUGFIX] Silent mode for V3 (#338)

Silent mode for V3
Diffstat:
Mnam/train/core.py | 19++++++++++---------
Mtests/test_nam/test_train/test_core.py | 32++++++++++++++++++++++++++++++++
2 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -622,7 +622,7 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: return True -def _check_v3(input_path, output_path, *args, **kwargs) -> bool: +def _check_v3(input_path, output_path, silent: bool, *args, **kwargs) -> bool: with torch.no_grad(): print("V3 checks...") rate = _V3_DATA_INFO.rate @@ -634,14 +634,15 @@ def _check_v3(input_path, output_path, *args, **kwargs) -> bool: esr_replicate_threshold = 0.01 if esr_replicate > esr_replicate_threshold: print(_esr_validation_replicate_msg(esr_replicate_threshold)) - plt.figure() - t = np.arange(len(y_val_1)) / rate - plt.plot(t, y_val_1, label="Validation 1") - plt.plot(t, y_val_2, label="Validation 2") - plt.xlabel("Time (sec)") - plt.legend() - plt.title("V3 check: Validation replicate FAILURE") - plt.show() + if not silent: + plt.figure() + t = np.arange(len(y_val_1)) / rate + plt.plot(t, y_val_1, label="Validation 1") + plt.plot(t, y_val_2, label="Validation 2") + plt.xlabel("Time (sec)") + plt.legend() + plt.title("V3 check: Validation replicate FAILURE") + plt.show() return False return True diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -206,5 +206,37 @@ TestValidationDatasetV3_0_0 = _make_t_validation_dataset_class( ) +def test_v3_check_doesnt_make_figure_if_silent(mocker): + """ + Issue 337 + + :param mocker: Provided by pytest-mock + """ + import matplotlib.pyplot + + class MadeFigureError(RuntimeError): + """ + For this test, detect if a figure was made, and raise an exception if so + """ + + pass + + def figure_mock(*args, **kwargs): + raise MadeFigureError("The test tried to make a figure") + + mocker.patch("matplotlib.pyplot.figure", figure_mock) + + # Make some data that's totally going to biff it + # [:-1] won't match [1:] + x = np.random.rand(core._V3_DATA_INFO.t_validate + 1) - 0.5 + + with TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir, "output.wav") + np_to_wav(x, output_path) + input_path = None # Isn't used right now. + # If this makes a figure, then it wasn't silent! + core._check_v3(input_path, output_path, silent=True) + + if __name__ == "__main__": pytest.main()