commit a5ad4b79c83ec24d63b04678356b58bfe9b0cf12
parent 58c6f746264e1d976646fb694d1fabc49c2be0d7
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 11 Feb 2024 08:21:30 -0800
Clean up some code (#381)
* Some cleanup
* Black
Diffstat:
3 files changed, 6 insertions(+), 41 deletions(-)
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -188,41 +188,3 @@ class BaseNet(_Base):
d["loudness"] = self._metadata_loudness()
d["gain"] = self._metadata_gain()
return d
-
-
-class ParametricBaseNet(_Base):
- """
- Parametric inputs
- """
-
- def forward(
- self,
- params: torch.Tensor,
- x: torch.Tensor,
- pad_start: Optional[bool] = None,
- **kwargs
- ):
- pad_start = self.pad_start_default if pad_start is None else pad_start
- scalar = x.ndim == 1
- if scalar:
- x = x[None]
- params = params[None]
- if pad_start:
- x = torch.cat(
- (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
- )
- y = self._forward(params, x, **kwargs)
- if scalar:
- y = y[0]
- return y
-
- @abc.abstractmethod
- def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
- """
- The true forward method.
-
- :param params: (N,D)
- :param x: (N,L1)
- :return: (N,L1-RF+1)
- """
- pass
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -4,8 +4,8 @@
"""
Implements the base PyTorch Lightning model.
-This is meant to combine an acutal model (subclassed from `._base.BaseNet` or
-`._base.ParametricBaseNet`) along with loss function boilerplate.
+This is meant to combine an actual model (subclassed from `._base.BaseNet`)
+along with loss function boilerplate.
For the base *PyTorch* model containing the actual architecture, see `._base`.
"""
diff --git a/setup.py b/setup.py
@@ -5,17 +5,20 @@
from distutils.util import convert_path
from setuptools import setup, find_packages
+
def get_additional_requirements():
# Issue 294
try:
import transformers
+
# This may not be unnecessarily straict a requirement, but I'd rather
- # fix this promptly than leave a chance that it wouldn't be fixed
+ # fix this promptly than leave a chance that it wouldn't be fixed
# properly.
return ["transformers>=4"]
except ModuleNotFoundError:
return []
+
main_ns = {}
ver_path = convert_path("nam/_version.py")
with open(ver_path) as ver_file: