commit 4fb64087c2030607b8f1820bfdf62cd0c59c5310
parent 3899ccd703c73d159c9442e24d743dd3b4219558
Author: zghannam <124007389+zghannam@users.noreply.github.com>
Date: Wed, 18 Dec 2024 11:44:52 -0500
[BREAKING] Change module imports to private scope (#515)
* Local scope import changes for first and third party imports (not in-package imports, though)
* nam/models refactored and versions changed
* Refactored the gui
* Refactored nam/train/
* Formatted with black
* Accidentally changed TODO note
* Removed accidental addition of settings.json
* Didn't delete _version.py
* Uneccessary imports in local scopes removed, some missing names added.
* Change name of import reference in test to reflect name change.
Diffstat:
22 files changed, 1118 insertions(+), 1012 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
@@ -12,8 +12,8 @@ copyright = "2024 Steven Atkinson"
author = "Steven Atkinson"
# TODO update this automatically from nam.__version__!
-release = "0.11"
-version = "0.11.1"
+release = "0.12"
+version = "0.12.0"
# -- General configuration
diff --git a/nam/_core.py b/nam/_core.py
@@ -2,7 +2,7 @@
# Created Date: Saturday February 5th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
-from copy import deepcopy
+from copy import deepcopy as _deepcopy
class InitializableFromConfig(object):
@@ -12,4 +12,4 @@ class InitializableFromConfig(object):
@classmethod
def parse_config(cls, config):
- return deepcopy(config)
+ return _deepcopy(config)
diff --git a/nam/cli.py b/nam/cli.py
@@ -77,17 +77,17 @@ def _apply_extensions():
_apply_extensions()
-import json
-from argparse import ArgumentParser
-from pathlib import Path
+import json as _json
+from argparse import ArgumentParser as _ArgumentParser
+from pathlib import Path as _Path
from nam.train.full import main as _nam_full
-from nam.train.gui import run as nam_gui # noqa F401 Used as an entry point
-from nam.util import timestamp
+from nam.train.gui import run as _nam_gui # noqa F401 Used as an entry point
+from nam.util import timestamp as _timestamp
def nam_full():
- parser = ArgumentParser()
+ parser = _ArgumentParser()
parser.add_argument("data_config_path", type=str)
parser.add_argument("model_config_path", type=str)
parser.add_argument("learning_config_path", type=str)
@@ -96,17 +96,17 @@ def nam_full():
args = parser.parse_args()
- def ensure_outdir(outdir: str) -> Path:
- outdir = Path(outdir, timestamp())
+ def ensure_outdir(outdir: str) -> _Path:
+ outdir = _Path(outdir, _timestamp())
outdir.mkdir(parents=True, exist_ok=False)
return outdir
outdir = ensure_outdir(args.outdir)
# Read
with open(args.data_config_path, "r") as fp:
- data_config = json.load(fp)
+ data_config = _json.load(fp)
with open(args.model_config_path, "r") as fp:
- model_config = json.load(fp)
+ model_config = _json.load(fp)
with open(args.learning_config_path, "r") as fp:
- learning_config = json.load(fp)
+ learning_config = _json.load(fp)
_nam_full(data_config, model_config, learning_config, outdir, args.no_show)
diff --git a/nam/data.py b/nam/data.py
@@ -6,35 +6,43 @@
Functions and classes for working with audio data with NAM
"""
-import abc
-import logging
-from collections import namedtuple
-from copy import deepcopy
-from dataclasses import dataclass
-from enum import Enum
-from pathlib import Path
-from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
-
-import numpy as np
-import torch
-import wavio
-from scipy.interpolate import interp1d
+import abc as _abc
+import logging as _logging
+from collections import namedtuple as _namedtuple
+from copy import deepcopy as _deepcopy
+from dataclasses import dataclass as _dataclass
+from enum import Enum as _Enum
+from pathlib import Path as _Path
+from typing import (
+ Any as _Any,
+ Callable as _Callable,
+ Dict as _Dict,
+ Optional as _Optional,
+ Sequence as _Sequence,
+ Tuple as _Tuple,
+ Union as _Union,
+)
+
+import numpy as _np
+import torch as _torch
+import wavio as _wavio
+from scipy.interpolate import interp1d as _interp1d
from torch.utils.data import Dataset as _Dataset
-from tqdm import tqdm
+from tqdm import tqdm as _tqdm
-from ._core import InitializableFromConfig
+from ._core import InitializableFromConfig as _InitializableFromConfig
-logger = logging.getLogger(__name__)
+logger = _logging.getLogger(__name__)
_REQUIRED_CHANNELS = 1 # Mono
-class Split(Enum):
+class Split(_Enum):
TRAIN = "train"
VALIDATION = "validation"
-@dataclass
+@_dataclass
class WavInfo:
sampwidth: int
rate: int
@@ -69,14 +77,14 @@ class AudioShapeMismatchError(ValueError, DataError):
def wav_to_np(
- filename: Union[str, Path],
- rate: Optional[int] = None,
- require_match: Optional[Union[str, Path]] = None,
- required_shape: Optional[Tuple[int, ...]] = None,
- required_wavinfo: Optional[WavInfo] = None,
- preroll: Optional[int] = None,
+ filename: _Union[str, _Path],
+ rate: _Optional[int] = None,
+ require_match: _Optional[_Union[str, _Path]] = None,
+ required_shape: _Optional[_Tuple[int, ...]] = None,
+ required_wavinfo: _Optional[WavInfo] = None,
+ preroll: _Optional[int] = None,
info: bool = False,
-) -> Union[np.ndarray, Tuple[np.ndarray, WavInfo]]:
+) -> _Union[_np.ndarray, _Tuple[_np.ndarray, WavInfo]]:
"""
:param filename: Where to load from
:param rate: Expected sample rate. `None` allows for anything.
@@ -89,7 +97,7 @@ def wav_to_np(
:param preroll: Drop this many samples off the front
:param info: If `True`, also return the WAV info of this file.
"""
- x_wav = wavio.read(str(filename))
+ x_wav = _wavio.read(str(filename))
assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono"
if rate is not None and x_wav.rate != rate:
raise RuntimeError(
@@ -100,7 +108,7 @@ def wav_to_np(
if require_match is not None:
assert required_shape is None
assert required_wavinfo is None
- y_wav = wavio.read(str(require_match))
+ y_wav = _wavio.read(str(require_match))
required_shape = y_wav.data.shape
required_wavinfo = WavInfo(y_wav.sampwidth, y_wav.rate)
if required_wavinfo is not None:
@@ -124,33 +132,33 @@ def wav_to_np(
def wav_to_tensor(
*args, info: bool = False, **kwargs
-) -> Union[torch.Tensor, Tuple[torch.Tensor, WavInfo]]:
+) -> _Union[_torch.Tensor, _Tuple[_torch.Tensor, WavInfo]]:
out = wav_to_np(*args, info=info, **kwargs)
if info:
arr, info = out
- return torch.Tensor(arr), info
+ return _torch.Tensor(arr), info
else:
arr = out
- return torch.Tensor(arr)
+ return _torch.Tensor(arr)
-def tensor_to_wav(x: torch.Tensor, *args, **kwargs):
+def tensor_to_wav(x: _torch.Tensor, *args, **kwargs):
np_to_wav(x.detach().cpu().numpy(), *args, **kwargs)
def np_to_wav(
- x: np.ndarray,
- filename: Union[str, Path],
+ x: _np.ndarray,
+ filename: _Union[str, _Path],
rate: int = 48_000,
sampwidth: int = 3,
scale=None,
**kwargs,
):
- if wavio.__version__ <= "0.0.4" and scale is None:
+ if _wavio.__version__ <= "0.0.4" and scale is None:
scale = "none"
- wavio.write(
+ _wavio.write(
str(filename),
- (np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(np.int32),
+ (_np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(_np.int32),
rate,
scale=scale,
sampwidth=sampwidth,
@@ -158,8 +166,8 @@ def np_to_wav(
)
-class AbstractDataset(_Dataset, abc.ABC):
- @abc.abstractmethod
+class AbstractDataset(_Dataset, _abc.ABC):
+ @_abc.abstractmethod
def __getitem__(self, idx: int):
"""
Get input and output audio segment for training / evaluation.
@@ -168,7 +176,7 @@ class AbstractDataset(_Dataset, abc.ABC):
pass
-class _DelayInterpolationMethod(Enum):
+class _DelayInterpolationMethod(_Enum):
"""
:param LINEAR: Linear interpolation
:param CUBIC: Cubic spline interpolation
@@ -180,22 +188,22 @@ class _DelayInterpolationMethod(Enum):
def _interpolate_delay(
- x: torch.Tensor, delay: float, method: _DelayInterpolationMethod
-) -> np.ndarray:
+ x: _torch.Tensor, delay: float, method: _DelayInterpolationMethod
+) -> _np.ndarray:
"""
NOTE: This breaks the gradient tape!
"""
if delay == 0.0:
return x
- t_in = np.arange(len(x))
- n_out = len(x) - int(np.ceil(np.abs(delay)))
+ t_in = _np.arange(len(x))
+ n_out = len(x) - int(_np.ceil(_np.abs(delay)))
if delay > 0:
- t_out = np.arange(n_out) + delay
+ t_out = _np.arange(n_out) + delay
elif delay < 0:
- t_out = np.arange(len(x) - n_out, len(x)) - np.abs(delay)
+ t_out = _np.arange(len(x) - n_out, len(x)) - _np.abs(delay)
- return torch.Tensor(
- interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out)
+ return _torch.Tensor(
+ _interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out)
)
@@ -242,33 +250,35 @@ def _sample_to_time(s, rate):
return f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples"
-class Dataset(AbstractDataset, InitializableFromConfig):
+class Dataset(AbstractDataset, _InitializableFromConfig):
"""
Take a pair of matched audio files and serve input + output pairs.
"""
def __init__(
self,
- x: torch.Tensor,
- y: torch.Tensor,
+ x: _torch.Tensor,
+ y: _torch.Tensor,
nx: int,
- ny: Optional[int],
- start: Optional[int] = None,
- stop: Optional[int] = None,
- start_samples: Optional[int] = None,
- stop_samples: Optional[int] = None,
- start_seconds: Optional[Union[int, float]] = None,
- stop_seconds: Optional[Union[int, float]] = None,
- delay: Optional[Union[int, float]] = None,
- delay_interpolation_method: Union[
+ ny: _Optional[int],
+ start: _Optional[int] = None,
+ stop: _Optional[int] = None,
+ start_samples: _Optional[int] = None,
+ stop_samples: _Optional[int] = None,
+ start_seconds: _Optional[_Union[int, float]] = None,
+ stop_seconds: _Optional[_Union[int, float]] = None,
+ delay: _Optional[_Union[int, float]] = None,
+ delay_interpolation_method: _Union[
str, _DelayInterpolationMethod
] = _DelayInterpolationMethod.CUBIC,
y_scale: float = 1.0,
- x_path: Optional[Union[str, Path]] = None,
- y_path: Optional[Union[str, Path]] = None,
+ x_path: _Optional[_Union[str, _Path]] = None,
+ y_path: _Optional[_Union[str, _Path]] = None,
input_gain: float = 0.0,
- sample_rate: Optional[float] = None,
- require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
+ sample_rate: _Optional[float] = None,
+ require_input_pre_silence: _Optional[
+ float
+ ] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
:param x: The input signal. A 1D array.
@@ -347,7 +357,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
self._nx = nx
self._ny = ny if ny is not None else len(x) - nx + 1
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]:
"""
:return:
Input (NX+NY-1,)
@@ -370,11 +380,11 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return self._ny
@property
- def sample_rate(self) -> Optional[float]:
+ def sample_rate(self) -> _Optional[float]:
return self._sample_rate
@property
- def x(self) -> torch.Tensor:
+ def x(self) -> _torch.Tensor:
"""
The input audio data
@@ -383,7 +393,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return self._x
@property
- def y(self) -> torch.Tensor:
+ def y(self) -> _torch.Tensor:
"""
The output audio data
@@ -411,7 +421,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
y (torch.Tensor) - loaded from y_path
Everything else is passed on to __init__
"""
- config = deepcopy(config)
+ config = _deepcopy(config)
sample_rate = config.pop("sample_rate", None)
x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
sample_rate = x_wavinfo.rate
@@ -460,11 +470,11 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _apply_delay(
cls,
- x: torch.Tensor,
- y: torch.Tensor,
- delay: Union[int, float],
+ x: _torch.Tensor,
+ y: _torch.Tensor,
+ delay: _Union[int, float],
method: _DelayInterpolationMethod,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
# Check for floats that could be treated like ints (simpler algorithm)
if isinstance(delay, float) and int(delay) == delay:
delay = int(delay)
@@ -477,8 +487,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _apply_delay_int(
- cls, x: torch.Tensor, y: torch.Tensor, delay: int
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ cls, x: _torch.Tensor, y: _torch.Tensor, delay: int
+ ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
if delay > 0:
x = x[:-delay]
y = y[delay:]
@@ -490,12 +500,12 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _apply_delay_float(
cls,
- x: torch.Tensor,
- y: torch.Tensor,
+ x: _torch.Tensor,
+ y: _torch.Tensor,
delay: float,
method: _DelayInterpolationMethod,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- n_out = len(y) - int(np.ceil(np.abs(delay)))
+ ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
+ n_out = len(y) - int(_np.ceil(_np.abs(delay)))
if delay > 0:
x = x[:n_out]
elif delay < 0:
@@ -506,16 +516,16 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _validate_start_stop(
cls,
- x: torch.Tensor,
- y: torch.Tensor,
- start: Optional[int] = None,
- stop: Optional[int] = None,
- start_samples: Optional[int] = None,
- stop_samples: Optional[int] = None,
- start_seconds: Optional[Union[int, float]] = None,
- stop_seconds: Optional[Union[int, float]] = None,
- sample_rate: Optional[int] = None,
- ) -> Tuple[Optional[int], Optional[int]]:
+ x: _torch.Tensor,
+ y: _torch.Tensor,
+ start: _Optional[int] = None,
+ stop: _Optional[int] = None,
+ start_samples: _Optional[int] = None,
+ stop_samples: _Optional[int] = None,
+ start_seconds: _Optional[_Union[int, float]] = None,
+ stop_seconds: _Optional[_Union[int, float]] = None,
+ sample_rate: _Optional[int] = None,
+ ) -> _Tuple[_Optional[int], _Optional[int]]:
"""
Parse the requested start and stop trim points.
@@ -639,7 +649,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
)
if ny is not None:
assert ny <= len(y) - nx + 1
- if torch.abs(y).max() >= 1.0:
+ if _torch.abs(y).max() >= 1.0:
msg = "Output clipped."
if self._y_path is not None:
msg += f"Source is {self._y_path}"
@@ -648,10 +658,10 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _validate_preceding_silence(
cls,
- x: torch.Tensor,
- start: Optional[int],
+ x: _torch.Tensor,
+ start: _Optional[int],
silent_seconds: float,
- sample_rate: Optional[float],
+ sample_rate: _Optional[float],
):
"""
Make sure that the input is silent before the starting index.
@@ -677,7 +687,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
raw_check_start = start - silent_samples
check_start = max(raw_check_start, 0) if start >= 0 else min(raw_check_start, 0)
check_end = start
- if not torch.all(x[check_start:check_end] == 0.0):
+ if not _torch.all(x[check_start:check_end] == 0.0):
raise XYError(
f"Input provided isn't silent for at least {silent_samples} samples "
"before the starting index. Responses to this non-silent input may "
@@ -685,15 +695,15 @@ class Dataset(AbstractDataset, InitializableFromConfig):
)
-class ConcatDataset(AbstractDataset, InitializableFromConfig):
- def __init__(self, datasets: Sequence[Dataset], flatten=True):
+class ConcatDataset(AbstractDataset, _InitializableFromConfig):
+ def __init__(self, datasets: _Sequence[Dataset], flatten=True):
if flatten:
datasets = self._flatten_datasets(datasets)
self._validate_datasets(datasets)
self._datasets = datasets
self._lookup = self._make_lookup()
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]:
i, j = self._lookup[idx]
return self.datasets[i][j]
@@ -712,7 +722,7 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
init = _dataset_init_registry[config.get("type", "dataset")]
return {
"datasets": tuple(
- init(c) for c in tqdm(config["dataset_configs"], desc="Loading data")
+ init(c) for c in _tqdm(config["dataset_configs"], desc="Loading data")
)
}
@@ -756,8 +766,8 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
return lookup
@classmethod
- def _validate_datasets(cls, datasets: Sequence[Dataset]):
- Reference = namedtuple("Reference", ("index", "val"))
+ def _validate_datasets(cls, datasets: _Sequence[Dataset]):
+ Reference = _namedtuple("Reference", ("index", "val"))
ref_keys, ref_ny = None, None
for i, d in enumerate(datasets):
ref_ny = Reference(i, d.ny) if ref_ny is None else ref_ny
@@ -771,7 +781,7 @@ _dataset_init_registry = {"dataset": Dataset.init_from_config}
def register_dataset_initializer(
- name: str, constructor: Callable[[Any], AbstractDataset], overwrite=False
+ name: str, constructor: _Callable[[_Any], AbstractDataset], overwrite=False
):
"""
If you have other data set types, you can register their initializer by name using
diff --git a/nam/models/_activations.py b/nam/models/_activations.py
@@ -2,8 +2,8 @@
# Created Date: Friday July 29th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
-import torch.nn as nn
+import torch.nn as _nn
-def get_activation(name: str) -> nn.Module:
- return getattr(nn, name)()
+def get_activation(name: str) -> _nn.Module:
+ return getattr(_nn, name)()
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -7,57 +7,64 @@ The foundation of the model without the PyTorch Lightning attributes (losses, tr
steps)
"""
-import abc
-import math
-import pkg_resources
-from typing import Any, Dict, Optional, Tuple, Union
+import abc as _abc
+import math as _math
+import pkg_resources as _pkg_resources
+from typing import (
+ Any as _Any,
+ Dict as _Dict,
+ Optional as _Optional,
+ Tuple as _Tuple,
+ Union as _Union,
+)
-import numpy as np
-import torch
-import torch.nn as nn
+import numpy as _np
+import torch as _torch
+import torch.nn as _nn
-from .._core import InitializableFromConfig
-from ..data import wav_to_tensor
-from .exportable import Exportable
+from .._core import InitializableFromConfig as _InitializableFromConfig
+from ..data import wav_to_tensor as _wav_to_tensor
+from .exportable import Exportable as _Exportable
-class _Base(nn.Module, InitializableFromConfig, Exportable):
- def __init__(self, sample_rate: Optional[float] = None):
+class _Base(_nn.Module, _InitializableFromConfig, _Exportable):
+ def __init__(self, sample_rate: _Optional[float] = None):
super().__init__()
self.register_buffer(
- "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool)
+ "_has_sample_rate",
+ _torch.tensor(sample_rate is not None, dtype=_torch.bool),
)
self.register_buffer(
- "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate)
+ "_sample_rate", _torch.tensor(0.0 if sample_rate is None else sample_rate)
)
@property
- @abc.abstractmethod
+ @_abc.abstractmethod
def pad_start_default(self) -> bool:
pass
@property
- @abc.abstractmethod
+ @_abc.abstractmethod
def receptive_field(self) -> int:
"""
Receptive field of the model
"""
pass
- @abc.abstractmethod
- def forward(self, *args, **kwargs) -> torch.Tensor:
+ @_abc.abstractmethod
+ def forward(self, *args, **kwargs) -> _torch.Tensor:
pass
@classmethod
- def _metadata_loudness_x(cls) -> torch.Tensor:
- return wav_to_tensor(
- pkg_resources.resource_filename(
+ def _metadata_loudness_x(cls) -> _torch.Tensor:
+ return _wav_to_tensor(
+ _pkg_resources.resource_filename(
"nam", "models/_resources/loudness_input.wav"
)
)
@property
- def device(self) -> Optional[torch.device]:
+ def device(self) -> _Optional[_torch.device]:
"""
Helpful property, where the parameters of the model live.
"""
@@ -69,13 +76,13 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
return None
@property
- def sample_rate(self) -> Optional[float]:
+ def sample_rate(self) -> _Optional[float]:
return self._sample_rate.item() if self._has_sample_rate else None
@sample_rate.setter
- def sample_rate(self, val: Optional[float]):
- self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool)
- self._sample_rate = torch.tensor(0.0 if val is None else val)
+ def sample_rate(self, val: _Optional[float]):
+ self._has_sample_rate = _torch.tensor(val is not None, dtype=_torch.bool)
+ self._sample_rate = _torch.tensor(0.0 if val is None else val)
def _get_export_dict(self):
d = super()._get_export_dict()
@@ -97,17 +104,17 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
"""
x = self._metadata_loudness_x().to(self.device)
y = self._at_nominal_settings(gain * x)
- loudness = torch.sqrt(torch.mean(torch.square(y)))
+ loudness = _torch.sqrt(_torch.mean(_torch.square(y)))
if db:
- loudness = 20.0 * torch.log10(loudness)
+ loudness = 20.0 * _torch.log10(loudness)
return loudness.item()
def _metadata_gain(self) -> float:
"""
Between 0 and 1, how much gain / compression does the model seem to have?
"""
- x = np.linspace(0.0, 1.0, 11)
- y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x])
+ x = _np.linspace(0.0, 1.0, 11)
+ y = _np.array([self._metadata_loudness(gain=gain, db=False) for gain in x])
#
# O ^ o o o o o o
# u | o x +-------------------------------------+
@@ -123,14 +130,14 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
gain_range = max_gain - min_gain
this_gain = y.sum()
normalized_gain = (this_gain - min_gain) / gain_range
- return np.clip(normalized_gain, 0.0, 1.0)
+ return _np.clip(normalized_gain, 0.0, 1.0)
- def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
+ def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor:
# parametric?...
raise NotImplementedError()
- @abc.abstractmethod
- def _forward(self, *args) -> torch.Tensor:
+ @_abc.abstractmethod
+ def _forward(self, *args) -> _torch.Tensor:
"""
The true forward method.
@@ -139,27 +146,27 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
"""
pass
- def _export_input_output_args(self) -> Tuple[Any]:
+ def _export_input_output_args(self) -> _Tuple[_Any]:
"""
Create any other args necessesary (e.g. params to eval at)
"""
return ()
- def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
+ def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]:
args = self._export_input_output_args()
rate = self.sample_rate
if rate is None:
raise RuntimeError(
"Cannot export model's input and output without a sample rate."
)
- x = torch.cat(
+ x = _torch.cat(
[
- torch.zeros((rate,)),
+ _torch.zeros((rate,)),
0.5
- * torch.sin(
- 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1]
+ * _torch.sin(
+ 2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1]
),
- torch.zeros((rate,)),
+ _torch.zeros((rate,)),
]
)
# Use pad start to ensure same length as requested by ._export_input_output()
@@ -174,14 +181,15 @@ class BaseNet(_Base):
super().__init__(*args, **kwargs)
self._mps_65536_fallback = False
- def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
+ def forward(self, x: _torch.Tensor, pad_start: _Optional[bool] = None, **kwargs):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
if scalar:
x = x[None]
if pad_start:
- x = torch.cat(
- (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
+ x = _torch.cat(
+ (_torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x),
+ dim=1,
)
if x.shape[1] < self.receptive_field:
raise ValueError(
@@ -193,10 +201,10 @@ class BaseNet(_Base):
y = y[0]
return y
- def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
+ def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor:
return self(x)
- def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ def _forward_mps_safe(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor:
"""
Wrap `._forward()` to protect against MPS-unsupported input lengths
beyond 65,536 samples.
@@ -213,7 +221,7 @@ class BaseNet(_Base):
"===WARNING===\n"
"NAM encountered a bug in PyTorch's MPS backend and will "
"switch to a fallback.\n"
- f"Your version of PyTorch is {torch.__version__}.\n"
+ f"Your version of PyTorch is {_torch.__version__}.\n"
"Please report this in an Issue at:\n"
"https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
"\n"
@@ -236,10 +244,10 @@ class BaseNet(_Base):
# Bit hacky, but correct.
if j == x.shape[1]:
break
- return torch.cat(out_list, dim=1)
+ return _torch.cat(out_list, dim=1)
- @abc.abstractmethod
- def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ @_abc.abstractmethod
+ def _forward(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor:
"""
The true forward method.
@@ -248,7 +256,7 @@ class BaseNet(_Base):
"""
pass
- def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
+ def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]:
d = super()._get_non_user_metadata()
d["loudness"] = self._metadata_loudness()
d["gain"] = self._metadata_gain()
diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py
@@ -2,28 +2,37 @@
# Created Date: Saturday February 5th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
-import json
-import math
-from enum import Enum
-from functools import partial
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from typing import Optional, Sequence, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
+import json as _json
+import math as _math
+from enum import Enum as _Enum
+from functools import partial as _partial
+from pathlib import Path as _Path
+from tempfile import TemporaryDirectory as _TemporaryDirectory
+from typing import (
+ Optional as _Optional,
+ Sequence as _Sequence,
+ Tuple as _Tuple,
+ Union as _Union,
+)
+
+import numpy as _np
+import torch as _torch
+import torch.nn as _nn
+import torch.nn.functional as _F
from .. import __version__
-from ..data import wav_to_tensor
-from ._activations import get_activation
-from .base import BaseNet
-from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME
+from ..data import wav_to_tensor as _wav_to_tensor
+from ._activations import get_activation as _get_activation
+from .base import BaseNet as _BaseNet
+from ._names import (
+ ACTIVATION_NAME as _ACTIVATION_NAME,
+ BATCHNORM_NAME as _BATCHNORM_NAME,
+ CONV_NAME as _CONV_NAME,
+)
-class TrainStrategy(Enum):
+class TrainStrategy(_Enum):
STRIDE = "stride"
DILATE = "dilate"
@@ -31,7 +40,7 @@ class TrainStrategy(Enum):
default_train_strategy = TrainStrategy.DILATE
-class _Functional(nn.Module):
+class _Functional(_nn.Module):
"""
Define a layer by a function w/ no params
"""
@@ -44,37 +53,37 @@ class _Functional(nn.Module):
return self._op(*args, **kwargs)
-class _IR(nn.Module):
- def __init__(self, filename: Union[str, Path]):
+class _IR(_nn.Module):
+ def __init__(self, filename: _Union[str, _Path]):
super().__init__()
- self.register_buffer("_weight", reversed(wav_to_tensor(filename))[None, None])
+ self.register_buffer("_weight", reversed(_wav_to_tensor(filename))[None, None])
@property
def length(self) -> int:
return self._weight.shape[-1]
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: _torch.Tensor) -> _torch.Tensor:
"""
:param x: (N,D)
:return: (N,D-length+1)
"""
- return F.conv1d(x[:, None], self._weight)[:, 0]
+ return _F.conv1d(x[:, None], self._weight)[:, 0]
def _conv_net(
channels: int = 32,
- dilations: Sequence[int] = None,
+ dilations: _Sequence[int] = None,
batchnorm: bool = False,
activation: str = "Tanh",
-) -> nn.Sequential:
+) -> _nn.Sequential:
def block(cin, cout, dilation):
- net = nn.Sequential()
+ net = _nn.Sequential()
net.add_module(
- CONV_NAME, nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm)
+ _CONV_NAME, _nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm)
)
if batchnorm:
- net.add_module(BATCHNORM_NAME, nn.BatchNorm1d(cout))
- net.add_module(ACTIVATION_NAME, get_activation(activation))
+ net.add_module(_BATCHNORM_NAME, _nn.BatchNorm1d(cout))
+ net.add_module(_ACTIVATION_NAME, _get_activation(activation))
return net
def check_and_expand(n, x):
@@ -86,19 +95,19 @@ def _conv_net(
dilations = [1, 2, 4, 8] if dilations is None else dilations
receptive_field = sum(dilations) + 1
- net = nn.Sequential()
- net.add_module("expand", _Functional(partial(check_and_expand, receptive_field)))
+ net = _nn.Sequential()
+ net.add_module("expand", _Functional(_partial(check_and_expand, receptive_field)))
cin = 1
cout = channels
for i, dilation in enumerate(dilations):
net.add_module(f"block_{i}", block(cin, cout, dilation))
cin = cout
- net.add_module("head", nn.Conv1d(channels, 1, 1))
- net.add_module("flatten", nn.Flatten())
+ net.add_module("head", _nn.Conv1d(channels, 1, 1))
+ net.add_module("flatten", _nn.Flatten())
return net
-class ConvNet(BaseNet):
+class ConvNet(_BaseNet):
"""
A straightforward convolutional neural network.
@@ -109,8 +118,8 @@ class ConvNet(BaseNet):
self,
*args,
train_strategy: TrainStrategy = default_train_strategy,
- ir: Optional[_IR] = None,
- sample_rate: Optional[float] = None,
+ ir: _Optional[_IR] = None,
+ sample_rate: _Optional[float] = None,
**kwargs,
):
super().__init__(sample_rate=sample_rate)
@@ -149,12 +158,12 @@ class ConvNet(BaseNet):
@property
def _activation(self):
return (
- self._net._modules["block_0"]._modules[ACTIVATION_NAME].__class__.__name__
+ self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__
)
@property
def _channels(self) -> int:
- return self._net._modules["block_0"]._modules[CONV_NAME].weight.shape[0]
+ return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0]
@property
def _num_layers(self) -> int:
@@ -162,14 +171,14 @@ class ConvNet(BaseNet):
@property
def _batchnorm(self) -> bool:
- return BATCHNORM_NAME in self._net._modules["block_0"]._modules
-
- def export_cpp_header(self, filename: Path):
- with TemporaryDirectory() as tmpdir:
- tmpdir = Path(tmpdir)
- self.export(Path(tmpdir))
- with open(Path(tmpdir, "config.json"), "r") as fp:
- _c = json.load(fp)
+ return _BATCHNORM_NAME in self._net._modules["block_0"]._modules
+
+ def export_cpp_header(self, filename: _Path):
+ with _TemporaryDirectory() as tmpdir:
+ tmpdir = _Path(tmpdir)
+ self.export(_Path(tmpdir))
+ with open(_Path(tmpdir, "config.json"), "r") as fp:
+ _c = _json.load(fp)
version = _c["version"]
config = _c["config"]
with open(filename, "w") as f:
@@ -187,7 +196,10 @@ class ConvNet(BaseNet):
f"const std::string ACTIVATION = \"{config['activation']}\";\n",
"std::vector<float> PARAMS{"
+ ",".join(
- [f"{w:.16f}" for w in np.load(Path(tmpdir, "weights.npy"))]
+ [
+ f"{w:.16f}"
+ for w in _np.load(_Path(tmpdir, "weights.npy"))
+ ]
)
+ "};\n",
)
@@ -201,11 +213,11 @@ class ConvNet(BaseNet):
"activation": self._activation,
}
- def _export_input_output(self, x=None) -> Tuple[np.ndarray, np.ndarray]:
+ def _export_input_output(self, x=None) -> _Tuple[_np.ndarray, _np.ndarray]:
"""
:return: (L,), (L,)
"""
- with torch.no_grad():
+ with _torch.no_grad():
training = self.training
self.eval()
x = self._export_input_signal() if x is None else x
@@ -222,18 +234,18 @@ class ConvNet(BaseNet):
raise RuntimeError(
"Cannot export model's input and output without a sample rate."
)
- return torch.cat(
+ return _torch.cat(
[
- torch.zeros((rate,)),
+ _torch.zeros((rate,)),
0.5
- * torch.sin(
- 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1]
+ * _torch.sin(
+ 2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1]
),
- torch.zeros((rate,)),
+ _torch.zeros((rate,)),
]
)
- def _export_weights(self) -> np.ndarray:
+ def _export_weights(self) -> _np.ndarray:
"""
weights are serialized to weights.npy in the following order:
* (expand: no params)
@@ -256,21 +268,21 @@ class ConvNet(BaseNet):
for i in range(self._num_layers):
block_name = f"block_{i}"
block = self._net._modules[block_name]
- conv = block._modules[CONV_NAME]
+ conv = block._modules[_CONV_NAME]
params.append(conv.weight.flatten())
if conv.bias is not None:
params.append(conv.bias.flatten())
if self._batchnorm:
- bn = block._modules[BATCHNORM_NAME]
+ bn = block._modules[_BATCHNORM_NAME]
params.append(bn.running_mean.flatten())
params.append(bn.running_var.flatten())
params.append(bn.weight.flatten())
params.append(bn.bias.flatten())
- params.append(torch.Tensor([bn.eps]).to(bn.weight.device))
+ params.append(_torch.Tensor([bn.eps]).to(bn.weight.device))
head = self._net._modules["head"]
params.append(head.weight.flatten())
params.append(head.bias.flatten())
- params = torch.cat(params).detach().cpu().numpy()
+ params = _torch.cat(params).detach().cpu().numpy()
return params
def _forward(self, x):
@@ -279,13 +291,13 @@ class ConvNet(BaseNet):
y = self._ir(y)
return y
- def _get_dilations(self) -> Tuple[int]:
+ def _get_dilations(self) -> _Tuple[int]:
return tuple(
- self._net._modules[f"block_{i}"]._modules[CONV_NAME].dilation[0]
+ self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0]
for i in range(self._num_blocks)
)
- def _get_num_blocks(self, net: nn.Sequential):
+ def _get_num_blocks(self, net: _nn.Sequential):
i = 0
while True:
if f"block_{i}" not in net._modules:
diff --git a/nam/models/exportable.py b/nam/models/exportable.py
@@ -2,32 +2,39 @@
# Created Date: Tuesday February 8th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
-import abc
-import json
-import logging
-from datetime import datetime
-from enum import Enum
-from pathlib import Path
-from typing import Any, Dict, Optional, Sequence, Tuple, Union
-
-import numpy as np
-
-from .metadata import Date, UserMetadata
-
-logger = logging.getLogger(__name__)
+import abc as _abc
+import json as _json
+import logging as _logging
+from datetime import datetime as _datetime
+from enum import Enum as _Enum
+from pathlib import Path as _Path
+from typing import (
+ Any as _Any,
+ Dict as _Dict,
+ Optional as _Optional,
+ Sequence as _Sequence,
+ Tuple as _Tuple,
+ Union as _Union,
+)
+
+import numpy as _np
+
+from .metadata import Date as _Date, UserMetadata as _UserMetadata
+
+logger = _logging.getLogger(__name__)
# Model version is independent from package version as of package version 0.5.2 so that
# the API of the package can iterate at a different pace from that of the model files.
_MODEL_VERSION = "0.5.4"
-def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
+def _cast_enums(d: _Dict[_Any, _Any]) -> _Dict[_Any, _Any]:
"""
Casts enum-type keys to their values
"""
out = {}
for key, val in d.items():
- if isinstance(val, Enum):
+ if isinstance(val, _Enum):
val = val.value
if isinstance(val, dict):
val = _cast_enums(val)
@@ -35,7 +42,7 @@ def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
return out
-class Exportable(abc.ABC):
+class Exportable(_abc.ABC):
"""
Interface for my custon export format for use in the plugin.
"""
@@ -44,11 +51,11 @@ class Exportable(abc.ABC):
def export(
self,
- outdir: Path,
+ outdir: _Path,
include_snapshot: bool = False,
basename: str = "model",
- user_metadata: Optional[UserMetadata] = None,
- other_metadata: Optional[dict] = None,
+ user_metadata: _Optional[_UserMetadata] = None,
+ other_metadata: _Optional[dict] = None,
):
"""
Interface for exporting.
@@ -81,29 +88,29 @@ class Exportable(abc.ABC):
training = self.training
self.eval()
- with open(Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp:
- json.dump(model_dict, fp)
+ with open(_Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp:
+ _json.dump(model_dict, fp)
if include_snapshot:
x, y = self._export_input_output()
- x_path = Path(outdir, "test_inputs.npy")
- y_path = Path(outdir, "test_outputs.npy")
+ x_path = _Path(outdir, "test_inputs.npy")
+ y_path = _Path(outdir, "test_outputs.npy")
logger.debug(f"Saving snapshot input to {x_path}")
- np.save(x_path, x)
+ _np.save(x_path, x)
logger.debug(f"Saving snapshot output to {y_path}")
- np.save(y_path, y)
+ _np.save(y_path, y)
# And resume training state
self.train(training)
- @abc.abstractmethod
- def export_cpp_header(self, filename: Path):
+ @_abc.abstractmethod
+ def export_cpp_header(self, filename: _Path):
"""
Export a .h file to compile into the plugin with the weights written right out
as text
"""
pass
- def export_onnx(self, filename: Path):
+ def export_onnx(self, filename: _Path):
"""
Export model in format for ONNX Runtime
"""
@@ -112,7 +119,7 @@ class Exportable(abc.ABC):
f"{self.__class__.__name__}"
)
- def import_weights(self, weights: Sequence[float]):
+ def import_weights(self, weights: _Sequence[float]):
"""
Inverse of `._export_weights()
"""
@@ -121,7 +128,7 @@ class Exportable(abc.ABC):
"implemented yet."
)
- @abc.abstractmethod
+ @_abc.abstractmethod
def _export_config(self):
"""
Creates the JSON of the model's archtecture hyperparameters (number of layers,
@@ -131,8 +138,8 @@ class Exportable(abc.ABC):
"""
pass
- @abc.abstractmethod
- def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
+ @_abc.abstractmethod
+ def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]:
"""
Create an input and corresponding output signal to verify its behavior.
@@ -141,8 +148,8 @@ class Exportable(abc.ABC):
"""
pass
- @abc.abstractmethod
- def _export_weights(self) -> np.ndarray:
+ @_abc.abstractmethod
+ def _export_weights(self) -> _np.ndarray:
"""
Flatten the weights out to a 1D array
"""
@@ -157,13 +164,13 @@ class Exportable(abc.ABC):
"weights": self._export_weights().tolist(),
}
- def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
+ def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]:
"""
Get any metadata that's non-user-provided (date, loudness, gain)
"""
- t = datetime.now()
+ t = _datetime.now()
return {
- "date": Date(
+ "date": _Date(
year=t.year,
month=t.month,
day=t.day,
diff --git a/nam/models/linear.py b/nam/models/linear.py
@@ -6,18 +6,18 @@
Linear model
"""
-import numpy as np
-import torch
-import torch.nn as nn
+import numpy as _np
+import torch as _torch
+import torch.nn as _nn
from .._version import __version__
-from .base import BaseNet
+from .base import BaseNet as _BaseNet
-class Linear(BaseNet):
+class Linear(_BaseNet):
def __init__(self, receptive_field: int, *args, bias: bool = False, **kwargs):
super().__init__(*args, **kwargs)
- self._net = nn.Conv1d(1, 1, receptive_field, bias=bias)
+ self._net = _nn.Conv1d(1, 1, receptive_field, bias=bias)
@property
def pad_start_default(self) -> bool:
@@ -34,7 +34,7 @@ class Linear(BaseNet):
def _bias(self) -> bool:
return self._net.bias is not None
- def _forward(self, x: torch.Tensor) -> torch.Tensor:
+ def _forward(self, x: _torch.Tensor) -> _torch.Tensor:
return self._net(x[:, None])[:, 0]
def _export_config(self):
@@ -43,9 +43,9 @@ class Linear(BaseNet):
"bias": self._bias,
}
- def _export_weights(self) -> np.ndarray:
+ def _export_weights(self) -> _np.ndarray:
params_list = [self._net.weight.flatten()]
if self._bias:
params_list.append(self._net.bias.flatten())
- params = torch.cat(params_list).detach().cpu().numpy()
+ params = _torch.cat(params_list).detach().cpu().numpy()
return params
diff --git a/nam/models/losses.py b/nam/models/losses.py
@@ -6,13 +6,13 @@
Loss functions
"""
-from typing import Optional
+from typing import Optional as _Optional
-import torch
-from auraloss.freq import MultiResolutionSTFTLoss
+import torch as _torch
+from auraloss.freq import MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss
-def apply_pre_emphasis_filter(x: torch.Tensor, coef: float) -> torch.Tensor:
+def apply_pre_emphasis_filter(x: _torch.Tensor, coef: float) -> _torch.Tensor:
"""
Apply first-order pre-emphsis filter
@@ -24,7 +24,7 @@ def apply_pre_emphasis_filter(x: torch.Tensor, coef: float) -> torch.Tensor:
return x[..., 1:] - coef * x[..., :-1]
-def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+def esr(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor:
"""
ESR of (a batch of) predictions & targets
@@ -42,18 +42,18 @@ def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
raise ValueError(
f"Expect 2D targets (batch_size, num_samples). Got {targets.shape}"
)
- return torch.mean(
- torch.mean(torch.square(preds - targets), dim=1)
- / torch.mean(torch.square(targets), dim=1)
+ return _torch.mean(
+ _torch.mean(_torch.square(preds - targets), dim=1)
+ / _torch.mean(_torch.square(targets), dim=1)
)
def multi_resolution_stft_loss(
- preds: torch.Tensor,
- targets: torch.Tensor,
- loss_func: Optional[MultiResolutionSTFTLoss] = None,
- device: Optional[torch.device] = None,
-) -> torch.Tensor:
+ preds: _torch.Tensor,
+ targets: _torch.Tensor,
+ loss_func: _Optional[_MultiResolutionSTFTLoss] = None,
+ device: _Optional[_torch.device] = None,
+) -> _torch.Tensor:
"""
Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
B: Batch size
@@ -66,13 +66,13 @@ def multi_resolution_stft_loss(
:param device: If provided, send the preds and targets to the provided device.
:return: ()
"""
- loss_func = MultiResolutionSTFTLoss() if loss_func is None else loss_func
+ loss_func = _MultiResolutionSTFTLoss() if loss_func is None else loss_func
if device is not None:
preds, targets = [z.to(device) for z in (preds, targets)]
return loss_func(preds, targets)
-def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+def mse_fft(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor:
"""
Fourier loss
@@ -80,7 +80,7 @@ def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
:param targets: Same as preds
:return: ()
"""
- fp = torch.fft.fft(preds)
- ft = torch.fft.fft(targets)
+ fp = _torch.fft.fft(preds)
+ ft = _torch.fft.fft(targets)
e = fp - ft
- return torch.mean(torch.square(e.abs()))
+ return _torch.mean(_torch.square(e.abs()))
diff --git a/nam/models/metadata.py b/nam/models/metadata.py
@@ -6,14 +6,14 @@
Metadata about models
"""
-from enum import Enum
-from typing import Optional
+from enum import Enum as _Enum
+from typing import Optional as _Optional
-from pydantic import BaseModel
+from pydantic import BaseModel as _BaseModel
# Note: if you change this enum, you need to update the options in easy_colab.ipynb!
-class GearType(Enum):
+class GearType(_Enum):
AMP = "amp"
PEDAL = "pedal"
PEDAL_AMP = "pedal_amp"
@@ -24,7 +24,7 @@ class GearType(Enum):
# Note: if you change this enum, you need to update the options in easy_colab.ipynb!
-class ToneType(Enum):
+class ToneType(_Enum):
CLEAN = "clean"
OVERDRIVE = "overdrive"
CRUNCH = "crunch"
@@ -32,7 +32,7 @@ class ToneType(Enum):
FUZZ = "fuzz"
-class Date(BaseModel):
+class Date(_BaseModel):
year: int
month: int
day: int
@@ -41,7 +41,7 @@ class Date(BaseModel):
second: int
-class UserMetadata(BaseModel):
+class UserMetadata(_BaseModel):
"""
Metadata that users provide for a NAM model
@@ -57,11 +57,11 @@ class UserMetadata(BaseModel):
the model.
"""
- name: Optional[str] = None
- modeled_by: Optional[str] = None
- gear_type: Optional[GearType] = None
- gear_make: Optional[str] = None
- gear_model: Optional[str] = None
- tone_type: Optional[ToneType] = None
- input_level_dbu: Optional[float] = None
- output_level_dbu: Optional[float] = None
+ name: _Optional[str] = None
+ modeled_by: _Optional[str] = None
+ gear_type: _Optional[GearType] = None
+ gear_make: _Optional[str] = None
+ gear_model: _Optional[str] = None
+ tone_type: _Optional[ToneType] = None
+ input_level_dbu: _Optional[float] = None
+ output_level_dbu: _Optional[float] = None
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -8,20 +8,20 @@ Recurrent models (LSTM)
TODO batch_first=False (I get it...)
"""
-import abc
-import json
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from typing import Optional, Tuple
+import abc as _abc
+import json as _json
+from pathlib import Path as _Path
+from tempfile import TemporaryDirectory as _TemporaryDirectory
+from typing import Optional as Optional, Tuple as _Tuple
-import numpy as np
-import torch
-import torch.nn as nn
+import numpy as _np
+import torch as _torch
+import torch.nn as _nn
-from .base import BaseNet
+from .base import BaseNet as _BaseNet
-class _L(nn.LSTM):
+class _L(_nn.LSTM):
"""
Tweaks to PyTorch LSTM module
* Up the remembering
@@ -47,24 +47,24 @@ class _L(nn.LSTM):
# DH: Hidden state dimension
# [0]: hidden (L,DH)
# [1]: cell (L,DH)
-_LSTMHiddenType = torch.Tensor
-_LSTMCellType = torch.Tensor
-_LSTMHiddenCellType = Tuple[_LSTMHiddenType, _LSTMCellType]
+_LSTMHiddenType = _torch.Tensor
+_LSTMCellType = _torch.Tensor
+_LSTMHiddenCellType = _Tuple[_LSTMHiddenType, _LSTMCellType]
# TODO get this somewhere more core-ish
-class _ExportsWeights(abc.ABC):
- @abc.abstractmethod
- def export_weights(self) -> np.ndarray:
+class _ExportsWeights(_abc.ABC):
+ @_abc.abstractmethod
+ def export_weights(self) -> _np.ndarray:
"""
:return: a 1D array of weights
"""
pass
-class _Linear(nn.Linear, _ExportsWeights):
+class _Linear(_nn.Linear, _ExportsWeights):
def export_weights(self):
- return np.concatenate(
+ return _np.concatenate(
[
self.weight.data.detach().cpu().numpy().flatten(),
self.bias.data.detach().cpu().numpy().flatten(),
@@ -72,7 +72,7 @@ class _Linear(nn.Linear, _ExportsWeights):
)
-class LSTM(BaseNet):
+class LSTM(_BaseNet):
"""
ABC for recurrent architectures
"""
@@ -105,16 +105,16 @@ class LSTM(BaseNet):
self._head = self._init_head(hidden_size)
self._train_burn_in = train_burn_in
self._train_truncate = train_truncate
- self._initial_cell = nn.Parameter(
- torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
+ self._initial_cell = _nn.Parameter(
+ _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
)
- self._initial_hidden = nn.Parameter(
- torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
+ self._initial_hidden = _nn.Parameter(
+ _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
)
self._get_initial_state_burn_in = 48_000
@property
- def input_device(self) -> torch.device:
+ def input_device(self) -> _torch.device:
"""
What device does the input need to be on?
"""
@@ -129,12 +129,12 @@ class LSTM(BaseNet):
# I should simplify this...
return True
- def export_cpp_header(self, filename: Path):
- with TemporaryDirectory() as tmpdir:
- tmpdir = Path(tmpdir)
- LSTM.export(self, Path(tmpdir)) # Hacky...need to work w/ CatLSTM
- with open(Path(tmpdir, "model.nam"), "r") as fp:
- _c = json.load(fp)
+ def export_cpp_header(self, filename: _Path):
+ with _TemporaryDirectory() as tmpdir:
+ tmpdir = _Path(tmpdir)
+ LSTM.export(self, _Path(tmpdir)) # Hacky...need to work w/ CatLSTM
+ with open(_Path(tmpdir, "model.nam"), "r") as fp:
+ _c = _json.load(fp)
version = _c["version"]
config = _c["config"]
s_parametric = self._export_cpp_header_parametric(config.get("parametric"))
@@ -159,7 +159,7 @@ class LSTM(BaseNet):
)
)
- def _apply_head(self, features: torch.Tensor) -> torch.Tensor:
+ def _apply_head(self, features: _torch.Tensor) -> _torch.Tensor:
"""
:param features: (B,S,DH)
:return: (B,S)
@@ -167,8 +167,8 @@ class LSTM(BaseNet):
return self._head(features)[:, :, 0]
def _forward(
- self, x: torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None
- ) -> torch.Tensor:
+ self, x: _torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None
+ ) -> _torch.Tensor:
"""
:param x: (B,L) or (B,L,D)
:return: (B,L)
@@ -183,7 +183,7 @@ class LSTM(BaseNet):
x[:, i : i + BLOCK_SIZE, :], hidden_state
)
outputs.append(out)
- return torch.cat(outputs, dim=1), hidden_state # assert batch_first
+ return _torch.cat(outputs, dim=1), hidden_state # assert batch_first
last_hidden_state = (
self._initial_state(len(x)) if initial_state is None else initial_state
@@ -208,12 +208,12 @@ class LSTM(BaseNet):
x[:, i : i + self._train_truncate, :], last_hidden_state
)
output_features_list.append(last_output_features)
- output_features = torch.cat(output_features_list, dim=1)
+ output_features = _torch.cat(output_features_list, dim=1)
return self._apply_head(output_features)
def _export_cell_weights(
- self, i: int, hidden_state: torch.Tensor, cell_state: torch.Tensor
- ) -> np.ndarray:
+ self, i: int, hidden_state: _torch.Tensor, cell_state: _torch.Tensor
+ ) -> _np.ndarray:
"""
* weight matrix (xh -> ifco)
* bias vector
@@ -222,7 +222,7 @@ class LSTM(BaseNet):
"""
tensors = [
- torch.cat(
+ _torch.cat(
[
getattr(self._core, f"weight_ih_l{i}").data,
getattr(self._core, f"weight_hh_l{i}").data,
@@ -234,7 +234,7 @@ class LSTM(BaseNet):
hidden_state,
cell_state,
]
- return np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors])
+ return _np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors])
def _export_config(self):
return {
@@ -259,7 +259,7 @@ class LSTM(BaseNet):
* Head weights
* Head bias
"""
- return np.concatenate(
+ return _np.concatenate(
[
self._export_cell_weights(i, h, c)
for i, (h, c) in enumerate(zip(*self._get_initial_state()))
@@ -279,7 +279,7 @@ class LSTM(BaseNet):
:return: (L,DH), (L,DH)
"""
inputs = (
- torch.zeros((1, self._get_initial_state_burn_in, 1))
+ _torch.zeros((1, self._get_initial_state_burn_in, 1))
if inputs is None
else inputs
).to(self.input_device)
@@ -298,7 +298,7 @@ class LSTM(BaseNet):
(self._initial_hidden, self._initial_cell)
if n is None
else (
- torch.tile(self._initial_hidden[:, None], (1, n, 1)),
- torch.tile(self._initial_cell[:, None], (1, n, 1)),
+ _torch.tile(self._initial_hidden[:, None], (1, n, 1)),
+ _torch.tile(self._initial_cell[:, None], (1, n, 1)),
)
)
diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py
@@ -7,34 +7,39 @@ WaveNet implementation
https://arxiv.org/abs/1609.03499
"""
-import json
-from copy import deepcopy
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from typing import Dict, Optional, Sequence, Tuple
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-from ._activations import get_activation
-from .base import BaseNet
-from ._names import ACTIVATION_NAME, CONV_NAME
-
-
-class Conv1d(nn.Conv1d):
- def export_weights(self) -> torch.Tensor:
+import json as _json
+from copy import deepcopy as _deepcopy
+from pathlib import Path as _Path
+from tempfile import TemporaryDirectory as _TemporaryDirectory
+from typing import (
+ Dict as _Dict,
+ Optional as _Optional,
+ Sequence as _Sequence,
+ Tuple as _Tuple,
+)
+
+import numpy as _np
+import torch as _torch
+import torch.nn as _nn
+
+from ._activations import get_activation as _get_activation
+from .base import BaseNet as _BaseNet
+from ._names import ACTIVATION_NAME as _ACTIVATION_NAME, CONV_NAME as _CONV_NAME
+
+
+class Conv1d(_nn.Conv1d):
+ def export_weights(self) -> _torch.Tensor:
tensors = []
if self.weight is not None:
tensors.append(self.weight.data.flatten())
if self.bias is not None:
tensors.append(self.bias.data.flatten())
if len(tensors) == 0:
- return torch.zeros((0,))
+ return _torch.zeros((0,))
else:
- return torch.cat(tensors)
+ return _torch.cat(tensors)
- def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ def import_weights(self, weights: _torch.Tensor, i: int) -> int:
if self.weight is not None:
n = self.weight.numel()
self.weight.data = (
@@ -50,7 +55,7 @@ class Conv1d(nn.Conv1d):
return i
-class _Layer(nn.Module):
+class _Layer(_nn.Module):
def __init__(
self,
condition_size: int,
@@ -67,7 +72,7 @@ class _Layer(nn.Module):
# Custom init: favors direct input-output
# self._conv.weight.data.zero_()
self._input_mixer = Conv1d(condition_size, mid_channels, 1, bias=False)
- self._activation = get_activation(activation)
+ self._activation = _get_activation(activation)
self._activation_name = activation
self._1x1 = Conv1d(channels, channels, 1)
self._gated = gated
@@ -88,8 +93,8 @@ class _Layer(nn.Module):
def kernel_size(self) -> int:
return self._conv.kernel_size[0]
- def export_weights(self) -> torch.Tensor:
- return torch.cat(
+ def export_weights(self) -> _torch.Tensor:
+ return _torch.cat(
[
self.conv.export_weights(),
self._input_mixer.export_weights(),
@@ -98,8 +103,8 @@ class _Layer(nn.Module):
)
def forward(
- self, x: torch.Tensor, h: Optional[torch.Tensor], out_length: int
- ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
+ self, x: _torch.Tensor, h: _Optional[_torch.Tensor], out_length: int
+ ) -> _Tuple[_Optional[_torch.Tensor], _torch.Tensor]:
"""
:param x: (B,C,L1) From last layer
:param h: (B,DX,L2) Conditioning. If first, ignored.
@@ -117,7 +122,7 @@ class _Layer(nn.Module):
if not self._gated
else (
self._activation(z1[:, : self._channels])
- * torch.sigmoid(z1[:, self._channels :])
+ * _torch.sigmoid(z1[:, self._channels :])
)
)
return (
@@ -125,7 +130,7 @@ class _Layer(nn.Module):
post_activation[:, :, -out_length:],
)
- def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ def import_weights(self, weights: _torch.Tensor, i: int) -> int:
i = self.conv.import_weights(weights, i)
i = self._input_mixer.import_weights(weights, i)
return self._1x1.import_weights(weights, i)
@@ -135,7 +140,7 @@ class _Layer(nn.Module):
return self._1x1.in_channels
-class _Layers(nn.Module):
+class _Layers(_nn.Module):
"""
Takes in the input and condition (and maybe the head input so far); outputs the
layer output and head input.
@@ -152,14 +157,14 @@ class _Layers(nn.Module):
head_size,
channels: int,
kernel_size: int,
- dilations: Sequence[int],
+ dilations: _Sequence[int],
activation: str = "Tanh",
gated: bool = True,
head_bias: bool = True,
):
super().__init__()
self._rechannel = Conv1d(input_size, channels, 1, bias=False)
- self._layers = nn.ModuleList(
+ self._layers = _nn.ModuleList(
[
_Layer(
condition_size, channels, kernel_size, dilation, activation, gated
@@ -187,16 +192,16 @@ class _Layers(nn.Module):
return 1 + (self._kernel_size - 1) * sum(self._dilations)
def export_config(self):
- return deepcopy(self._config)
+ return _deepcopy(self._config)
- def export_weights(self) -> torch.Tensor:
- return torch.cat(
+ def export_weights(self) -> _torch.Tensor:
+ return _torch.cat(
[self._rechannel.export_weights()]
+ [layer.export_weights() for layer in self._layers]
+ [self._head_rechannel.export_weights()]
)
- def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ def import_weights(self, weights: _torch.Tensor, i: int) -> int:
i = self._rechannel.import_weights(weights, i)
for layer in self._layers:
i = layer.import_weights(weights, i)
@@ -204,10 +209,10 @@ class _Layers(nn.Module):
def forward(
self,
- x: torch.Tensor,
- c: torch.Tensor,
- head_input: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ x: _torch.Tensor,
+ c: _torch.Tensor,
+ head_input: _Optional[_torch.Tensor] = None,
+ ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
"""
:param x: (B,Dx,L) layer input
:param c: (B,Dc,L) condition
@@ -228,7 +233,7 @@ class _Layers(nn.Module):
return self._head_rechannel(head_input), x
@property
- def _dilations(self) -> Sequence[int]:
+ def _dilations(self) -> _Sequence[int]:
return self._config["dilations"]
@property
@@ -236,7 +241,7 @@ class _Layers(nn.Module):
return self._layers[0].kernel_size
-class _Head(nn.Module):
+class _Head(_nn.Module):
def __init__(
self,
in_channels: int,
@@ -248,14 +253,14 @@ class _Head(nn.Module):
super().__init__()
def block(cx, cy):
- net = nn.Sequential()
- net.add_module(ACTIVATION_NAME, get_activation(activation))
- net.add_module(CONV_NAME, Conv1d(cx, cy, 1))
+ net = _nn.Sequential()
+ net.add_module(_ACTIVATION_NAME, _get_activation(activation))
+ net.add_module(_CONV_NAME, Conv1d(cx, cy, 1))
return net
assert num_layers > 0
- layers = nn.Sequential()
+ layers = _nn.Sequential()
cin = in_channels
for i in range(num_layers):
layers.add_module(
@@ -273,30 +278,30 @@ class _Head(nn.Module):
}
def export_config(self):
- return deepcopy(self._config)
+ return _deepcopy(self._config)
- def export_weights(self) -> torch.Tensor:
- return torch.cat([layer[1].export_weights() for layer in self._layers])
+ def export_weights(self) -> _torch.Tensor:
+ return _torch.cat([layer[1].export_weights() for layer in self._layers])
def forward(self, *args, **kwargs):
return self._layers(*args, **kwargs)
- def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ def import_weights(self, weights: _torch.Tensor, i: int) -> int:
for layer in self._layers:
i = layer[1].import_weights(weights, i)
return i
-class _WaveNet(nn.Module):
+class _WaveNet(_nn.Module):
def __init__(
self,
- layers_configs: Sequence[Dict],
- head_config: Optional[Dict] = None,
+ layers_configs: _Sequence[_Dict],
+ head_config: _Optional[_Dict] = None,
head_scale: float = 1.0,
):
super().__init__()
- self._layers = nn.ModuleList([_Layers(**lc) for lc in layers_configs])
+ self._layers = _nn.ModuleList([_Layers(**lc) for lc in layers_configs])
self._head = None if head_config is None else _Head(**head_config)
self._head_scale = head_scale
@@ -311,22 +316,22 @@ class _WaveNet(nn.Module):
"head_scale": self._head_scale,
}
- def export_weights(self) -> np.ndarray:
+ def export_weights(self) -> _np.ndarray:
"""
:return: 1D array
"""
- weights = torch.cat([layer.export_weights() for layer in self._layers])
+ weights = _torch.cat([layer.export_weights() for layer in self._layers])
if self._head is not None:
- weights = torch.cat([weights, self._head.export_weights()])
- weights = torch.cat([weights.cpu(), torch.Tensor([self._head_scale])])
+ weights = _torch.cat([weights, self._head.export_weights()])
+ weights = _torch.cat([weights.cpu(), _torch.Tensor([self._head_scale])])
return weights.detach().cpu().numpy()
- def import_weights(self, weights: torch.Tensor):
+ def import_weights(self, weights: _torch.Tensor):
i = 0
for layer in self._layers:
i = layer.import_weights(weights, i)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: _torch.Tensor) -> _torch.Tensor:
"""
:param x: (B,Cx,L)
:return: (B,Cy,L-R)
@@ -338,8 +343,8 @@ class _WaveNet(nn.Module):
return head_input if self._head is None else self._head(head_input)
-class WaveNet(BaseNet):
- def __init__(self, *args, sample_rate: Optional[float] = None, **kwargs):
+class WaveNet(_BaseNet):
+ def __init__(self, *args, sample_rate: _Optional[float] = None, **kwargs):
super().__init__(sample_rate=sample_rate)
self._net = _WaveNet(*args, **kwargs)
@@ -351,12 +356,12 @@ class WaveNet(BaseNet):
def receptive_field(self) -> int:
return self._net.receptive_field
- def export_cpp_header(self, filename: Path):
- with TemporaryDirectory() as tmpdir:
- tmpdir = Path(tmpdir)
- WaveNet.export(self, Path(tmpdir)) # Hacky...need to work w/ CatWaveNet
- with open(Path(tmpdir, "model.nam"), "r") as fp:
- _c = json.load(fp)
+ def export_cpp_header(self, filename: _Path):
+ with _TemporaryDirectory() as tmpdir:
+ tmpdir = _Path(tmpdir)
+ WaveNet.export(self, _Path(tmpdir)) # Hacky...need to work w/ CatWaveNet
+ with open(_Path(tmpdir, "model.nam"), "r") as fp:
+ _c = _json.load(fp)
version = _c["version"]
config = _c["config"]
@@ -412,9 +417,9 @@ class WaveNet(BaseNet):
)
)
- def import_weights(self, weights: Sequence[float]):
- if not isinstance(weights, torch.Tensor):
- weights = torch.Tensor(weights)
+ def import_weights(self, weights: _Sequence[float]):
+ if not isinstance(weights, _torch.Tensor):
+ weights = _torch.Tensor(weights)
self._net.import_weights(weights)
def _export_config(self):
@@ -425,7 +430,7 @@ class WaveNet(BaseNet):
raise ValueError("Got non-None parametric config")
return ("nlohmann::json PARAMETRIC {};\n",)
- def _export_weights(self) -> np.ndarray:
+ def _export_weights(self) -> _np.ndarray:
return self._net.export_weights()
def _forward(self, x):
diff --git a/nam/train/_names.py b/nam/train/_names.py
@@ -2,15 +2,15 @@
# Created Date: Monday November 6th 2023
# Author: Steven Atkinson (steven@atkinson.mn)
-from typing import NamedTuple, Optional, Set
+from typing import NamedTuple as _NamedTuple, Optional as _Optional, Set as _Set
-from ._version import PROTEUS_VERSION, Version
+from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version
-class VersionAndName(NamedTuple):
+class VersionAndName(_NamedTuple):
version: Version
name: str
- other_names: Optional[Set[str]]
+ other_names: _Optional[_Set[str]]
# From most- to the least-recently-released:
@@ -22,7 +22,7 @@ INPUT_BASENAMES = (
VersionAndName(Version(2, 0, 0), "v2_0_0.wav", None),
VersionAndName(Version(1, 1, 1), "v1_1_1.wav", None),
VersionAndName(Version(1, 0, 0), "v1.wav", None),
- VersionAndName(PROTEUS_VERSION, "Proteus_Capture.wav", None),
+ VersionAndName(_PROTEUS_VERSION, "Proteus_Capture.wav", None),
# ==================================================================================
)
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -6,14 +6,18 @@
Hide the mess in Colab to make things look pretty for users.
"""
-from pathlib import Path
-from typing import Optional, Tuple
+from pathlib import Path as _Path
+from typing import Optional as _Optional, Tuple as _Tuple
-from ..models.metadata import UserMetadata
-from ._names import INPUT_BASENAMES, LATEST_VERSION, Version
-from ._version import PROTEUS_VERSION, Version
-from .core import TrainOutput, train
-from .metadata import TRAINING_KEY
+from ..models.metadata import UserMetadata as _UserMetadata
+from ._names import (
+ INPUT_BASENAMES as _INPUT_BASENAMES,
+ LATEST_VERSION as _LATEST_VERSION,
+ Version as _Version,
+)
+from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version
+from .core import TrainOutput as _TrainOutput, train as _train
+from .metadata import TRAINING_KEY as _TRAINING_KEY
_BUGGY_INPUT_BASENAMES = {
# 1.1.0 has the spikes at the wrong spots.
@@ -23,41 +27,41 @@ _OUTPUT_BASENAME = "output.wav"
_TRAIN_PATH = "."
-def _check_for_files() -> Tuple[Version, str]:
+def _check_for_files() -> _Tuple[_Version, str]:
# TODO use hash logic as in GUI trainer!
print("Checking that we have all of the required audio files...")
for name in _BUGGY_INPUT_BASENAMES:
- if Path(name).exists():
+ if _Path(name).exists():
raise RuntimeError(
- f"Detected input signal {name} that has known bugs. Please download the latest input signal, {LATEST_VERSION[1]}"
+ f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}"
)
- for input_version, input_basename, other_names in INPUT_BASENAMES:
- if Path(input_basename).exists():
- if input_version == PROTEUS_VERSION:
+ for input_version, input_basename, other_names in _INPUT_BASENAMES:
+ if _Path(input_basename).exists():
+ if input_version == _PROTEUS_VERSION:
print(f"Using Proteus input file...")
- elif input_version != LATEST_VERSION.version:
+ elif input_version != _LATEST_VERSION.version:
print(
f"WARNING: Using out-of-date input file {input_basename}. "
"Recommend downloading and using the latest version, "
- f"{LATEST_VERSION.name}."
+ f"{_LATEST_VERSION.name}."
)
break
if other_names is not None:
for other_name in other_names:
- if Path(other_name).exists():
+ if _Path(other_name).exists():
raise RuntimeError(
f"Found out-of-date input file {other_name}. Rename it to {input_basename} and re-run."
)
else:
raise FileNotFoundError(
- f"Didn't find NAM's input audio file. Please upload {LATEST_VERSION.name}"
+ f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION.name}"
)
# We found it
- if not Path(_OUTPUT_BASENAME).exists():
+ if not _Path(_OUTPUT_BASENAME).exists():
raise FileNotFoundError(
f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}."
)
- if input_version != PROTEUS_VERSION:
+ if input_version != _PROTEUS_VERSION:
print(f"Found {input_basename}, version {input_version}")
else:
print(f"Found Proteus input {input_basename}.")
@@ -66,7 +70,7 @@ def _check_for_files() -> Tuple[Version, str]:
def _get_valid_export_directory():
def get_path(version):
- return Path("exported_models", f"version_{version}")
+ return _Path("exported_models", f"version_{version}")
version = 0
while get_path(version).exists():
@@ -76,13 +80,13 @@ def _get_valid_export_directory():
def run(
epochs: int = 100,
- delay: Optional[int] = None,
+ delay: _Optional[int] = None,
model_type: str = "WaveNet",
architecture: str = "standard",
lr: float = 0.004,
lr_decay: float = 0.007,
- seed: Optional[int] = 0,
- user_metadata: Optional[UserMetadata] = None,
+ seed: _Optional[int] = 0,
+ user_metadata: _Optional[_UserMetadata] = None,
ignore_checks: bool = False,
fit_mrstft: bool = True,
):
@@ -101,7 +105,7 @@ def run(
input_version, input_basename = _check_for_files()
- train_output: TrainOutput = train(
+ train_output: _TrainOutput = _train(
input_basename,
_OUTPUT_BASENAME,
_TRAIN_PATH,
@@ -129,6 +133,6 @@ def run(
model.net.export(
model_export_outdir,
user_metadata=user_metadata,
- other_metadata={TRAINING_KEY: training_metadata.model_dump()},
+ other_metadata={_TRAINING_KEY: training_metadata.model_dump()},
)
print(f"Model exported to {model_export_outdir}. Enjoy!")
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -8,31 +8,46 @@ The core of the "simplified trainer"
Used by the GUI and Colab trainers.
"""
-import hashlib
-import tkinter as tk
-from copy import deepcopy
-from enum import Enum
-from functools import partial
-from pathlib import Path
-from time import time
-from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pytorch_lightning as pl
-import torch
-from pydantic import BaseModel
-from pytorch_lightning.utilities.warnings import PossibleUserWarning
-from torch.utils.data import DataLoader
-
-from ..data import DataError, Split, init_dataset, wav_to_np, wav_to_tensor
-from ..models.exportable import Exportable
-from ..models.losses import esr
-from ..models.metadata import UserMetadata
-from ..util import filter_warnings
-from ._version import PROTEUS_VERSION, Version
-from .lightning_module import LightningModule
-from . import metadata
+import hashlib as _hashlib
+import tkinter as _tk
+from copy import deepcopy as _deepcopy
+from enum import Enum as _Enum
+from functools import partial as _partial
+from pathlib import Path as _Path
+from time import time as _time
+from typing import (
+ Dict as _Dict,
+ NamedTuple as _NamedTuple,
+ Optional as _Optional,
+ Sequence as _Sequence,
+ Tuple as _Tuple,
+ Union as _Union,
+)
+
+import matplotlib.pyplot as _plt
+import numpy as _np
+import pytorch_lightning as _pl
+import torch as _torch
+from pydantic import BaseModel as _BaseModel
+from pytorch_lightning.utilities.warnings import (
+ PossibleUserWarning as _PossibleUserWarning,
+)
+from torch.utils.data import DataLoader as _DataLoader
+
+from ..data import (
+ DataError as _DataError,
+ Split as _Split,
+ init_dataset as _init_dataset,
+ wav_to_np as _wav_to_np,
+ wav_to_tensor as _wav_to_tensor,
+)
+from ..models.exportable import Exportable as _Exportable
+from ..models.losses import esr as _ESR
+from ..models.metadata import UserMetadata as _UserMetadata
+from ..util import filter_warnings as _filter_warnings
+from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version
+from .lightning_module import LightningModule as _LightningModule
+from . import metadata as _metadata
# Training using the simplified trainers in NAM is done at 48k.
STANDARD_SAMPLE_RATE = 48_000.0
@@ -40,7 +55,7 @@ STANDARD_SAMPLE_RATE = 48_000.0
_NY_DEFAULT = 8192
-class Architecture(Enum):
+class Architecture(_Enum):
STANDARD = "standard"
LITE = "lite"
FEATHER = "feather"
@@ -51,17 +66,17 @@ class _InputValidationError(ValueError):
pass
-def _detect_input_version(input_path) -> Tuple[Version, bool]:
+def _detect_input_version(input_path) -> _Tuple[_Version, bool]:
"""
Check to see if the input matches any of the known inputs
:return: version, strong match
"""
- def detect_strong(input_path) -> Optional[Version]:
+ def detect_strong(input_path) -> _Optional[_Version]:
def assign_hash(path):
# Use this to create hashes for new files
- md5 = hashlib.md5()
+ md5 = _hashlib.md5()
buffer_size = 65536
with open(path, "rb") as f:
while True:
@@ -76,11 +91,11 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
print(f"Strong hash: {file_hash}")
version = {
- "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
- "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
- "ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
- "36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0),
- "80e224bd5622fd6153ff1fd9f34cb3bd": PROTEUS_VERSION,
+ "4d54a958861bf720ec4637f43d44a7ef": _Version(1, 0, 0),
+ "7c3b6119c74465f79d96c761a0e27370": _Version(1, 1, 1),
+ "ede3b9d82135ce10c7ace3bb27469422": _Version(2, 0, 0),
+ "36cd1af62985c2fac3e654333e36431e": _Version(3, 0, 0),
+ "80e224bd5622fd6153ff1fd9f34cb3bd": _PROTEUS_VERSION,
}.get(file_hash)
if version is None:
print(
@@ -89,17 +104,17 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
)
return version
- def detect_weak(input_path) -> Optional[Version]:
+ def detect_weak(input_path) -> _Optional[_Version]:
def assign_hash(path):
- Hash = Optional[str]
- Hashes = Tuple[Hash, Hash]
+ Hash = _Optional[str]
+ Hashes = _Tuple[Hash, Hash]
- def _hash(x: np.ndarray) -> str:
- return hashlib.md5(x).hexdigest()
+ def _hash(x: _np.ndarray) -> str:
+ return _hashlib.md5(x).hexdigest()
def assign_hashes_v1(path) -> Hashes:
# Use this to create recognized hashes for new files
- x, info = wav_to_np(path, info=True)
+ x, info = _wav_to_np(path, info=True)
rate = info.rate
if rate != _V1_DATA_INFO.rate:
return None, None
@@ -116,7 +131,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
def assign_hashes_v2(path) -> Hashes:
# Use this to create recognized hashes for new files
- x, info = wav_to_np(path, info=True)
+ x, info = _wav_to_np(path, info=True)
rate = info.rate
if rate != _V2_DATA_INFO.rate:
return None, None
@@ -133,7 +148,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
def assign_hashes_v3(path) -> Hashes:
# Use this to create recognized hashes for new files
- x, info = wav_to_np(path, info=True)
+ x, info = _wav_to_np(path, info=True)
rate = info.rate
if rate != _V3_DATA_INFO.rate:
return None, None
@@ -147,7 +162,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
def assign_hash_v4(path) -> Hash:
# Use this to create recognized hashes for new files
- x, info = wav_to_np(path, info=True)
+ x, info = _wav_to_np(path, info=True)
rate = info.rate
if rate != _V4_DATA_INFO.rate:
return None
@@ -195,7 +210,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
(
"dadb5d62f6c3973a59bf01439799809b",
"8458126969a3f9d8e19a53554eb1fd52",
- ): Version(3, 0, 0)
+ ): _Version(3, 0, 0)
}.get((start_hash_v3, end_hash_v3))
if version is not None:
return version
@@ -203,7 +218,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
(
"1c4d94fbcb47e4d820bef611c1d4ae65",
"28694e7bf9ab3f8ae6ef86e9545d4663",
- ): Version(2, 0, 0)
+ ): _Version(2, 0, 0)
}.get((start_hash_v2, end_hash_v2))
if version is not None:
return version
@@ -211,17 +226,17 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
(
"bb4e140c9299bae67560d280917eb52b",
"9b2468fcb6e9460a399fc5f64389d353",
- ): Version(
+ ): _Version(
1, 0, 0
), # FIXME!
(
"9f20c6b5f7fef68dd88307625a573a14",
"8458126969a3f9d8e19a53554eb1fd52",
- ): Version(1, 1, 1),
+ ): _Version(1, 1, 1),
}.get((start_hash_v1, end_hash_v1))
if version is not None:
return version
- version = {"46151c8030798081acc00a725325a07d": PROTEUS_VERSION}.get(hash_v4)
+ version = {"46151c8030798081acc00a725325a07d": _PROTEUS_VERSION}.get(hash_v4)
return version
version = detect_strong(input_path)
@@ -239,20 +254,20 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
return version, strong_match
-class _DataInfo(BaseModel):
+class _DataInfo(_BaseModel):
"""
:param major_version: Data major version
"""
major_version: int
- rate: Optional[float]
+ rate: _Optional[float]
t_blips: int
first_blips_start: int
t_validate: int
train_start: int
validation_start: int
- noise_interval: Tuple[int, int]
- blip_locations: Sequence[Sequence[int]]
+ noise_interval: _Tuple[int, int]
+ blip_locations: _Sequence[_Sequence[int]]
_V1_DATA_INFO = _DataInfo(
@@ -336,7 +351,7 @@ _DELAY_CALIBRATION_REL_THRESHOLD = 0.001
_DELAY_CALIBRATION_SAFETY_FACTOR = 1 # Might be able to make this zero...
-def _warn_lookaheads(indices: Sequence[int]) -> str:
+def _warn_lookaheads(indices: _Sequence[int]) -> str:
return (
f"WARNING: delays from some blips ({','.join([str(i) for i in indices])}) are "
"at the minimum value possible. This usually means that something is "
@@ -350,7 +365,7 @@ def _calibrate_latency_v_all(
abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD,
rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD,
safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR,
-) -> metadata.LatencyCalibration:
+) -> _metadata.LatencyCalibration:
"""
Calibrate the delay in teh input-output pair based on blips.
This only uses the blips in the first set of blip locations!
@@ -359,8 +374,8 @@ def _calibrate_latency_v_all(
"""
def report_any_latency_warnings(
- delays: Sequence[int],
- ) -> metadata.LatencyCalibrationWarnings:
+ delays: _Sequence[int],
+ ) -> _metadata.LatencyCalibrationWarnings:
# Warnings associated with any single delay:
# "Lookahead warning": if the delay is equal to the lookahead, then it's
@@ -375,7 +390,7 @@ def _calibrate_latency_v_all(
# If they're _really_ different, then something might be wrong.
max_disagreement_threshold = 20
max_disagreement_too_high = (
- np.max(delays) - np.min(delays) >= max_disagreement_threshold
+ _np.max(delays) - _np.min(delays) >= max_disagreement_threshold
)
if max_disagreement_too_high:
print(
@@ -384,7 +399,7 @@ def _calibrate_latency_v_all(
"badly, then you might need to provide the latency manually."
)
- return metadata.LatencyCalibrationWarnings(
+ return _metadata.LatencyCalibrationWarnings(
matches_lookahead=matches_lookahead,
disagreement_too_high=max_disagreement_too_high,
)
@@ -393,8 +408,8 @@ def _calibrate_latency_v_all(
lookback = 10_000
# Calibrate the level for the trigger:
y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips]
- background_level = np.max(
- np.abs(
+ background_level = _np.max(
+ _np.abs(
y[
data_info.noise_interval[0]
- data_info.first_blips_start : data_info.noise_interval[1]
@@ -414,8 +429,8 @@ def _calibrate_latency_v_all(
start_looking = i_rel - lookahead
stop_looking = i_rel + lookback
y_scans.append(y[start_looking:stop_looking])
- y_scan_average = np.mean(np.stack(y_scans), axis=0)
- triggered = np.where(np.abs(y_scan_average) > trigger_threshold)[0]
+ y_scan_average = _np.mean(_np.stack(y_scans), axis=0)
+ triggered = _np.where(_np.abs(y_scan_average) > trigger_threshold)[0]
if len(triggered) == 0:
msg = (
"No response activated the trigger in response to input spikes. "
@@ -423,24 +438,24 @@ def _calibrate_latency_v_all(
)
print(msg)
print("SHARE THIS PLOT IF YOU ASK FOR HELP")
- plt.figure()
- plt.plot(
- np.arange(-lookahead, lookback),
+ _plt.figure()
+ _plt.plot(
+ _np.arange(-lookahead, lookback),
y_scan_average,
color="C0",
label="Signal average",
)
for y_scan in y_scans:
- plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2)
- plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
- plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold")
- plt.axhline(y=trigger_threshold, color="k", linestyle="--")
- plt.xlim((-lookahead, lookback))
- plt.xlabel("Samples")
- plt.ylabel("Response")
- plt.legend()
- plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP")
- plt.show()
+ _plt.plot(_np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2)
+ _plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
+ _plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold")
+ _plt.axhline(y=trigger_threshold, color="k", linestyle="--")
+ _plt.xlim((-lookahead, lookback))
+ _plt.xlabel("Samples")
+ _plt.ylabel("Response")
+ _plt.legend()
+ _plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP")
+ _plt.show()
raise RuntimeError(msg)
else:
j = triggered[0]
@@ -454,7 +469,7 @@ def _calibrate_latency_v_all(
f"After aplying safety factor of {safety_factor}, the final delay is "
f"{delay_post_safety_factor}"
)
- return metadata.LatencyCalibration(
+ return _metadata.LatencyCalibration(
algorithm_version=1,
delays=[delay],
safety_factor=safety_factor,
@@ -463,72 +478,72 @@ def _calibrate_latency_v_all(
)
-_calibrate_latency_v1 = partial(_calibrate_latency_v_all, _V1_DATA_INFO)
-_calibrate_latency_v2 = partial(_calibrate_latency_v_all, _V2_DATA_INFO)
-_calibrate_latency_v3 = partial(_calibrate_latency_v_all, _V3_DATA_INFO)
-_calibrate_latency_v4 = partial(_calibrate_latency_v_all, _V4_DATA_INFO)
+_calibrate_latency_v1 = _partial(_calibrate_latency_v_all, _V1_DATA_INFO)
+_calibrate_latency_v2 = _partial(_calibrate_latency_v_all, _V2_DATA_INFO)
+_calibrate_latency_v3 = _partial(_calibrate_latency_v_all, _V3_DATA_INFO)
+_calibrate_latency_v4 = _partial(_calibrate_latency_v_all, _V4_DATA_INFO)
def _plot_latency_v_all(
data_info: _DataInfo, latency: int, input_path: str, output_path: str, _nofail=True
):
print("Plotting the latency for manual inspection...")
- x = wav_to_np(input_path)[
+ x = _wav_to_np(input_path)[
data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips
]
- y = wav_to_np(output_path)[
+ y = _wav_to_np(output_path)[
data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips
]
# Only get the blips we really want.
- i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0]
+ i = _np.where(_np.abs(x) > 0.5 * _np.abs(x).max())[0]
if len(i) == 0:
print("Failed to find the spike in the input file.")
print(
"Plotting the input and output; there should be spikes at around the "
"marked locations."
)
- t = np.arange(
+ t = _np.arange(
data_info.first_blips_start, data_info.first_blips_start + data_info.t_blips
)
expected_spikes = data_info.blip_locations[0] # For v1 specifically
- fig, axs = plt.subplots(len((x, y)), 1)
+ fig, axs = _plt.subplots(len((x, y)), 1)
for ax, curve in zip(axs, (x, y)):
ax.plot(t, curve)
[ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes]
- plt.show()
+ _plt.show()
if _nofail:
raise RuntimeError("Failed to plot delay")
else:
- plt.figure()
+ _plt.figure()
di = 20
# V1's got not a spike but a longer plateau; take the front of it.
if data_info.major_version == 1:
i = [i[0]]
for e, ii in enumerate(i, 1):
- plt.plot(
- np.arange(-di, di),
+ _plt.plot(
+ _np.arange(-di, di),
y[ii - di + latency : ii + di + latency],
".-",
label=f"Output {e}",
)
- plt.axvline(x=0, linestyle="--", color="k")
- plt.legend()
- plt.show() # This doesn't freeze the notebook
+ _plt.axvline(x=0, linestyle="--", color="k")
+ _plt.legend()
+ _plt.show() # This doesn't freeze the notebook
-_plot_latency_v1 = partial(_plot_latency_v_all, _V1_DATA_INFO)
-_plot_latency_v2 = partial(_plot_latency_v_all, _V2_DATA_INFO)
-_plot_latency_v3 = partial(_plot_latency_v_all, _V3_DATA_INFO)
-_plot_latency_v4 = partial(_plot_latency_v_all, _V4_DATA_INFO)
+_plot_latency_v1 = _partial(_plot_latency_v_all, _V1_DATA_INFO)
+_plot_latency_v2 = _partial(_plot_latency_v_all, _V2_DATA_INFO)
+_plot_latency_v3 = _partial(_plot_latency_v_all, _V3_DATA_INFO)
+_plot_latency_v4 = _partial(_plot_latency_v_all, _V4_DATA_INFO)
def _analyze_latency(
- user_latency: Optional[int],
- input_version: Version,
+ user_latency: _Optional[int],
+ input_version: _Version,
input_path: str,
output_path: str,
silent: bool = False,
-) -> metadata.Latency:
+) -> _metadata.Latency:
"""
:param is_proteus: Forget the version; d
"""
@@ -546,14 +561,14 @@ def _analyze_latency(
)
if user_latency is not None:
print(f"Delay is specified as {user_latency}")
- calibration_output = calibrate(wav_to_np(output_path))
+ calibration_output = calibrate(_wav_to_np(output_path))
latency = (
user_latency if user_latency is not None else calibration_output.recommended
)
if not silent:
plot(latency, input_path, output_path)
- return metadata.Latency(manual=user_latency, calibration=calibration_output)
+ return _metadata.Latency(manual=user_latency, calibration=calibration_output)
def get_lstm_config(architecture):
@@ -585,8 +600,8 @@ def get_lstm_config(architecture):
}[architecture]
-def _check_v1(*args, **kwargs) -> metadata.DataChecks:
- return metadata.DataChecks(version=1, passed=True)
+def _check_v1(*args, **kwargs) -> _metadata.DataChecks:
+ return _metadata.DataChecks(version=1, passed=True)
def _esr_validation_replicate_msg(threshold: float) -> str:
@@ -601,16 +616,18 @@ def _esr_validation_replicate_msg(threshold: float) -> str:
)
-def _check_v2(input_path, output_path, delay: int, silent: bool) -> metadata.DataChecks:
- with torch.no_grad():
+def _check_v2(
+ input_path, output_path, delay: int, silent: bool
+) -> _metadata.DataChecks:
+ with _torch.no_grad():
print("V2 checks...")
rate = _V2_DATA_INFO.rate
- y = wav_to_tensor(output_path, rate=rate)
+ y = _wav_to_tensor(output_path, rate=rate)
t_blips = _V2_DATA_INFO.t_blips
t_validate = _V2_DATA_INFO.t_validate
y_val_1 = y[-(t_blips + 2 * t_validate) : -(t_blips + t_validate)]
y_val_2 = y[-(t_blips + t_validate) : -t_blips]
- esr_replicate = esr(y_val_1, y_val_2).item()
+ esr_replicate = _ESR(y_val_1, y_val_2).item()
print(f"Replicate ESR is {esr_replicate:.8f}.")
esr_replicate_threshold = 0.01
if esr_replicate > esr_replicate_threshold:
@@ -630,19 +647,19 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> metadata.Dat
i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)]
start = -10
end = 1000
- blips = torch.stack(
+ blips = _torch.stack(
[
- torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]),
- torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]),
+ _torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]),
+ _torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]),
]
)
return blips
blips = get_blips(y)
- esr_0 = esr(blips[0][0], blips[0][1]).item() # Within start
- esr_1 = esr(blips[1][0], blips[1][1]).item() # Within end
- esr_cross_0 = esr(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end
- esr_cross_1 = esr(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end
+ esr_0 = _ESR(blips[0][0], blips[0][1]).item() # Within start
+ esr_1 = _ESR(blips[1][0], blips[1][1]).item() # Within end
+ esr_cross_0 = _ESR(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end
+ esr_cross_1 = _ESR(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end
print(" ESRs:")
print(f" Start : {esr_0}")
@@ -655,22 +672,22 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> metadata.Dat
def plot_esr_blip_error(
show_plot: bool,
msg: str,
- arrays: Sequence[Sequence[float]],
- labels: Sequence[str],
+ arrays: _Sequence[_Sequence[float]],
+ labels: _Sequence[str],
):
"""
:param silent: Whether to make and show a plot about it
"""
if show_plot:
- plt.figure()
- [plt.plot(array, label=label) for array, label in zip(arrays, labels)]
- plt.xlabel("Sample")
- plt.ylabel("Output")
- plt.legend()
- plt.grid()
+ _plt.figure()
+ [_plt.plot(array, label=label) for array, label in zip(arrays, labels)]
+ _plt.xlabel("Sample")
+ _plt.ylabel("Output")
+ _plt.legend()
+ _plt.grid()
print(msg)
if show_plot:
- plt.show()
+ _plt.show()
print(
"This is known to be a very sensitive test, so training will continue. "
"If the model doesn't look good, then this may be why!"
@@ -693,7 +710,7 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> metadata.Dat
blip_pair,
("Replicate 1", "Replicate 2"),
)
- return metadata.DataChecks(version=2, passed=False)
+ return _metadata.DataChecks(version=2, passed=False)
# Check blips between start & end of train signal
for e, blip_pair, replicate in zip(
(esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2)
@@ -707,46 +724,46 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> metadata.Dat
blip_pair,
(f"Start, replicate {replicate}", f"End, replicate {replicate}"),
)
- return metadata.DataChecks(version=2, passed=False)
- return metadata.DataChecks(version=2, passed=True)
+ return _metadata.DataChecks(version=2, passed=False)
+ return _metadata.DataChecks(version=2, passed=True)
def _check_v3(
input_path, output_path, silent: bool, *args, **kwargs
-) -> metadata.DataChecks:
- with torch.no_grad():
+) -> _metadata.DataChecks:
+ with _torch.no_grad():
print("V3 checks...")
rate = _V3_DATA_INFO.rate
- y = wav_to_tensor(output_path, rate=rate)
- n = len(wav_to_tensor(input_path)) # to End-crop output
+ y = _wav_to_tensor(output_path, rate=rate)
+ n = len(_wav_to_tensor(input_path)) # to End-crop output
y_val_1 = y[: _V3_DATA_INFO.t_validate]
y_val_2 = y[n - _V3_DATA_INFO.t_validate : n]
- esr_replicate = esr(y_val_1, y_val_2).item()
+ esr_replicate = _ESR(y_val_1, y_val_2).item()
print(f"Replicate ESR is {esr_replicate:.8f}.")
esr_replicate_threshold = 0.01
if esr_replicate > esr_replicate_threshold:
print(_esr_validation_replicate_msg(esr_replicate_threshold))
if not silent:
- plt.figure()
- t = np.arange(len(y_val_1)) / rate
- plt.plot(t, y_val_1, label="Validation 1")
- plt.plot(t, y_val_2, label="Validation 2")
- plt.xlabel("Time (sec)")
- plt.legend()
- plt.title("V3 check: Validation replicate FAILURE")
- plt.show()
- return metadata.DataChecks(version=3, passed=False)
- return metadata.DataChecks(version=3, passed=True)
+ _plt.figure()
+ t = _np.arange(len(y_val_1)) / rate
+ _plt.plot(t, y_val_1, label="Validation 1")
+ _plt.plot(t, y_val_2, label="Validation 2")
+ _plt.xlabel("Time (sec)")
+ _plt.legend()
+ _plt.title("V3 check: Validation replicate FAILURE")
+ _plt.show()
+ return _metadata.DataChecks(version=3, passed=False)
+ return _metadata.DataChecks(version=3, passed=True)
def _check_v4(
input_path, output_path, silent: bool, *args, **kwargs
-) -> metadata.DataChecks:
+) -> _metadata.DataChecks:
# Things we can't check:
# Latency compensation agreement
# Data replicability
print("Using Proteus audio file. Standard data checks aren't possible!")
- signal, info = wav_to_np(output_path, info=True)
+ signal, info = _wav_to_np(output_path, info=True)
passed = True
if info.rate != _V4_DATA_INFO.rate:
print(
@@ -761,12 +778,12 @@ def _check_v4(
"File doesn't meet the minimum length requirements for latency compensation and validation signal!"
)
passed = False
- return metadata.DataChecks(version=4, passed=passed)
+ return _metadata.DataChecks(version=4, passed=passed)
def _check_data(
- input_path: str, output_path: str, input_version: Version, delay: int, silent: bool
-) -> Optional[metadata.DataChecks]:
+ input_path: str, output_path: str, input_version: _Version, delay: int, silent: bool
+) -> _Optional[_metadata.DataChecks]:
"""
Ensure that everything should go smoothly
@@ -912,7 +929,11 @@ _CAB_MRSTFT_PRE_EMPH_COEF = 0.85
def _get_data_config(
- input_version: Version, input_path: Path, output_path: Path, ny: int, latency: int
+ input_version: _Version,
+ input_path: _Path,
+ output_path: _Path,
+ ny: int,
+ latency: int,
) -> dict:
def get_split_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
@@ -976,7 +997,7 @@ def _get_data_config(
def _get_configs(
- input_version: Version,
+ input_version: _Version,
input_path: str,
output_path: str,
latency: int,
@@ -1031,9 +1052,9 @@ def _get_configs(
model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
- if torch.cuda.is_available():
+ if _torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
- elif torch.backends.mps.is_available():
+ elif _torch.backends.mps.is_available():
device_config = {"accelerator": "mps", "devices": 1}
else:
print("WARNING: No GPU was found. Training will be very slow!")
@@ -1053,45 +1074,49 @@ def _get_configs(
def _get_dataloaders(
- data_config: Dict, learning_config: Dict, model: LightningModule
-) -> Tuple[DataLoader, DataLoader]:
- data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)]
+ data_config: _Dict, learning_config: _Dict, model: _LightningModule
+) -> _Tuple[_DataLoader, _DataLoader]:
+ data_config, learning_config = [
+ _deepcopy(c) for c in (data_config, learning_config)
+ ]
data_config["common"]["nx"] = model.net.receptive_field
- dataset_train = init_dataset(data_config, Split.TRAIN)
- dataset_validation = init_dataset(data_config, Split.VALIDATION)
- train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
- val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+ dataset_train = _init_dataset(data_config, _Split.TRAIN)
+ dataset_validation = _init_dataset(data_config, _Split.VALIDATION)
+ train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"])
+ val_dataloader = _DataLoader(
+ dataset_validation, **learning_config["val_dataloader"]
+ )
return train_dataloader, val_dataloader
-def _esr(pred: torch.Tensor, target: torch.Tensor) -> float:
+def _esr(pred: _torch.Tensor, target: _torch.Tensor) -> float:
return (
- torch.mean(torch.square(pred - target)).item()
- / torch.mean(torch.square(target)).item()
+ _torch.mean(_torch.square(pred - target)).item()
+ / _torch.mean(_torch.square(target)).item()
)
def _plot(
model,
ds,
- window_start: Optional[int] = None,
- window_end: Optional[int] = None,
- filepath: Optional[str] = None,
+ window_start: _Optional[int] = None,
+ window_end: _Optional[int] = None,
+ filepath: _Optional[str] = None,
silent: bool = False,
) -> float:
"""
:return: The ESR
"""
print("Plotting a comparison of your model with the target output...")
- with torch.no_grad():
+ with _torch.no_grad():
tx = len(ds.x) / 48_000
print(f"Run (t={tx:.2f} sec)")
- t0 = time()
+ t0 = _time()
output = model(ds.x).flatten().cpu().numpy()
- t1 = time()
+ t1 = _time()
print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)")
- esr = _esr(torch.Tensor(output), ds.y)
+ esr = _esr(_torch.Tensor(output), ds.y)
# Trying my best to put numbers to it...
if esr < 0.01:
esr_comment = "Great!"
@@ -1106,15 +1131,15 @@ def _plot(
print(f"Error-signal ratio = {esr:.4g}")
print(esr_comment)
- plt.figure(figsize=(16, 5))
- plt.plot(output[window_start:window_end], label="Prediction")
- plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
- plt.title(f"ESR={esr:.4g}")
- plt.legend()
+ _plt.figure(figsize=(16, 5))
+ _plt.plot(output[window_start:window_end], label="Prediction")
+ _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
+ _plt.title(f"ESR={esr:.4g}")
+ _plt.legend()
if filepath is not None:
- plt.savefig(filepath + ".png")
+ _plt.savefig(filepath + ".png")
if not silent:
- plt.show()
+ _plt.show()
return esr
@@ -1139,14 +1164,14 @@ def _print_nasty_checks_warning():
def _nasty_checks_modal():
msg = "You are ignoring the checks!\nYour model might turn out bad!"
- root = tk.Tk()
+ root = _tk.Tk()
root.withdraw() # hide the root window
- modal = tk.Toplevel(root)
+ modal = _tk.Toplevel(root)
modal.geometry("300x100")
modal.title("Warning!")
- label = tk.Label(modal, text=msg)
+ label = _tk.Label(modal, text=msg)
label.pack(pady=10)
- ok_button = tk.Button(
+ ok_button = _tk.Button(
modal,
text="I can only blame myself!",
command=lambda: [modal.destroy(), root.quit()],
@@ -1156,7 +1181,7 @@ def _nasty_checks_modal():
modal.mainloop()
-class _ValidationStopping(pl.callbacks.EarlyStopping):
+class _ValidationStopping(_pl.callbacks.EarlyStopping):
"""
Callback to indicate to stop training if the validation metric is good enough,
without the other conditions that EarlyStopping usually forces like patience.
@@ -1164,10 +1189,10 @@ class _ValidationStopping(pl.callbacks.EarlyStopping):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.patience = np.inf
+ self.patience = _np.inf
-class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
+class _ModelCheckpoint(_pl.callbacks.model_checkpoint.ModelCheckpoint):
"""
Extension to model checkpoint to save a .nam file as well as the .ckpt file.
"""
@@ -1175,9 +1200,9 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
def __init__(
self,
*args,
- user_metadata: Optional[UserMetadata] = None,
- settings_metadata: Optional[metadata.Settings] = None,
- data_metadata: Optional[metadata.Data] = None,
+ user_metadata: _Optional[_UserMetadata] = None,
+ settings_metadata: _Optional[_metadata.Settings] = None,
+ data_metadata: _Optional[_metadata.Data] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
@@ -1185,10 +1210,10 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
self._settings_metadata = settings_metadata
self._data_metadata = data_metadata
- _NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION
+ _NAM_FILE_EXTENSION = _Exportable.FILE_EXTENSION
@classmethod
- def _get_nam_filepath(cls, filepath: str) -> Path:
+ def _get_nam_filepath(cls, filepath: str) -> _Path:
"""
Given a .ckpt filepath, figure out a .nam for it.
"""
@@ -1197,18 +1222,18 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
f"Checkpoint filepath {filepath} doesn't end in expected extension "
f"{cls.FILE_EXTENSION}"
)
- return Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION)
+ return _Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION)
@property
def _include_other_metadata(self) -> bool:
return self._settings_metadata is not None and self._data_metadata is not None
- def _save_checkpoint(self, trainer: pl.Trainer, filepath: str):
+ def _save_checkpoint(self, trainer: _pl.Trainer, filepath: str):
# Save the .ckpt:
super()._save_checkpoint(trainer, filepath)
# Save the .nam:
nam_filepath = self._get_nam_filepath(filepath)
- pl_model: LightningModule = trainer.model
+ pl_model: _LightningModule = trainer.model
nam_model = pl_model.net
outdir = nam_filepath.parent
# HACK: Assume the extension
@@ -1217,7 +1242,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
None
if not self._include_other_metadata
else {
- metadata.TRAINING_KEY: metadata.TrainingMetadata(
+ _metadata.TRAINING_KEY: _metadata.TrainingMetadata(
settings=self._settings_metadata,
data=self._data_metadata,
validation_esr=None, # TODO how to get this?
@@ -1231,7 +1256,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
other_metadata=other_metadata,
)
- def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None:
+ def _remove_checkpoint(self, trainer: _pl.Trainer, filepath: str) -> None:
super()._remove_checkpoint(trainer, filepath)
nam_path = self._get_nam_filepath(filepath)
if nam_path.exists():
@@ -1239,10 +1264,10 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
def get_callbacks(
- threshold_esr: Optional[float],
- user_metadata: Optional[UserMetadata] = None,
- settings_metadata: Optional[metadata.Settings] = None,
- data_metadata: Optional[metadata.Data] = None,
+ threshold_esr: _Optional[float],
+ user_metadata: _Optional[_UserMetadata] = None,
+ settings_metadata: _Optional[_metadata.Settings] = None,
+ data_metadata: _Optional[_metadata.Data] = None,
):
callbacks = [
_ModelCheckpoint(
@@ -1269,18 +1294,18 @@ def get_callbacks(
return callbacks
-class TrainOutput(NamedTuple):
+class TrainOutput(_NamedTuple):
"""
:param model: The trained model
:param simpliifed_trianer_metadata: The metadata summarizing training with the
simplified trainer.
"""
- model: Optional[LightningModule]
- metadata: metadata.TrainingMetadata
+ model: _Optional[_LightningModule]
+ metadata: _metadata.TrainingMetadata
-def _get_final_latency(latency_analysis: metadata.Latency) -> int:
+def _get_final_latency(latency_analysis: _metadata.Latency) -> int:
if latency_analysis.manual is not None:
latency = latency_analysis.manual
print(f"Latency provided as {latency_analysis.manual}; override calibration")
@@ -1294,27 +1319,27 @@ def train(
input_path: str,
output_path: str,
train_path: str,
- input_version: Optional[Version] = None, # Deprecate?
+ input_version: _Optional[_Version] = None, # Deprecate?
epochs=100,
- delay: Optional[int] = None,
- latency: Optional[int] = None,
+ delay: _Optional[int] = None,
+ latency: _Optional[int] = None,
model_type: str = "WaveNet",
- architecture: Union[Architecture, str] = Architecture.STANDARD,
+ architecture: _Union[Architecture, str] = Architecture.STANDARD,
batch_size: int = 16,
ny: int = _NY_DEFAULT,
lr=0.004,
lr_decay=0.007,
- seed: Optional[int] = 0,
+ seed: _Optional[int] = 0,
save_plot: bool = False,
silent: bool = False,
modelname: str = "model",
ignore_checks: bool = False,
local: bool = False,
fit_mrstft: bool = True,
- threshold_esr: Optional[bool] = None,
- user_metadata: Optional[UserMetadata] = None,
- fast_dev_run: Union[bool, int] = False,
-) -> Optional[TrainOutput]:
+ threshold_esr: _Optional[bool] = None,
+ user_metadata: _Optional[_UserMetadata] = None,
+ fast_dev_run: _Union[bool, int] = False,
+) -> _Optional[TrainOutput]:
"""
:param lr_decay: =1-gamma for Exponential learning rate decay.
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
@@ -1322,8 +1347,8 @@ def train(
"""
def parse_user_latency(
- delay: Optional[int], latency: Optional[int]
- ) -> Optional[int]:
+ delay: _Optional[int], latency: _Optional[int]
+ ) -> _Optional[int]:
if delay is not None:
if latency is not None:
raise ValueError("Both delay and latency are provided; use latency!")
@@ -1332,7 +1357,7 @@ def train(
return latency
if seed is not None:
- torch.manual_seed(seed)
+ _torch.manual_seed(seed)
# HACK: We need to check the sample rates and lengths of the audio here or else
# It will look like a bad self-ESR (Issue 473)
@@ -1384,9 +1409,9 @@ def train(
print("Exiting core training...")
return TrainOutput(
model=None,
- metadata=metadata.TrainingMetadata(
- settings=metadata.Settings(ignore_checks=ignore_checks),
- data=metadata.Data(
+ metadata=_metadata.TrainingMetadata(
+ settings=_metadata.Settings(ignore_checks=ignore_checks),
+ data=_metadata.Data(
latency=latency_analysis, checks=data_check_output
),
validation_esr=None,
@@ -1417,7 +1442,7 @@ def train(
# * Model is re-instantiated after training anyways.
# (Hacky) solution: set sample rate in model from dataloader after second
# instantiation from final checkpoint.
- model = LightningModule.init_from_config(model_config)
+ model = _LightningModule.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
@@ -1431,10 +1456,10 @@ def train(
model.net.sample_rate = sample_rate
# Put together the metadata that's needed in checkpoints:
- settings_metadata = metadata.Settings(ignore_checks=ignore_checks)
- data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output)
+ settings_metadata = _metadata.Settings(ignore_checks=ignore_checks)
+ data_metadata = _metadata.Data(latency=latency_analysis, checks=data_check_output)
- trainer = pl.Trainer(
+ trainer = _pl.Trainer(
callbacks=get_callbacks(
threshold_esr,
user_metadata=user_metadata,
@@ -1446,21 +1471,21 @@ def train(
**learning_config["trainer"],
)
# Suppress the PossibleUserWarning about num_workers (Issue 345)
- with filter_warnings("ignore", category=PossibleUserWarning):
+ with _filter_warnings("ignore", category=_PossibleUserWarning):
trainer.fit(model, train_dataloader, val_dataloader)
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
- model = LightningModule.load_from_checkpoint(
+ model = _LightningModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
- **LightningModule.parse_config(model_config),
+ **_LightningModule.parse_config(model_config),
)
model.cpu()
model.eval()
model.net.sample_rate = sample_rate # Hack, part 2
- def window_kwargs(version: Version):
+ def window_kwargs(version: _Version):
if version.major == 1:
return dict(
window_start=100_000, # Start of the plotting window, in samples
@@ -1487,7 +1512,7 @@ def train(
)
return TrainOutput(
model=model,
- metadata=metadata.TrainingMetadata(
+ metadata=_metadata.TrainingMetadata(
settings=settings_metadata,
data=data_metadata,
validation_esr=validation_esr,
@@ -1495,7 +1520,7 @@ def train(
)
-class DataInputValidation(BaseModel):
+class DataInputValidation(_BaseModel):
passed: bool
@@ -1512,49 +1537,49 @@ def validate_input(input_path) -> DataInputValidation:
return DataInputValidation(passed=False)
-class _PyTorchDataSplitValidation(BaseModel):
+class _PyTorchDataSplitValidation(_BaseModel):
"""
:param msg: On exception, catch and assign. Otherwise None
"""
passed: bool
- msg: Optional[str]
+ msg: _Optional[str]
-class _PyTorchDataValidation(BaseModel):
+class _PyTorchDataValidation(_BaseModel):
passed: bool
train: _PyTorchDataSplitValidation # cf Split.TRAIN
validation: _PyTorchDataSplitValidation # Split.VALIDATION
-class _SampleRateValidation(BaseModel):
+class _SampleRateValidation(_BaseModel):
passed: bool
input: int
output: int
-class _LengthValidation(BaseModel):
+class _LengthValidation(_BaseModel):
passed: bool
delta_seconds: float
-class DataValidationOutput(BaseModel):
+class DataValidationOutput(_BaseModel):
passed: bool
passed_critical: bool
sample_rate: _SampleRateValidation
length: _LengthValidation
input_version: str
- latency: metadata.Latency
- checks: metadata.DataChecks
+ latency: _metadata.Latency
+ checks: _metadata.DataChecks
pytorch: _PyTorchDataValidation
def _check_audio_sample_rates(
- input_path: Path,
- output_path: Path,
+ input_path: _Path,
+ output_path: _Path,
) -> _SampleRateValidation:
- _, x_info = wav_to_np(input_path, info=True)
- _, y_info = wav_to_np(output_path, info=True)
+ _, x_info = _wav_to_np(input_path, info=True)
+ _, y_info = _wav_to_np(output_path, info=True)
return _SampleRateValidation(
passed=x_info.rate == y_info.rate,
@@ -1564,10 +1589,10 @@ def _check_audio_sample_rates(
def _check_audio_lengths(
- input_path: Path,
- output_path: Path,
- max_under_seconds: Optional[float] = 0.0,
- max_over_seconds: Optional[float] = 1.0,
+ input_path: _Path,
+ output_path: _Path,
+ max_under_seconds: _Optional[float] = 0.0,
+ max_over_seconds: _Optional[float] = 1.0,
) -> _LengthValidation:
"""
Check that the input and output have the right lengths compared to each
@@ -1584,8 +1609,8 @@ def _check_audio_lengths(
value of 1.0 means that the output can't be more than a second longer
than the input.
"""
- x, x_info = wav_to_np(input_path, info=True)
- y, y_info = wav_to_np(output_path, info=True)
+ x, x_info = _wav_to_np(input_path, info=True)
+ y, y_info = _wav_to_np(output_path, info=True)
length_input = len(x) / x_info.rate
length_output = len(y) / y_info.rate
@@ -1601,9 +1626,9 @@ def _check_audio_lengths(
def validate_data(
- input_path: Path,
- output_path: Path,
- user_latency: Optional[int],
+ input_path: _Path,
+ output_path: _Path,
+ user_latency: _Optional[int],
num_output_samples_per_datum: int = _NY_DEFAULT,
):
"""
@@ -1660,14 +1685,14 @@ def validate_data(
# be unlikely to make a difference. Still, would be nice to fix.
data_config["common"]["nx"] = 4096
- pytorch_data_split_validation_dict: Dict[str, _PyTorchDataSplitValidation] = {}
- for split in Split:
+ pytorch_data_split_validation_dict: _Dict[str, _PyTorchDataSplitValidation] = {}
+ for split in _Split:
try:
- init_dataset(data_config, split)
+ _init_dataset(data_config, split)
pytorch_data_split_validation_dict[split.value] = (
_PyTorchDataSplitValidation(passed=True, msg=None)
)
- except DataError as e:
+ except _DataError as e:
pytorch_data_split_validation_dict[split.value] = (
_PyTorchDataSplitValidation(passed=False, msg=str(e))
)
diff --git a/nam/train/full.py b/nam/train/full.py
@@ -2,31 +2,37 @@
# Created Date: Tuesday March 26th 2024
# Author: Enrico Schifano (eraz1997@live.it)
-import json
-from pathlib import Path
-from time import time
-from typing import Optional, Union
-from warnings import warn
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pytorch_lightning as pl
-from pytorch_lightning.utilities.warnings import PossibleUserWarning
-import torch
-from torch.utils.data import DataLoader
-
-from nam.data import ConcatDataset, Split, init_dataset
-from nam.train.lightning_module import LightningModule
-from nam.util import filter_warnings
-
-torch.manual_seed(0)
-
-
-def _rms(x: Union[np.ndarray, torch.Tensor]) -> float:
- if isinstance(x, np.ndarray):
- return np.sqrt(np.mean(np.square(x)))
- elif isinstance(x, torch.Tensor):
- return torch.sqrt(torch.mean(torch.square(x))).item()
+import json as _json
+from pathlib import Path as _Path
+from time import time as _time
+from typing import Optional as _Optional, Union as _Union
+from warnings import warn as _warn
+
+import matplotlib.pyplot as _plt
+import numpy as _np
+import pytorch_lightning as _pl
+from pytorch_lightning.utilities.warnings import (
+ PossibleUserWarning as _PossibleUserWarning,
+)
+import torch as _torch
+from torch.utils.data import DataLoader as _DataLoader
+
+from nam.data import (
+ ConcatDataset as _ConcatDataset,
+ Split as _Split,
+ init_dataset as _init_dataset,
+)
+from nam.train.lightning_module import LightningModule as _LightningModule
+from nam.util import filter_warnings as _filter_warnings
+
+_torch.manual_seed(0)
+
+
+def _rms(x: _Union[_np.ndarray, _torch.Tensor]) -> float:
+ if isinstance(x, _np.ndarray):
+ return _np.sqrt(_np.mean(_np.square(x)))
+ elif isinstance(x, _torch.Tensor):
+ return _torch.sqrt(_torch.mean(_torch.square(x))).item()
else:
raise TypeError(type(x))
@@ -36,18 +42,18 @@ def _plot(
ds,
savefig=None,
show=True,
- window_start: Optional[int] = None,
- window_end: Optional[int] = None,
+ window_start: _Optional[int] = None,
+ window_end: _Optional[int] = None,
):
- if isinstance(ds, ConcatDataset):
+ if isinstance(ds, _ConcatDataset):
def extend_savefig(i, savefig):
if savefig is None:
return None
- savefig = Path(savefig)
+ savefig = _Path(savefig)
extension = savefig.name.split(".")[-1]
stem = savefig.name[: -len(extension) - 1]
- return Path(savefig.parent, f"{stem}_{i}.{extension}")
+ return _Path(savefig.parent, f"{stem}_{i}.{extension}")
for i, ds_i in enumerate(ds.datasets):
_plot(
@@ -59,29 +65,29 @@ def _plot(
window_end=window_end,
)
return
- with torch.no_grad():
+ with _torch.no_grad():
tx = len(ds.x) / 48_000
print(f"Run (t={tx:.2f})")
- t0 = time()
+ t0 = _time()
output = model(ds.x).flatten().cpu().numpy()
- t1 = time()
+ t1 = _time()
try:
rt = f"{tx / (t1 - t0):.2f}"
except ZeroDivisionError as e:
rt = "???"
print(f"Took {t1 - t0:.2f} ({rt}x)")
- plt.figure(figsize=(16, 5))
- plt.plot(output[window_start:window_end], label="Prediction")
- plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
- nrmse = _rms(torch.Tensor(output) - ds.y) / _rms(ds.y)
+ _plt.figure(figsize=(16, 5))
+ _plt.plot(output[window_start:window_end], label="Prediction")
+ _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
+ nrmse = _rms(_torch.Tensor(output) - ds.y) / _rms(ds.y)
esr = nrmse**2
- plt.title(f"ESR={esr:.3f}")
- plt.legend()
+ _plt.title(f"ESR={esr:.3f}")
+ _plt.legend()
if savefig is not None:
- plt.savefig(savefig)
+ _plt.savefig(savefig)
if show:
- plt.show()
+ _plt.show()
def _create_callbacks(learning_config):
@@ -102,7 +108,7 @@ def _create_callbacks(learning_config):
)
}
- checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ checkpoint_best = _pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}",
save_top_k=3,
monitor="val_loss",
@@ -111,14 +117,14 @@ def _create_callbacks(learning_config):
# return [checkpoint_best, checkpoint_last]
# The last epoch that was finished.
- checkpoint_epoch = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ checkpoint_epoch = _pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1
)
if not validate_inside_epoch:
return [checkpoint_best, checkpoint_epoch]
else:
# The last validation pass, whether at the end of an epoch or not
- checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ checkpoint_last = _pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="checkpoint_last_{epoch:04d}_{step}", **kwargs
)
return [checkpoint_best, checkpoint_last, checkpoint_epoch]
@@ -128,7 +134,7 @@ def main(
data_config,
model_config,
learning_config,
- outdir: Path,
+ outdir: _Path,
no_show: bool = False,
make_plots=True,
):
@@ -140,35 +146,37 @@ def main(
("model", model_config),
("learning", learning_config),
):
- with open(Path(outdir, f"config_{basename}.json"), "w") as fp:
- json.dump(config, fp, indent=4)
+ with open(_Path(outdir, f"config_{basename}.json"), "w") as fp:
+ _json.dump(config, fp, indent=4)
- model = LightningModule.init_from_config(model_config)
+ model = _LightningModule.init_from_config(model_config)
# Add receptive field to data config:
data_config["common"] = data_config.get("common", {})
if "nx" in data_config["common"]:
- warn(
+ _warn(
f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}"
)
data_config["common"]["nx"] = model.net.receptive_field
- dataset_train = init_dataset(data_config, Split.TRAIN)
- dataset_validation = init_dataset(data_config, Split.VALIDATION)
+ dataset_train = _init_dataset(data_config, _Split.TRAIN)
+ dataset_validation = _init_dataset(data_config, _Split.VALIDATION)
if dataset_train.sample_rate != dataset_validation.sample_rate:
raise RuntimeError(
"Train and validation data loaders have different data set sample rates: "
f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}"
)
model.net.sample_rate = dataset_train.sample_rate
- train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
- val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+ train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"])
+ val_dataloader = _DataLoader(
+ dataset_validation, **learning_config["val_dataloader"]
+ )
- trainer = pl.Trainer(
+ trainer = _pl.Trainer(
callbacks=_create_callbacks(learning_config),
default_root_dir=outdir,
**learning_config["trainer"],
)
- with filter_warnings("ignore", category=PossibleUserWarning):
+ with _filter_warnings("ignore", category=_PossibleUserWarning):
trainer.fit(
model,
train_dataloader,
@@ -178,9 +186,9 @@ def main(
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
- model = LightningModule.load_from_checkpoint(
+ model = _LightningModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
- **LightningModule.parse_config(model_config),
+ **_LightningModule.parse_config(model_config),
)
model.cpu()
model.eval()
@@ -188,7 +196,7 @@ def main(
_plot(
model,
dataset_validation,
- savefig=Path(outdir, "comparison.png"),
+ savefig=_Path(outdir, "comparison.png"),
window_start=100_000,
window_end=110_000,
show=False,
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -10,16 +10,16 @@ Usage:
>>> run()
"""
-import abc
-import re
-import requests
-import tkinter as tk
-import subprocess
-import sys
-import webbrowser
-from dataclasses import dataclass
-from enum import Enum
-from functools import partial
+import abc as _abc
+import re as _re
+import requests as _requests
+import tkinter as _tk
+import subprocess as _subprocess
+import sys as _sys
+import webbrowser as _webbrowser
+from dataclasses import dataclass as _dataclass
+from enum import Enum as _Enum
+from functools import partial as _partial
try: # Not supported in Colab
from idlelib.tooltip import Hovertip
@@ -34,26 +34,43 @@ except ModuleNotFoundError:
pass
-from pathlib import Path
-from tkinter import filedialog
-from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence
+from pathlib import Path as _Path
+from tkinter import filedialog as _filedialog
+from typing import (
+ Any as _Any,
+ Callable as _Callable,
+ Dict as _Dict,
+ NamedTuple as _NamedTuple,
+ Optional as _Optional,
+ Sequence as _Sequence,
+)
try: # 3rd-party and 1st-party imports
- import torch
+ import torch as _torch
from nam import __version__
- from nam.data import Split
- from nam.train import core
- from nam.train.gui._resources import settings
- from nam.models.metadata import GearType, UserMetadata, ToneType
+ from nam.data import Split as _Split
+ from nam.train import core as _core
+ from nam.train.gui._resources import settings as _settings
+ from nam.models.metadata import (
+ GearType as _GearType,
+ UserMetadata as _UserMetadata,
+ ToneType as _ToneType,
+ )
# Ok private access here--this is technically allowed access
- from nam.train import metadata
- from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
- from nam.train._version import Version, get_current_version
+ from nam.train import metadata as _metadata
+ from nam.train._names import (
+ INPUT_BASENAMES as _INPUT_BASENAMES,
+ LATEST_VERSION as _LATEST_VERSION,
+ )
+ from nam.train._version import (
+ Version as _Version,
+ get_current_version as _get_current_version,
+ )
_install_is_valid = True
- _HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
+ _HAVE_ACCELERATOR = _torch.cuda.is_available() or _torch.backends.mps.is_available()
except ImportError:
_install_is_valid = False
_HAVE_ACCELERATOR = False
@@ -81,13 +98,13 @@ _METADATA_RIGHT_WIDTH = 60
def _is_mac() -> bool:
- return sys.platform == "darwin"
+ return _sys.platform == "darwin"
_SYSTEM_TEXT_COLOR = "systemTextColor" if _is_mac() else "black"
-@dataclass
+@_dataclass
class AdvancedOptions(object):
"""
:param architecture: Which architecture to use.
@@ -99,14 +116,14 @@ class AdvancedOptions(object):
stop.
"""
- architecture: core.Architecture
+ architecture: _core.Architecture
num_epochs: int
- latency: Optional[int]
+ latency: _Optional[int]
ignore_checks: bool
- threshold_esr: Optional[float]
+ threshold_esr: _Optional[float]
-class _PathType(Enum):
+class _PathType(_Enum):
FILE = "file"
DIRECTORY = "directory"
MULTIFILE = "multifile"
@@ -119,42 +136,42 @@ class _PathButton(object):
def __init__(
self,
- frame: tk.Frame,
+ frame: _tk.Frame,
button_text: str,
info_str: str,
path_type: _PathType,
- path_key: settings.PathKey,
- hooks: Optional[Sequence[Callable[[], None]]] = None,
+ path_key: _settings.PathKey,
+ hooks: _Optional[_Sequence[_Callable[[], None]]] = None,
color_when_not_set: str = "#EF0000", # Darker red
color_when_set: str = _SYSTEM_TEXT_COLOR,
- default: Optional[Path] = None,
+ default: _Optional[_Path] = None,
):
"""
:param hooks: Callables run at the end of setting the value.
"""
self._button_text = button_text
self._info_str = info_str
- self._path: Optional[Path] = default
+ self._path: _Optional[_Path] = default
self._path_type = path_type
self._path_key = path_key
self._frame = frame
self._widgets = {}
- self._widgets["button"] = tk.Button(
+ self._widgets["button"] = _tk.Button(
self._frame,
text=button_text,
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=self._set_val,
)
- self._widgets["button"].pack(side=tk.LEFT)
- self._widgets["label"] = tk.Label(
+ self._widgets["button"].pack(side=_tk.LEFT)
+ self._widgets["label"] = _tk.Label(
self._frame,
width=_TEXT_WIDTH,
height=_BUTTON_HEIGHT,
bg=None,
anchor="w",
)
- self._widgets["label"].pack(side=tk.LEFT)
+ self._widgets["label"].pack(side=_tk.LEFT)
self._hooks = hooks
self._color_when_not_set = color_when_not_set
self._color_when_set = color_when_set
@@ -173,7 +190,7 @@ class _PathButton(object):
)
@property
- def val(self) -> Optional[Path]:
+ def val(self) -> _Optional[_Path]:
return self._path
def _set_text(self):
@@ -189,7 +206,7 @@ class _PathButton(object):
] = f"{self._button_text.capitalize()} set to {val}"
def _set_val(self):
- last_path = settings.get_last_path(self._path_key)
+ last_path = _settings.get_last_path(self._path_key)
if last_path is None:
initial_dir = None
elif not last_path.is_dir():
@@ -197,15 +214,15 @@ class _PathButton(object):
else:
initial_dir = last_path
result = {
- _PathType.FILE: filedialog.askopenfilename,
- _PathType.DIRECTORY: filedialog.askdirectory,
- _PathType.MULTIFILE: filedialog.askopenfilenames,
+ _PathType.FILE: _filedialog.askopenfilename,
+ _PathType.DIRECTORY: _filedialog.askdirectory,
+ _PathType.MULTIFILE: _filedialog.askopenfilenames,
}[self._path_type](initialdir=str(initial_dir))
if result != "":
self._path = result
- settings.set_last_path(
+ _settings.set_last_path(
self._path_key,
- Path(result[0] if self._path_type == _PathType.MULTIFILE else result),
+ _Path(result[0] if self._path_type == _PathType.MULTIFILE else result),
)
self._set_text()
@@ -218,14 +235,14 @@ class _InputPathButton(_PathButton):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Download the training file!
- self._widgets["button_download_input"] = tk.Button(
+ self._widgets["button_download_input"] = _tk.Button(
self._frame,
text="Download input file",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=self._download_input_file,
)
- self._widgets["button_download_input"].pack(side=tk.RIGHT)
+ self._widgets["button_download_input"].pack(side=_tk.RIGHT)
@classmethod
def _download_input_file(cls):
@@ -237,20 +254,20 @@ class _InputPathButton(_PathButton):
"v1.wav": "",
}
# Pick the most recent file.
- for input_basename in INPUT_BASENAMES:
+ for input_basename in _INPUT_BASENAMES:
name = input_basename.name
url = file_urls.get(name)
if url:
- if name != LATEST_VERSION.name:
+ if name != _LATEST_VERSION.name:
print(
f"WARNING: File {name} is out of date. "
"This needs to be updated!"
)
- webbrowser.open(url)
+ _webbrowser.open(url)
return
-class _CheckboxKeys(Enum):
+class _CheckboxKeys(_Enum):
"""
Keys for checkboxes
"""
@@ -259,13 +276,13 @@ class _CheckboxKeys(Enum):
SAVE_PLOT = "save_plot"
-class _TopLevelWithOk(tk.Toplevel):
+class _TopLevelWithOk(_tk.Toplevel):
"""
Toplevel with an Ok button (provide yourself!)
"""
def __init__(
- self, on_ok: Callable[[None], None], resume_main: Callable[[None], None]
+ self, on_ok: _Callable[[None], None], resume_main: _Callable[[None], None]
):
"""
:param on_ok: What to do when "Ok" button is pressed
@@ -281,17 +298,17 @@ class _TopLevelWithOk(tk.Toplevel):
super().destroy()
-class _TopLevelWithYesNo(tk.Toplevel):
+class _TopLevelWithYesNo(_tk.Toplevel):
"""
Toplevel holding functions for yes/no buttons to close
"""
def __init__(
self,
- on_yes: Callable[[None], None],
- on_no: Callable[[None], None],
- on_close: Optional[Callable[[None], None]],
- resume_main: Callable[[None], None],
+ on_yes: _Callable[[None], None],
+ on_no: _Callable[[None], None],
+ on_close: _Optional[_Callable[[None], None]],
+ resume_main: _Callable[[None], None],
):
"""
:param on_yes: What to do when "Yes" button is pressed.
@@ -321,13 +338,13 @@ class _OkModal(object):
Message and OK button
"""
- def __init__(self, resume_main, msg: str, label_kwargs: Optional[dict] = None):
+ def __init__(self, resume_main, msg: str, label_kwargs: _Optional[dict] = None):
label_kwargs = {} if label_kwargs is None else label_kwargs
self._root = _TopLevelWithOk((lambda: None), resume_main)
- self._text = tk.Label(self._root, text=msg, **label_kwargs)
+ self._text = _tk.Label(self._root, text=msg, **label_kwargs)
self._text.pack()
- self._ok = tk.Button(
+ self._ok = _tk.Button(
self._root,
text="Ok",
width=_BUTTON_WIDTH,
@@ -344,38 +361,38 @@ class _YesNoModal(object):
def __init__(
self,
- on_yes: Callable[[None], None],
- on_no: Callable[[None], None],
+ on_yes: _Callable[[None], None],
+ on_no: _Callable[[None], None],
resume_main,
msg: str,
- on_close: Optional[Callable[[None], None]] = None,
- label_kwargs: Optional[dict] = None,
+ on_close: _Optional[_Callable[[None], None]] = None,
+ label_kwargs: _Optional[dict] = None,
):
label_kwargs = {} if label_kwargs is None else label_kwargs
self._root = _TopLevelWithYesNo(on_yes, on_no, on_close, resume_main)
- self._text = tk.Label(self._root, text=msg, **label_kwargs)
+ self._text = _tk.Label(self._root, text=msg, **label_kwargs)
self._text.pack()
- self._buttons_frame = tk.Frame(self._root)
+ self._buttons_frame = _tk.Frame(self._root)
self._buttons_frame.pack()
- self._yes = tk.Button(
+ self._yes = _tk.Button(
self._buttons_frame,
text="Yes",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=lambda: self._root.destroy(pressed_yes=True),
)
- self._yes.pack(side=tk.LEFT)
- self._no = tk.Button(
+ self._yes.pack(side=_tk.LEFT)
+ self._no = _tk.Button(
self._buttons_frame,
text="No",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=lambda: self._root.destroy(pressed_no=True),
)
- self._no.pack(side=tk.RIGHT)
+ self._no.pack(side=_tk.RIGHT)
-class _GUIWidgets(Enum):
+class _GUIWidgets(_Enum):
INPUT_PATH = "input_path"
OUTPUT_PATH = "output_path"
TRAINING_DESTINATION = "training_destination"
@@ -385,57 +402,57 @@ class _GUIWidgets(Enum):
UPDATE = "update"
-@dataclass
+@_dataclass
class Checkbox(object):
- variable: tk.BooleanVar
- check_button: tk.Checkbutton
+ variable: _tk.BooleanVar
+ check_button: _tk.Checkbutton
class GUI(object):
def __init__(self):
- self._root = tk.Tk()
+ self._root = _tk.Tk()
self._root.title(f"NAM Trainer - v{__version__}")
self._widgets = {}
# Buttons for paths:
- self._frame_input = tk.Frame(self._root)
+ self._frame_input = _tk.Frame(self._root)
self._frame_input.pack(anchor="w")
self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton(
self._frame_input,
"Input Audio",
- f"Select input (DI) file (e.g. {LATEST_VERSION.name})",
+ f"Select input (DI) file (e.g. {_LATEST_VERSION.name})",
_PathType.FILE,
- settings.PathKey.INPUT_FILE,
+ _settings.PathKey.INPUT_FILE,
hooks=[self._check_button_states],
)
- self._frame_output_path = tk.Frame(self._root)
+ self._frame_output_path = _tk.Frame(self._root)
self._frame_output_path.pack(anchor="w")
self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton(
self._frame_output_path,
"Output Audio",
"Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)",
_PathType.MULTIFILE,
- settings.PathKey.OUTPUT_FILE,
+ _settings.PathKey.OUTPUT_FILE,
hooks=[self._check_button_states],
)
- self._frame_train_destination = tk.Frame(self._root)
+ self._frame_train_destination = _tk.Frame(self._root)
self._frame_train_destination.pack(anchor="w")
self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton(
self._frame_train_destination,
"Train Destination",
"Select training output directory",
_PathType.DIRECTORY,
- settings.PathKey.TRAINING_DESTINATION,
+ _settings.PathKey.TRAINING_DESTINATION,
hooks=[self._check_button_states],
)
# Metadata
- self.user_metadata = UserMetadata()
- self._frame_metadata = tk.Frame(self._root)
+ self.user_metadata = _UserMetadata()
+ self._frame_metadata = _tk.Frame(self._root)
self._frame_metadata.pack(anchor="w")
- self._widgets["metadata"] = tk.Button(
+ self._widgets["metadata"] = _tk.Button(
self._frame_metadata,
text="Metadata...",
width=_BUTTON_WIDTH,
@@ -449,16 +466,16 @@ class GUI(object):
self._get_additional_options_frame()
# Last frames: avdanced options & train in the SE corner:
- self._frame_advanced_options = tk.Frame(self._root)
- self._frame_train = tk.Frame(self._root)
- self._frame_update = tk.Frame(self._root)
+ self._frame_advanced_options = _tk.Frame(self._root)
+ self._frame_train = _tk.Frame(self._root)
+ self._frame_update = _tk.Frame(self._root)
# Pack must be in reverse order
- self._frame_update.pack(side=tk.BOTTOM, anchor="e")
- self._frame_train.pack(side=tk.BOTTOM, anchor="e")
- self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e")
+ self._frame_update.pack(side=_tk.BOTTOM, anchor="e")
+ self._frame_train.pack(side=_tk.BOTTOM, anchor="e")
+ self._frame_advanced_options.pack(side=_tk.BOTTOM, anchor="e")
# Advanced options for training
- default_architecture = core.Architecture.STANDARD
+ default_architecture = _core.Architecture.STANDARD
self.advanced_options = AdvancedOptions(
default_architecture,
_DEFAULT_NUM_EPOCHS,
@@ -468,7 +485,7 @@ class GUI(object):
)
# Window to edit them:
- self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = tk.Button(
+ self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = _tk.Button(
self._frame_advanced_options,
text="Advanced options...",
width=_BUTTON_WIDTH,
@@ -479,7 +496,7 @@ class GUI(object):
# Train button
- self._widgets[_GUIWidgets.TRAIN] = tk.Button(
+ self._widgets[_GUIWidgets.TRAIN] = _tk.Button(
self._frame_train,
text="Train",
width=_BUTTON_WIDTH,
@@ -492,7 +509,7 @@ class GUI(object):
self._check_button_states()
- def core_train_kwargs(self) -> Dict[str, Any]:
+ def core_train_kwargs(self) -> _Dict[str, _Any]:
"""
Get any additional kwargs to provide to `core.train`
"""
@@ -528,29 +545,29 @@ class GUI(object):
self._widgets[_GUIWidgets.TRAINING_DESTINATION],
)
):
- self._widgets[_GUIWidgets.TRAIN]["state"] = tk.DISABLED
+ self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.DISABLED
return
- self._widgets[_GUIWidgets.TRAIN]["state"] = tk.NORMAL
+ self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.NORMAL
def _get_additional_options_frame(self):
# Checkboxes
# TODO get these definitions into __init__()
- self._frame_checkboxes = tk.Frame(self._root)
- self._frame_checkboxes.pack(side=tk.LEFT)
+ self._frame_checkboxes = _tk.Frame(self._root)
+ self._frame_checkboxes.pack(side=_tk.LEFT)
row = 1
def make_checkbox(
key: _CheckboxKeys, text: str, default_value: bool
) -> Checkbox:
- variable = tk.BooleanVar()
+ variable = _tk.BooleanVar()
variable.set(default_value)
- check_button = tk.Checkbutton(
+ check_button = _tk.Checkbutton(
self._frame_checkboxes, text=text, variable=variable
)
self._checkboxes[key] = Checkbox(variable, check_button)
self._widgets[key] = check_button # For tracking in set-all-widgets ops
- self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
+ self._checkboxes: _Dict[_CheckboxKeys, Checkbox] = dict()
make_checkbox(
_CheckboxKeys.SILENT_TRAINING,
"Silent run (suggested for batch training)",
@@ -568,7 +585,7 @@ class GUI(object):
self._root.mainloop()
def _disable(self):
- self._set_all_widget_states_to(tk.DISABLED)
+ self._set_all_widget_states_to(_tk.DISABLED)
def _open_advanced_options(self):
"""
@@ -584,15 +601,15 @@ class GUI(object):
self._wait_while_func(lambda resume: UserMetadataGUI(resume, self))
- def _pack_update_button(self, version_from: Version, version_to: Version):
+ def _pack_update_button(self, version_from: _Version, version_to: _Version):
"""
Pack a button that a user can click to update
"""
def update_nam():
- result = subprocess.run(
+ result = _subprocess.run(
[
- f"{sys.executable}",
+ f"{_sys.executable}",
"-m",
"pip",
"install",
@@ -611,7 +628,7 @@ class GUI(object):
"Update failed! See logs.",
)
- self._widgets[_GUIWidgets.UPDATE] = tk.Button(
+ self._widgets[_GUIWidgets.UPDATE] = _tk.Button(
self._frame_update,
text=f"Update ({str(version_from)} -> {str(version_to)})",
width=_BUTTON_WIDTH,
@@ -621,18 +638,18 @@ class GUI(object):
self._widgets[_GUIWidgets.UPDATE].pack()
def _pack_update_button_if_update_is_available(self):
- class UpdateInfo(NamedTuple):
+ class UpdateInfo(_NamedTuple):
available: bool
- current_version: Version
- new_version: Optional[Version]
+ current_version: _Version
+ new_version: _Optional[_Version]
def get_info() -> UpdateInfo:
# TODO error handling
url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases"
- current_version = get_current_version()
+ current_version = _get_current_version()
try:
- response = requests.get(url)
- except requests.exceptions.ConnectionError:
+ response = _requests.get(url)
+ except _requests.exceptions.ConnectionError:
print("WARNING: Failed to reach the server to check for updates")
return UpdateInfo(
available=False, current_version=current_version, new_version=None
@@ -651,7 +668,7 @@ class GUI(object):
if not tag.startswith("v"):
print(f"Found invalid version {tag}")
else:
- this_version = Version.from_string(tag[1:])
+ this_version = _Version.from_string(tag[1:])
if latest_version is None or this_version > latest_version:
latest_version = this_version
else:
@@ -672,7 +689,7 @@ class GUI(object):
)
def _resume(self):
- self._set_all_widget_states_to(tk.NORMAL)
+ self._set_all_widget_states_to(_tk.NORMAL)
self._check_button_states()
def _set_all_widget_states_to(self, state):
@@ -700,12 +717,12 @@ class GUI(object):
# Run it
for file in file_list:
print(f"Now training {file}")
- basename = re.sub(r"\.wav$", "", file.split("/")[-1])
+ basename = _re.sub(r"\.wav$", "", file.split("/")[-1])
user_metadata = (
- self.user_metadata if self.user_metadata_flag else UserMetadata()
+ self.user_metadata if self.user_metadata_flag else _UserMetadata()
)
- train_output = core.train(
+ train_output = _core.train(
input_path,
file,
self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
@@ -735,7 +752,7 @@ class GUI(object):
basename=basename,
user_metadata=user_metadata,
other_metadata={
- metadata.TRAINING_KEY: train_output.metadata.model_dump()
+ _metadata.TRAINING_KEY: train_output.metadata.model_dump()
},
)
print("Done!")
@@ -745,7 +762,7 @@ class GUI(object):
self.user_metadata_flag = False
def _validate_all_data(
- self, input_path: Path, output_paths: Sequence[Path]
+ self, input_path: _Path, output_paths: _Sequence[_Path]
) -> bool:
"""
Validate all the data.
@@ -757,14 +774,14 @@ class GUI(object):
"""
def make_message_for_file(
- output_path: str, validation_output: core.DataValidationOutput
+ output_path: str, validation_output: _core.DataValidationOutput
) -> str:
"""
State the file and explain what's wrong with it.
"""
# TODO put this closer to what it looks at, i.e. core.DataValidationOutput
msg = (
- f"\t{Path(output_path).name}:\n" # They all have the same directory so
+ f"\t{_Path(output_path).name}:\n" # They all have the same directory so
)
if not validation_output.sample_rate.passed:
msg += (
@@ -798,14 +815,14 @@ class GUI(object):
msg += "\t\t* A data check failed (TODO in more detail).\n"
if not validation_output.pytorch.passed:
msg += "\t\t* PyTorch data set errors:\n"
- for split in Split:
+ for split in _Split:
split_validation = getattr(validation_output.pytorch, split.value)
if not split_validation.passed:
msg += f" * {split.value:10s}: {split_validation.msg}\n"
return msg
# Validate input
- input_validation = core.validate_input(input_path)
+ input_validation = _core.validate_input(input_path)
if not input_validation.passed:
self._wait_while_func(
(lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
@@ -816,7 +833,7 @@ class GUI(object):
user_latency = self.advanced_options.latency
file_validation_outputs = {
- output_path: core.validate_data(
+ output_path: _core.validate_data(
input_path,
output_path,
user_latency,
@@ -920,12 +937,12 @@ def _rstripped_str(val):
return str(val).rstrip()
-class _SettingWidget(abc.ABC):
+class _SettingWidget(_abc.ABC):
"""
A widget for the user to interact with to set something
"""
- @abc.abstractmethod
+ @_abc.abstractmethod
def get(self):
pass
@@ -936,7 +953,11 @@ class LabeledOptionMenu(_SettingWidget):
"""
def __init__(
- self, frame: tk.Frame, label: str, choices: Enum, default: Optional[Enum] = None
+ self,
+ frame: _tk.Frame,
+ label: str,
+ choices: _Enum,
+ default: _Optional[_Enum] = None,
):
"""
:param command: Called to propagate option selection. Is provided with the
@@ -946,7 +967,7 @@ class LabeledOptionMenu(_SettingWidget):
self._choices = choices
height = _BUTTON_HEIGHT
bg = None
- self._label = tk.Label(
+ self._label = _tk.Label(
frame,
width=_ADVANCED_OPTIONS_LEFT_WIDTH,
height=height,
@@ -954,26 +975,26 @@ class LabeledOptionMenu(_SettingWidget):
anchor="w",
text=label,
)
- self._label.pack(side=tk.LEFT)
+ self._label.pack(side=_tk.LEFT)
- frame_menu = tk.Frame(frame)
- frame_menu.pack(side=tk.RIGHT)
+ frame_menu = _tk.Frame(frame)
+ frame_menu.pack(side=_tk.RIGHT)
self._selected_value = None
default = (list(choices)[0] if default is None else default).value
- self._menu = tk.OptionMenu(
+ self._menu = _tk.OptionMenu(
frame_menu,
- tk.StringVar(master=frame, value=default, name=label),
+ _tk.StringVar(master=frame, value=default, name=label),
# default,
*[choice.value for choice in choices], # if choice.value!=default],
command=self._set,
)
self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH)
- self._menu.pack(side=tk.RIGHT)
+ self._menu.pack(side=_tk.RIGHT)
# Initialize
self._set(default)
- def get(self) -> Enum:
+ def get(self) -> _Enum:
return self._selected_value
def _set(self, val: str):
@@ -992,12 +1013,12 @@ class _Hovertip(Hovertip):
def showcontents(self):
# Override
- label = tk.Label(
+ label = _tk.Label(
self.tipwindow,
text=self.text,
- justify=tk.LEFT,
+ justify=_tk.LEFT,
background="#ffffe0",
- relief=tk.SOLID,
+ relief=_tk.SOLID,
borderwidth=1,
fg="black",
)
@@ -1011,7 +1032,7 @@ class LabeledText(_SettingWidget):
def __init__(
self,
- frame: tk.Frame,
+ frame: _tk.Frame,
label: str,
default=None,
type=None,
@@ -1028,7 +1049,7 @@ class LabeledText(_SettingWidget):
self._frame = frame
label_height = 2
text_height = 1
- self._label = tk.Label(
+ self._label = _tk.Label(
frame,
width=left_width,
height=label_height,
@@ -1036,15 +1057,15 @@ class LabeledText(_SettingWidget):
anchor="e",
text=label,
)
- self._label.pack(side=tk.LEFT)
+ self._label.pack(side=_tk.LEFT)
- self._text = tk.Text(
+ self._text = _tk.Text(
frame,
width=right_width,
height=text_height,
bg=None,
)
- self._text.pack(side=tk.RIGHT)
+ self._text.pack(side=_tk.RIGHT)
self._type = (lambda x: x) if type is None else type
@@ -1052,10 +1073,10 @@ class LabeledText(_SettingWidget):
self._text.insert("1.0", str(default))
# You can assign a tooltip for the label if you'd like.
- self.label_tooltip: Optional[_Hovertip] = None
+ self.label_tooltip: _Optional[_Hovertip] = None
@property
- def label(self) -> tk.Label:
+ def label(self) -> _tk.Label:
return self._label
def get(self):
@@ -1064,7 +1085,7 @@ class LabeledText(_SettingWidget):
May throw a tk.TclError indicating something went wrong getting the value.
"""
# "1.0" means Line 1, character zero (wat)
- return self._type(self._text.get("1.0", tk.END))
+ return self._type(self._text.get("1.0", _tk.END))
class AdvancedOptionsGUI(object):
@@ -1080,9 +1101,9 @@ class AdvancedOptionsGUI(object):
self.pack()
# "Ok": apply and destroy
- self._frame_ok = tk.Frame(self._root)
+ self._frame_ok = _tk.Frame(self._root)
self._frame_ok.pack()
- self._button_ok = tk.Button(
+ self._button_ok = _tk.Button(
self._frame_ok,
text="Ok",
width=_BUTTON_WIDTH,
@@ -1113,17 +1134,17 @@ class AdvancedOptionsGUI(object):
# easier to work with.
# Architecture: radio buttons
- self._frame_architecture = tk.Frame(self._root)
+ self._frame_architecture = _tk.Frame(self._root)
self._frame_architecture.pack()
self._architecture = LabeledOptionMenu(
self._frame_architecture,
"Architecture",
- core.Architecture,
+ _core.Architecture,
default=self._parent.advanced_options.architecture,
)
# Number of epochs: text box
- self._frame_epochs = tk.Frame(self._root)
+ self._frame_epochs = _tk.Frame(self._root)
self._frame_epochs.pack()
self._num_epochs = LabeledText(
@@ -1134,7 +1155,7 @@ class AdvancedOptionsGUI(object):
)
# Delay: text box
- self._frame_latency = tk.Frame(self._root)
+ self._frame_latency = _tk.Frame(self._root)
self._frame_latency.pack()
self._latency = LabeledText(
@@ -1145,7 +1166,7 @@ class AdvancedOptionsGUI(object):
)
# Threshold ESR
- self._frame_threshold_esr = tk.Frame(self._root)
+ self._frame_threshold_esr = _tk.Frame(self._root)
self._frame_threshold_esr.pack()
self._threshold_esr = LabeledText(
self._frame_threshold_esr,
@@ -1168,9 +1189,9 @@ class UserMetadataGUI(object):
self.pack()
# "Ok": apply and destroy
- self._frame_ok = tk.Frame(self._root)
+ self._frame_ok = _tk.Frame(self._root)
self._frame_ok.pack()
- self._button_ok = tk.Button(
+ self._button_ok = _tk.Button(
self._frame_ok,
text="Ok",
width=_BUTTON_WIDTH,
@@ -1210,7 +1231,7 @@ class UserMetadataGUI(object):
# TODO things that are `_SettingWidget`s are named carefully, need to make this
# easier to work with.
- LabeledText_ = partial(
+ LabeledText_ = _partial(
LabeledText,
left_width=_METADATA_LEFT_WIDTH,
right_width=_METADATA_RIGHT_WIDTH,
@@ -1218,7 +1239,7 @@ class UserMetadataGUI(object):
parent = self._parent
# Name
- self._frame_name = tk.Frame(self._root)
+ self._frame_name = _tk.Frame(self._root)
self._frame_name.pack()
self._name = LabeledText_(
self._frame_name,
@@ -1227,7 +1248,7 @@ class UserMetadataGUI(object):
type=_rstripped_str,
)
# Modeled by
- self._frame_modeled_by = tk.Frame(self._root)
+ self._frame_modeled_by = _tk.Frame(self._root)
self._frame_modeled_by.pack()
self._modeled_by = LabeledText_(
self._frame_modeled_by,
@@ -1236,7 +1257,7 @@ class UserMetadataGUI(object):
type=_rstripped_str,
)
# Gear make
- self._frame_gear_make = tk.Frame(self._root)
+ self._frame_gear_make = _tk.Frame(self._root)
self._frame_gear_make.pack()
self._gear_make = LabeledText_(
self._frame_gear_make,
@@ -1245,7 +1266,7 @@ class UserMetadataGUI(object):
type=_rstripped_str,
)
# Gear model
- self._frame_gear_model = tk.Frame(self._root)
+ self._frame_gear_model = _tk.Frame(self._root)
self._frame_gear_model.pack()
self._gear_model = LabeledText_(
self._frame_gear_model,
@@ -1254,7 +1275,7 @@ class UserMetadataGUI(object):
type=_rstripped_str,
)
# Calibration: input & output dBu
- self._frame_input_dbu = tk.Frame(self._root)
+ self._frame_input_dbu = _tk.Frame(self._root)
self._frame_input_dbu.pack()
self._input_level_dbu = LabeledText_(
self._frame_input_dbu,
@@ -1272,7 +1293,7 @@ class UserMetadataGUI(object):
"Record the value here."
),
)
- self._frame_output_dbu = tk.Frame(self._root)
+ self._frame_output_dbu = _tk.Frame(self._root)
self._frame_output_dbu.pack()
self._output_level_dbu = LabeledText_(
self._frame_output_dbu,
@@ -1293,36 +1314,36 @@ class UserMetadataGUI(object):
),
)
# Gear type
- self._frame_gear_type = tk.Frame(self._root)
+ self._frame_gear_type = _tk.Frame(self._root)
self._frame_gear_type.pack()
self._gear_type = LabeledOptionMenu(
self._frame_gear_type,
"Gear type",
- GearType,
+ _GearType,
default=parent.user_metadata.gear_type,
)
# Tone type
- self._frame_tone_type = tk.Frame(self._root)
+ self._frame_tone_type = _tk.Frame(self._root)
self._frame_tone_type.pack()
self._tone_type = LabeledOptionMenu(
self._frame_tone_type,
"Tone type",
- ToneType,
+ _ToneType,
default=parent.user_metadata.tone_type,
)
def _install_error():
- window = tk.Tk()
+ window = _tk.Tk()
window.title("ERROR")
- label = tk.Label(
+ label = _tk.Label(
window,
width=45,
height=2,
text="The NAM training software has not been installed correctly.",
)
label.pack()
- button = tk.Button(window, width=10, height=2, text="Quit", command=window.destroy)
+ button = _tk.Button(window, width=10, height=2, text="Quit", command=window.destroy)
button.pack()
window.mainloop()
diff --git a/nam/train/lightning_module.py b/nam/train/lightning_module.py
@@ -10,32 +10,38 @@ along with loss function boilerplate.
For the base *PyTorch* model containing the actual architecture, see `..models.base`.
"""
-from dataclasses import dataclass
-from enum import Enum
-from typing import Any, Dict, NamedTuple, Optional, Tuple
-
-import auraloss
-import logging
-import pytorch_lightning as pl
-import torch
-import torch.nn as nn
-
-from .._core import InitializableFromConfig
-from ..models.conv_net import ConvNet
-from ..models.linear import Linear
+from dataclasses import dataclass as _dataclass
+from enum import Enum as _Enum
+from typing import (
+ Any as _Any,
+ Dict as _Dict,
+ NamedTuple as _NamedTuple,
+ Optional as _Optional,
+ Tuple as _Tuple,
+)
+
+import auraloss as _auraloss
+import logging as _logging
+import pytorch_lightning as _pl
+import torch as _torch
+import torch.nn as _nn
+
+from .._core import InitializableFromConfig as _InitializableFromConfig
+from ..models.conv_net import ConvNet as _ConvNet
+from ..models.linear import Linear as _Linear
from ..models.losses import (
- apply_pre_emphasis_filter,
- esr,
- multi_resolution_stft_loss,
- mse_fft,
+ apply_pre_emphasis_filter as _apply_pre_emphasis_filter,
+ esr as _esr,
+ multi_resolution_stft_loss as _multi_resolution_stft_loss,
+ mse_fft as _mse_fft,
)
-from ..models.recurrent import LSTM
-from ..models.wavenet import WaveNet
+from ..models.recurrent import LSTM as _LSTM
+from ..models.wavenet import WaveNet as _WaveNet
-logger = logging.getLogger(__name__)
+logger = _logging.getLogger(__name__)
-class ValidationLoss(Enum):
+class ValidationLoss(_Enum):
"""
mse: mean squared error
esr: error signal ratio (Eq. (10) from
@@ -51,8 +57,8 @@ class ValidationLoss(Enum):
ESR = "esr"
-@dataclass
-class LossConfig(InitializableFromConfig):
+@_dataclass
+class LossConfig(_InitializableFromConfig):
"""
:param mrstft_weight: Multi-resolution short-time Fourier transform loss
coefficient. None means to skip; 2e-4 works pretty well if one wants to use it.
@@ -64,15 +70,15 @@ class LossConfig(InitializableFromConfig):
:param pre_
"""
- mrstft_weight: Optional[float] = None
+ mrstft_weight: _Optional[float] = None
fourier: bool = False
mask_first: int = 0
dc_weight: float = None
val_loss: ValidationLoss = ValidationLoss.MSE
- pre_emph_weight: Optional[float] = None
- pre_emph_coef: Optional[float] = None
- pre_emph_mrstft_weight: Optional[float] = None
- pre_emph_mrstft_coef: Optional[float] = None
+ pre_emph_weight: _Optional[float] = None
+ pre_emph_coef: _Optional[float] = None
+ pre_emph_mrstft_weight: _Optional[float] = None
+ pre_emph_mrstft_coef: _Optional[float] = None
@classmethod
def parse_config(cls, config):
@@ -97,7 +103,7 @@ class LossConfig(InitializableFromConfig):
return tuple(a[..., self.mask_first :] for a in args)
@classmethod
- def _get_mrstft_weight(cls, config) -> Optional[float]:
+ def _get_mrstft_weight(cls, config) -> _Optional[float]:
key = "mrstft_weight"
wrong_key = "mstft_key" # Backward compatibility
if key in config:
@@ -117,20 +123,20 @@ class LossConfig(InitializableFromConfig):
return None
-class _LossItem(NamedTuple):
- weight: Optional[float]
- value: Optional[torch.Tensor]
+class _LossItem(_NamedTuple):
+ weight: _Optional[float]
+ value: _Optional[_torch.Tensor]
_model_net_init_registry = {
- "ConvNet": ConvNet.init_from_config,
- "Linear": Linear.init_from_config,
- "LSTM": LSTM.init_from_config,
- "WaveNet": WaveNet.init_from_config,
+ "ConvNet": _ConvNet.init_from_config,
+ "Linear": _Linear.init_from_config,
+ "LSTM": _LSTM.init_from_config,
+ "WaveNet": _WaveNet.init_from_config,
}
-class LightningModule(pl.LightningModule, InitializableFromConfig):
+class LightningModule(_pl.LightningModule, _InitializableFromConfig):
"""
The PyTorch Lightning Module that unites the model with its loss and
optimization recipe.
@@ -139,9 +145,9 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
def __init__(
self,
net,
- optimizer_config: Optional[dict] = None,
- scheduler_config: Optional[dict] = None,
- loss_config: Optional[LossConfig] = None,
+ optimizer_config: _Optional[dict] = None,
+ scheduler_config: _Optional[dict] = None,
+ loss_config: _Optional[LossConfig] = None,
):
"""
:param scheduler_config: contains
@@ -162,7 +168,7 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
# Where to compute the MRSTFT.
# Keeping it on-device is preferable, but if that fails, then remember to drop
# it to cpu from then on.
- self._mrstft_device: Optional[torch.device] = None
+ self._mrstft_device: _Optional[_torch.device] = None
@classmethod
def init_from_config(cls, config):
@@ -223,16 +229,16 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
_model_net_init_registry[name] = constructor
@property
- def net(self) -> nn.Module:
+ def net(self) -> _nn.Module:
return self._net
def configure_optimizers(self):
- optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config)
+ optimizer = _torch.optim.Adam(self.parameters(), **self._optimizer_config)
if self._scheduler_config is None:
return optimizer
else:
lr_scheduler = getattr(
- torch.optim.lr_scheduler, self._scheduler_config["class"]
+ _torch.optim.lr_scheduler, self._scheduler_config["class"]
)(optimizer, **self._scheduler_config["kwargs"])
lr_scheduler_config = {"scheduler": lr_scheduler}
for key in ("interval", "frequency", "monitor"):
@@ -243,17 +249,17 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
def forward(self, *args, **kwargs):
return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.
- def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
+ def on_load_checkpoint(self, checkpoint: _Dict[str, _Any]) -> None:
# Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
self.net.sample_rate = checkpoint["sample_rate"]
- def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
+ def on_save_checkpoint(self, checkpoint: _Dict[str, _Any]) -> None:
# Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
checkpoint["sample_rate"] = self.net.sample_rate
def _shared_step(
self, batch
- ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]:
+ ) -> _Tuple[_torch.Tensor, _torch.Tensor, _Dict[str, _LossItem]]:
"""
B: Batch size
L: Sequence length
@@ -267,7 +273,7 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
loss_dict = {} # Mind keys versus validation loss requested...
# Prediction aka MSE loss
if self._loss_config.fourier:
- loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets))
+ loss_dict["MSE_FFT"] = _LossItem(1.0, _mse_fft(preds, targets))
else:
loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets))
# Pre-emphasized MSE
@@ -300,8 +306,8 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
if dc_weight is not None and dc_weight > 0.0:
# Denominator could be a bad idea. I'm going to omit it esp since I'm
# using mini batches
- mean_dims = torch.arange(1, preds.ndim).tolist()
- dc_loss = nn.MSELoss()(
+ mean_dims = _torch.arange(1, preds.ndim).tolist()
+ dc_loss = _nn.MSELoss()(
preds.mean(dim=mean_dims), targets.mean(dim=mean_dims)
)
loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss)
@@ -344,7 +350,7 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
)
return val_loss
- def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ def _esr_loss(self, preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor:
"""
Error signal ratio aka ESR loss.
@@ -358,21 +364,21 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
:param targets: (B,L)
:return: ()
"""
- return esr(preds, targets)
+ return _esr(preds, targets)
- def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None):
+ def _mse_loss(self, preds, targets, pre_emph_coef: _Optional[float] = None):
if pre_emph_coef is not None:
preds, targets = [
- apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
+ _apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
]
- return nn.MSELoss()(preds, targets)
+ return _nn.MSELoss()(preds, targets)
def _mrstft_loss(
self,
- preds: torch.Tensor,
- targets: torch.Tensor,
- pre_emph_coef: Optional[float] = None,
- ) -> torch.Tensor:
+ preds: _torch.Tensor,
+ targets: _torch.Tensor,
+ pre_emph_coef: _Optional[float] = None,
+ ) -> _torch.Tensor:
"""
Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
B: Batch size
@@ -383,16 +389,16 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
:return: ()
"""
if self._mrstft is None:
- self._mrstft = auraloss.freq.MultiResolutionSTFTLoss()
+ self._mrstft = _auraloss.freq.MultiResolutionSTFTLoss()
backup_device = "cpu"
if pre_emph_coef is not None:
preds, targets = [
- apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
+ _apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
]
try:
- return multi_resolution_stft_loss(
+ return _multi_resolution_stft_loss(
preds, targets, self._mrstft, device=self._mrstft_device
)
except Exception as e:
@@ -400,6 +406,6 @@ class LightningModule(pl.LightningModule, InitializableFromConfig):
raise e
logger.warning("MRSTFT failed on device; falling back to CPU")
self._mrstft_device = backup_device
- return multi_resolution_stft_loss(
+ return _multi_resolution_stft_loss(
preds, targets, self._mrstft, device=self._mrstft_device
)
diff --git a/nam/train/metadata.py b/nam/train/metadata.py
@@ -9,15 +9,15 @@ Information from the simplified trainers that is good to know about.
# This isn't part of ../metadata because it's not necessarily worth knowning about--only
# if you're using the simplified trainers!
-from typing import List, Optional
+from typing import List as _List, Optional as _Optional
-from pydantic import BaseModel
+from pydantic import BaseModel as _BaseModel
# The key under which the metadata are saved in the .nam:
TRAINING_KEY = "training"
-class Settings(BaseModel):
+class Settings(_BaseModel):
"""
User-provided settings
"""
@@ -25,7 +25,7 @@ class Settings(BaseModel):
ignore_checks: bool
-class LatencyCalibrationWarnings(BaseModel):
+class LatencyCalibrationWarnings(_BaseModel):
"""
Things that aren't necessarily wrong with the latency calibration but are
worth looking into.
@@ -42,34 +42,34 @@ class LatencyCalibrationWarnings(BaseModel):
disagreement_too_high: bool
-class LatencyCalibration(BaseModel):
+class LatencyCalibration(_BaseModel):
algorithm_version: int
- delays: List[int]
+ delays: _List[int]
safety_factor: int
recommended: int
warnings: LatencyCalibrationWarnings
-class Latency(BaseModel):
+class Latency(_BaseModel):
"""
Information about the latency
"""
- manual: Optional[int]
+ manual: _Optional[int]
calibration: LatencyCalibration
-class DataChecks(BaseModel):
+class DataChecks(_BaseModel):
version: int
passed: bool
-class Data(BaseModel):
+class Data(_BaseModel):
latency: Latency
checks: DataChecks
-class TrainingMetadata(BaseModel):
+class TrainingMetadata(_BaseModel):
settings: Settings
data: Data
- validation_esr: Optional[float]
+ validation_esr: _Optional[float]
diff --git a/nam/util.py b/nam/util.py
@@ -6,12 +6,12 @@
Helpful utilities
"""
-import warnings
-from datetime import datetime
+import warnings as _warnings
+from datetime import datetime as _datetime
def timestamp() -> str:
- t = datetime.now()
+ t = _datetime.now()
return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}"
@@ -28,10 +28,10 @@ class _FilterWarnings(object):
self._kwargs = kwargs
def __enter__(self):
- warnings.filterwarnings(*self._args, **self._kwargs)
+ _warnings.filterwarnings(*self._args, **self._kwargs)
def __exit__(self, exc_type, exc_val, exc_tb):
- warnings.resetwarnings()
+ _warnings.resetwarnings()
def filter_warnings(*args, **kwargs):
diff --git a/tests/test_nam/test_train/test_lightning_module.py b/tests/test_nam/test_train/test_lightning_module.py
@@ -48,7 +48,7 @@ def test_mrstft_loss_cpu_fallback(mocker):
raise RuntimeError("Trigger fallback")
return _torch.tensor(1.0)
- mocker.patch("nam.train.lightning_module.multi_resolution_stft_loss", mocked_loss)
+ mocker.patch("nam.train.lightning_module._multi_resolution_stft_loss", mocked_loss)
batch_size = 3
sequence_length = 4096