commit 6ecb5e86970c22c3a69325c8fc01b9932707d32c
parent e0c73b26277936e3da8f42bb28114a181a8b4bd9
Author: Steven Atkinson <steven@atkinson.mn>
Date: Tue, 31 Jan 2023 19:24:07 -0800
Add different pre-made architectures to easy mode (#81)
Diffstat:
2 files changed, 103 insertions(+), 39 deletions(-)
diff --git a/bin/train/easy_colab.ipynb b/bin/train/easy_colab.ipynb
@@ -98,7 +98,10 @@
"outputs": [],
"source": [
"%tensorboard --logdir /content/lightning_logs\n",
- "run(epochs=100)\n",
+ "run(\n",
+ " epochs=100,\n",
+ " architecture=\"standard\" # standard, lite, feather\n",
+ ")\n",
"# Psst! Curious how it's going?\n",
"# You can look under lightning_logs/version_0/checkpoints to see it saving its progress.\n",
"# Look for the number after \"ESR\" to go down. It will start around 1.0. 0.1 is ok, and \n",
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -6,6 +6,7 @@
Hide the mess in Colab to make things look pretty for users.
"""
+from enum import Enum
from pathlib import Path
from time import time
from typing import Optional, Tuple
@@ -20,6 +21,12 @@ from nam.data import REQUIRED_RATE, Split, init_dataset, wav_to_np
from nam.models import Model
+class _Architecture(Enum):
+ STANDARD = "standard"
+ LITE = "lite"
+ FEATHER = "feather"
+
+
class _Version:
def __init__(self, major: int, minor: int, patch: int):
self.major = major
@@ -133,13 +140,97 @@ def _calibrate_delay(
return delay
+def _get_wavenet_config(architecture):
+ return {
+ _Architecture.STANDARD: {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "channels": 16,
+ "head_size": 8,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": 16,
+ "channels": 8,
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": True,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ _Architecture.LITE: {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "channels": 12,
+ "head_size": 6,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": 12,
+ "channels": 6,
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": True,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ _Architecture.FEATHER: {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "channels": 8,
+ "head_size": 4,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": 8,
+ "channels": 4,
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": True,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ }[architecture]
+
+
def _get_configs(
input_basename: str,
delay: int,
epochs: int,
- stage_1_channels: int,
- stage_2_channels: int,
- head_scale: float,
+ architecture: _Architecture,
lr: float,
lr_decay: float,
):
@@ -159,33 +250,7 @@ def _get_configs(
"name": "WaveNet",
# This should do decently. If you really want a nice model, try turning up
# "channels" in the first block and "input_size" in the second from 12 to 16.
- "config": {
- "layers_configs": [
- {
- "input_size": 1,
- "condition_size": 1,
- "head_size": stage_2_channels,
- "channels": stage_1_channels,
- "kernel_size": 3,
- "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
- "activation": "Tanh",
- "gated": False,
- "head_bias": False,
- },
- {
- "input_size": stage_1_channels,
- "condition_size": 1,
- "head_size": 1,
- "channels": stage_2_channels,
- "kernel_size": 3,
- "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
- "activation": "Tanh",
- "gated": False,
- "head_bias": True,
- },
- ],
- "head_scale": head_scale,
- },
+ "config": _get_wavenet_config(architecture),
},
"loss": {"val_loss": "esr"},
"optimizer": {"lr": lr},
@@ -258,11 +323,9 @@ def _get_valid_export_directory():
def run(
- epochs=100,
- delay=None,
- stage_1_channels=16,
- stage_2_channels=8,
- head_scale: float = 0.02,
+ epochs: int = 100,
+ delay: Optional[int] = None,
+ architecture: str = "standard",
lr=0.004,
lr_decay=0.007,
seed=0,
@@ -284,9 +347,7 @@ def run(
input_basename,
delay,
epochs,
- stage_1_channels,
- stage_2_channels,
- head_scale,
+ _Architecture(architecture),
lr,
lr_decay,
)