neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

full.py (6772B)


      1 # File: full.py
      2 # Created Date: Tuesday March 26th 2024
      3 # Author: Enrico Schifano (eraz1997@live.it)
      4 
      5 import json as _json
      6 from pathlib import Path as _Path
      7 from time import time as _time
      8 from typing import Optional as _Optional, Union as _Union
      9 from warnings import warn as _warn
     10 
     11 import matplotlib.pyplot as _plt
     12 import numpy as _np
     13 import pytorch_lightning as _pl
     14 from pytorch_lightning.utilities.warnings import (
     15     PossibleUserWarning as _PossibleUserWarning,
     16 )
     17 import torch as _torch
     18 from torch.utils.data import DataLoader as _DataLoader
     19 
     20 from nam.data import (
     21     ConcatDataset as _ConcatDataset,
     22     Split as _Split,
     23     init_dataset as _init_dataset,
     24 )
     25 from nam.train.lightning_module import LightningModule as _LightningModule
     26 from nam.util import filter_warnings as _filter_warnings
     27 
     28 _torch.manual_seed(0)
     29 
     30 
     31 def _rms(x: _Union[_np.ndarray, _torch.Tensor]) -> float:
     32     if isinstance(x, _np.ndarray):
     33         return _np.sqrt(_np.mean(_np.square(x)))
     34     elif isinstance(x, _torch.Tensor):
     35         return _torch.sqrt(_torch.mean(_torch.square(x))).item()
     36     else:
     37         raise TypeError(type(x))
     38 
     39 
     40 def _plot(
     41     model,
     42     ds,
     43     savefig=None,
     44     show=True,
     45     window_start: _Optional[int] = None,
     46     window_end: _Optional[int] = None,
     47 ):
     48     if isinstance(ds, _ConcatDataset):
     49 
     50         def extend_savefig(i, savefig):
     51             if savefig is None:
     52                 return None
     53             savefig = _Path(savefig)
     54             extension = savefig.name.split(".")[-1]
     55             stem = savefig.name[: -len(extension) - 1]
     56             return _Path(savefig.parent, f"{stem}_{i}.{extension}")
     57 
     58         for i, ds_i in enumerate(ds.datasets):
     59             _plot(
     60                 model,
     61                 ds_i,
     62                 savefig=extend_savefig(i, savefig),
     63                 show=show and i == len(ds.datasets) - 1,
     64                 window_start=window_start,
     65                 window_end=window_end,
     66             )
     67         return
     68     with _torch.no_grad():
     69         tx = len(ds.x) / 48_000
     70         print(f"Run (t={tx:.2f})")
     71         t0 = _time()
     72         output = model(ds.x).flatten().cpu().numpy()
     73         t1 = _time()
     74         try:
     75             rt = f"{tx / (t1 - t0):.2f}"
     76         except ZeroDivisionError as e:
     77             rt = "???"
     78         print(f"Took {t1 - t0:.2f} ({rt}x)")
     79 
     80     _plt.figure(figsize=(16, 5))
     81     _plt.plot(output[window_start:window_end], label="Prediction")
     82     _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
     83     nrmse = _rms(_torch.Tensor(output) - ds.y) / _rms(ds.y)
     84     esr = nrmse**2
     85     _plt.title(f"ESR={esr:.3f}")
     86     _plt.legend()
     87     if savefig is not None:
     88         _plt.savefig(savefig)
     89     if show:
     90         _plt.show()
     91 
     92 
     93 def _create_callbacks(learning_config):
     94     """
     95     Checkpointing, essentially
     96     """
     97     # Checkpoints should be run every time the validation check is run.
     98     # So base it off of learning_config["trainer"]["val_check_interval"] if it's there.
     99     validate_inside_epoch = "val_check_interval" in learning_config["trainer"]
    100     if validate_inside_epoch:
    101         kwargs = {
    102             "every_n_train_steps": learning_config["trainer"]["val_check_interval"]
    103         }
    104     else:
    105         kwargs = {
    106             "every_n_epochs": learning_config["trainer"].get(
    107                 "check_val_every_n_epoch", 1
    108             )
    109         }
    110 
    111     checkpoint_best = _pl.callbacks.model_checkpoint.ModelCheckpoint(
    112         filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}",
    113         save_top_k=3,
    114         monitor="val_loss",
    115         **kwargs,
    116     )
    117 
    118     # return [checkpoint_best, checkpoint_last]
    119     # The last epoch that was finished.
    120     checkpoint_epoch = _pl.callbacks.model_checkpoint.ModelCheckpoint(
    121         filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1
    122     )
    123     if not validate_inside_epoch:
    124         return [checkpoint_best, checkpoint_epoch]
    125     else:
    126         # The last validation pass, whether at the end of an epoch or not
    127         checkpoint_last = _pl.callbacks.model_checkpoint.ModelCheckpoint(
    128             filename="checkpoint_last_{epoch:04d}_{step}", **kwargs
    129         )
    130         return [checkpoint_best, checkpoint_last, checkpoint_epoch]
    131 
    132 
    133 def main(
    134     data_config,
    135     model_config,
    136     learning_config,
    137     outdir: _Path,
    138     no_show: bool = False,
    139     make_plots=True,
    140 ):
    141     if not outdir.exists():
    142         raise RuntimeError(f"No output location found at {outdir}")
    143     # Write
    144     for basename, config in (
    145         ("data", data_config),
    146         ("model", model_config),
    147         ("learning", learning_config),
    148     ):
    149         with open(_Path(outdir, f"config_{basename}.json"), "w") as fp:
    150             _json.dump(config, fp, indent=4)
    151 
    152     model = _LightningModule.init_from_config(model_config)
    153     # Add receptive field to data config:
    154     data_config["common"] = data_config.get("common", {})
    155     if "nx" in data_config["common"]:
    156         _warn(
    157             f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}"
    158         )
    159     data_config["common"]["nx"] = model.net.receptive_field
    160 
    161     dataset_train = _init_dataset(data_config, _Split.TRAIN)
    162     dataset_validation = _init_dataset(data_config, _Split.VALIDATION)
    163     if dataset_train.sample_rate != dataset_validation.sample_rate:
    164         raise RuntimeError(
    165             "Train and validation data loaders have different data set sample rates: "
    166             f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}"
    167         )
    168     model.net.sample_rate = dataset_train.sample_rate
    169     train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"])
    170     val_dataloader = _DataLoader(
    171         dataset_validation, **learning_config["val_dataloader"]
    172     )
    173 
    174     trainer = _pl.Trainer(
    175         callbacks=_create_callbacks(learning_config),
    176         default_root_dir=outdir,
    177         **learning_config["trainer"],
    178     )
    179     with _filter_warnings("ignore", category=_PossibleUserWarning):
    180         trainer.fit(
    181             model,
    182             train_dataloader,
    183             val_dataloader,
    184             **learning_config.get("trainer_fit_kwargs", {}),
    185         )
    186     # Go to best checkpoint
    187     best_checkpoint = trainer.checkpoint_callback.best_model_path
    188     if best_checkpoint != "":
    189         model = _LightningModule.load_from_checkpoint(
    190             trainer.checkpoint_callback.best_model_path,
    191             **_LightningModule.parse_config(model_config),
    192         )
    193     model.cpu()
    194     model.eval()
    195     if make_plots:
    196         _plot(
    197             model,
    198             dataset_validation,
    199             savefig=_Path(outdir, "comparison.png"),
    200             window_start=100_000,
    201             window_end=110_000,
    202             show=False,
    203         )
    204         _plot(model, dataset_validation, show=not no_show)
    205     # Export!
    206     model.net.export(outdir)