neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit aa2b2b8b85d09d79353ceae0daed208a9ebb51bc
parent 72d01c15bd14556e8e85a1ddca2df2d087ea1bc4
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 12 Nov 2023 14:17:07 -0600

Ignore PyTorch Lightning `PossibleUserWarning`s during `trainer.fit()` (#346)

Ignore PyTorch Lightning PossibleUserWarnings during trainer fit
Diffstat:
Mbin/train/main.py | 16+++++++++-------
Mnam/train/core.py | 8+++++++-
Mnam/util.py | 31+++++++++++++++++++++++++++++++
3 files changed, 47 insertions(+), 8 deletions(-)

diff --git a/bin/train/main.py b/bin/train/main.py @@ -27,12 +27,13 @@ from warnings import warn import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl +from pytorch_lightning.utilities.warnings import PossibleUserWarning import torch from torch.utils.data import DataLoader from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset from nam.models import Model -from nam.util import timestamp +from nam.util import filter_warnings, timestamp torch.manual_seed(0) @@ -198,12 +199,13 @@ def main_inner( default_root_dir=outdir, **learning_config["trainer"], ) - trainer.fit( - model, - train_dataloader, - val_dataloader, - **learning_config.get("trainer_fit_kwargs", {}), - ) + with filter_warnings("ignore", category=PossibleUserWarning): + trainer.fit( + model, + train_dataloader, + val_dataloader, + **learning_config.get("trainer_fit_kwargs", {}), + ) # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": diff --git a/nam/train/core.py b/nam/train/core.py @@ -19,13 +19,17 @@ import numpy as np import pytorch_lightning as pl import torch from pydantic import BaseModel +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model from ..models.losses import esr +from ..util import filter_warnings from ._version import Version +__all__ = ["train"] + class Architecture(Enum): STANDARD = "standard" @@ -1090,7 +1094,9 @@ def train( default_root_dir=train_path, **learning_config["trainer"], ) - trainer.fit(model, train_dataloader, val_dataloader) + # Suppress the PossibleUserWarning about num_workers (Issue 345) + with filter_warnings("ignore", category=PossibleUserWarning): + trainer.fit(model, train_dataloader, val_dataloader) # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path diff --git a/nam/util.py b/nam/util.py @@ -6,9 +6,40 @@ Helpful utilities """ +import warnings from datetime import datetime +__all__ = ["filter_warnings", "timestamp"] + def timestamp() -> str: t = datetime.now() return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}" + + +class _FilterWarnings(object): + """ + Context manager. + + Kinda hacky since it doesn't restore to what it was before, but to what the + global default is. + """ + + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + + def __enter__(self): + warnings.filterwarnings(*self._args, **self._kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + warnings.resetwarnings() + + +def filter_warnings(*args, **kwargs): + """ + Simple-but-kinda-hacky context manager that allows you to use + `warnings.filterwarnings()` / `warnings.resetwarnings()` as if it were a + context manager. + """ + return _FilterWarnings(*args, **kwargs)