commit bd4dab08e65c1da803d0d6d18aac7cfb516dd96f
parent 4c71ff40429034de868a49b271a679a83f8ca350
Author: Steven Atkinson <steven@atkinson.mn>
Date: Thu, 9 Nov 2023 08:25:48 -0800
[BUGFIX] Silent mode for V3 (#338)
Silent mode for V3
Diffstat:
2 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -622,7 +622,7 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
return True
-def _check_v3(input_path, output_path, *args, **kwargs) -> bool:
+def _check_v3(input_path, output_path, silent: bool, *args, **kwargs) -> bool:
with torch.no_grad():
print("V3 checks...")
rate = _V3_DATA_INFO.rate
@@ -634,14 +634,15 @@ def _check_v3(input_path, output_path, *args, **kwargs) -> bool:
esr_replicate_threshold = 0.01
if esr_replicate > esr_replicate_threshold:
print(_esr_validation_replicate_msg(esr_replicate_threshold))
- plt.figure()
- t = np.arange(len(y_val_1)) / rate
- plt.plot(t, y_val_1, label="Validation 1")
- plt.plot(t, y_val_2, label="Validation 2")
- plt.xlabel("Time (sec)")
- plt.legend()
- plt.title("V3 check: Validation replicate FAILURE")
- plt.show()
+ if not silent:
+ plt.figure()
+ t = np.arange(len(y_val_1)) / rate
+ plt.plot(t, y_val_1, label="Validation 1")
+ plt.plot(t, y_val_2, label="Validation 2")
+ plt.xlabel("Time (sec)")
+ plt.legend()
+ plt.title("V3 check: Validation replicate FAILURE")
+ plt.show()
return False
return True
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -206,5 +206,37 @@ TestValidationDatasetV3_0_0 = _make_t_validation_dataset_class(
)
+def test_v3_check_doesnt_make_figure_if_silent(mocker):
+ """
+ Issue 337
+
+ :param mocker: Provided by pytest-mock
+ """
+ import matplotlib.pyplot
+
+ class MadeFigureError(RuntimeError):
+ """
+ For this test, detect if a figure was made, and raise an exception if so
+ """
+
+ pass
+
+ def figure_mock(*args, **kwargs):
+ raise MadeFigureError("The test tried to make a figure")
+
+ mocker.patch("matplotlib.pyplot.figure", figure_mock)
+
+ # Make some data that's totally going to biff it
+ # [:-1] won't match [1:]
+ x = np.random.rand(core._V3_DATA_INFO.t_validate + 1) - 0.5
+
+ with TemporaryDirectory() as tmpdir:
+ output_path = Path(tmpdir, "output.wav")
+ np_to_wav(x, output_path)
+ input_path = None # Isn't used right now.
+ # If this makes a figure, then it wasn't silent!
+ core._check_v3(input_path, output_path, silent=True)
+
+
if __name__ == "__main__":
pytest.main()