commit 52640d6293612d5a946c3bd9122676f6905ebd66
parent 4fb64087c2030607b8f1820bfdf62cd0c59c5310
Author: Steven Atkinson <steven@atkinson.mn>
Date: Wed, 18 Dec 2024 18:49:39 -0800
Improve warning message when catching PyTorch MPS convolution bug (#522)
Better warning message
Diffstat:
1 file changed, 22 insertions(+), 10 deletions(-)
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -176,6 +176,10 @@ class _Base(_nn.Module, _InitializableFromConfig, _Exportable):
)
+def _get_torch_version() -> str:
+ return _torch.__version__
+
+
class BaseNet(_Base):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -217,17 +221,25 @@ class BaseNet(_Base):
return self._forward(x, **kwargs)
except NotImplementedError as e:
if "Output channels > 65536 not supported at the MPS device." in str(e):
- print(
- "===WARNING===\n"
- "NAM encountered a bug in PyTorch's MPS backend and will "
- "switch to a fallback.\n"
- f"Your version of PyTorch is {_torch.__version__}.\n"
- "Please report this in an Issue at:\n"
- "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
- "\n"
- "so that NAM's dependencies can avoid buggy versions of "
- "PyTorch and the associated performance hit."
+ msg = (
+ "Warning: NAM encountered a bug in PyTorch's MPS backend and "
+ "will switch to a fallback."
)
+ known_bad_versions = {"2.5.0", "2.5.1"}
+ torch_version = _get_torch_version()
+ if torch_version not in known_bad_versions:
+ msg += (
+ "\n"
+ f"Your version of PyTorch is {torch_version}, which "
+ "wasn't known to have this problem.\n"
+ "Please open an Issue at:\n"
+ "https://github.com/sdatkinson/neural-amp-modeler/issues/507"
+ "\n"
+ f"and report your PyTorch version ({torch_version}) "
+ "so that we can keep track of versions of PyTorch that "
+ "might be avoided."
+ )
+ print(msg)
self._mps_65536_fallback = True
return self._forward_mps_safe(x, **kwargs)
else: