commit aa2b2b8b85d09d79353ceae0daed208a9ebb51bc
parent 72d01c15bd14556e8e85a1ddca2df2d087ea1bc4
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 12 Nov 2023 14:17:07 -0600
Ignore PyTorch Lightning `PossibleUserWarning`s during `trainer.fit()` (#346)
Ignore PyTorch Lightning PossibleUserWarnings during trainer fit
Diffstat:
3 files changed, 47 insertions(+), 8 deletions(-)
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -27,12 +27,13 @@ from warnings import warn
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
+from pytorch_lightning.utilities.warnings import PossibleUserWarning
import torch
from torch.utils.data import DataLoader
from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset
from nam.models import Model
-from nam.util import timestamp
+from nam.util import filter_warnings, timestamp
torch.manual_seed(0)
@@ -198,12 +199,13 @@ def main_inner(
default_root_dir=outdir,
**learning_config["trainer"],
)
- trainer.fit(
- model,
- train_dataloader,
- val_dataloader,
- **learning_config.get("trainer_fit_kwargs", {}),
- )
+ with filter_warnings("ignore", category=PossibleUserWarning):
+ trainer.fit(
+ model,
+ train_dataloader,
+ val_dataloader,
+ **learning_config.get("trainer_fit_kwargs", {}),
+ )
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -19,13 +19,17 @@ import numpy as np
import pytorch_lightning as pl
import torch
from pydantic import BaseModel
+from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.losses import esr
+from ..util import filter_warnings
from ._version import Version
+__all__ = ["train"]
+
class Architecture(Enum):
STANDARD = "standard"
@@ -1090,7 +1094,9 @@ def train(
default_root_dir=train_path,
**learning_config["trainer"],
)
- trainer.fit(model, train_dataloader, val_dataloader)
+ # Suppress the PossibleUserWarning about num_workers (Issue 345)
+ with filter_warnings("ignore", category=PossibleUserWarning):
+ trainer.fit(model, train_dataloader, val_dataloader)
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
diff --git a/nam/util.py b/nam/util.py
@@ -6,9 +6,40 @@
Helpful utilities
"""
+import warnings
from datetime import datetime
+__all__ = ["filter_warnings", "timestamp"]
+
def timestamp() -> str:
t = datetime.now()
return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}"
+
+
+class _FilterWarnings(object):
+ """
+ Context manager.
+
+ Kinda hacky since it doesn't restore to what it was before, but to what the
+ global default is.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._args = args
+ self._kwargs = kwargs
+
+ def __enter__(self):
+ warnings.filterwarnings(*self._args, **self._kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ warnings.resetwarnings()
+
+
+def filter_warnings(*args, **kwargs):
+ """
+ Simple-but-kinda-hacky context manager that allows you to use
+ `warnings.filterwarnings()` / `warnings.resetwarnings()` as if it were a
+ context manager.
+ """
+ return _FilterWarnings(*args, **kwargs)