neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit a5ad4b79c83ec24d63b04678356b58bfe9b0cf12
parent 58c6f746264e1d976646fb694d1fabc49c2be0d7
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 11 Feb 2024 08:21:30 -0800

Clean up some code (#381)

* Some cleanup

* Black
Diffstat:
Mnam/models/_base.py | 38--------------------------------------
Mnam/models/base.py | 4++--
Msetup.py | 5++++-
3 files changed, 6 insertions(+), 41 deletions(-)

diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -188,41 +188,3 @@ class BaseNet(_Base): d["loudness"] = self._metadata_loudness() d["gain"] = self._metadata_gain() return d - - -class ParametricBaseNet(_Base): - """ - Parametric inputs - """ - - def forward( - self, - params: torch.Tensor, - x: torch.Tensor, - pad_start: Optional[bool] = None, - **kwargs - ): - pad_start = self.pad_start_default if pad_start is None else pad_start - scalar = x.ndim == 1 - if scalar: - x = x[None] - params = params[None] - if pad_start: - x = torch.cat( - (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 - ) - y = self._forward(params, x, **kwargs) - if scalar: - y = y[0] - return y - - @abc.abstractmethod - def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - The true forward method. - - :param params: (N,D) - :param x: (N,L1) - :return: (N,L1-RF+1) - """ - pass diff --git a/nam/models/base.py b/nam/models/base.py @@ -4,8 +4,8 @@ """ Implements the base PyTorch Lightning model. -This is meant to combine an acutal model (subclassed from `._base.BaseNet` or -`._base.ParametricBaseNet`) along with loss function boilerplate. +This is meant to combine an actual model (subclassed from `._base.BaseNet`) +along with loss function boilerplate. For the base *PyTorch* model containing the actual architecture, see `._base`. """ diff --git a/setup.py b/setup.py @@ -5,17 +5,20 @@ from distutils.util import convert_path from setuptools import setup, find_packages + def get_additional_requirements(): # Issue 294 try: import transformers + # This may not be unnecessarily straict a requirement, but I'd rather - # fix this promptly than leave a chance that it wouldn't be fixed + # fix this promptly than leave a chance that it wouldn't be fixed # properly. return ["transformers>=4"] except ModuleNotFoundError: return [] + main_ns = {} ver_path = convert_path("nam/_version.py") with open(ver_path) as ver_file: