full.py (6772B)
1 # File: full.py 2 # Created Date: Tuesday March 26th 2024 3 # Author: Enrico Schifano (eraz1997@live.it) 4 5 import json as _json 6 from pathlib import Path as _Path 7 from time import time as _time 8 from typing import Optional as _Optional, Union as _Union 9 from warnings import warn as _warn 10 11 import matplotlib.pyplot as _plt 12 import numpy as _np 13 import pytorch_lightning as _pl 14 from pytorch_lightning.utilities.warnings import ( 15 PossibleUserWarning as _PossibleUserWarning, 16 ) 17 import torch as _torch 18 from torch.utils.data import DataLoader as _DataLoader 19 20 from nam.data import ( 21 ConcatDataset as _ConcatDataset, 22 Split as _Split, 23 init_dataset as _init_dataset, 24 ) 25 from nam.train.lightning_module import LightningModule as _LightningModule 26 from nam.util import filter_warnings as _filter_warnings 27 28 _torch.manual_seed(0) 29 30 31 def _rms(x: _Union[_np.ndarray, _torch.Tensor]) -> float: 32 if isinstance(x, _np.ndarray): 33 return _np.sqrt(_np.mean(_np.square(x))) 34 elif isinstance(x, _torch.Tensor): 35 return _torch.sqrt(_torch.mean(_torch.square(x))).item() 36 else: 37 raise TypeError(type(x)) 38 39 40 def _plot( 41 model, 42 ds, 43 savefig=None, 44 show=True, 45 window_start: _Optional[int] = None, 46 window_end: _Optional[int] = None, 47 ): 48 if isinstance(ds, _ConcatDataset): 49 50 def extend_savefig(i, savefig): 51 if savefig is None: 52 return None 53 savefig = _Path(savefig) 54 extension = savefig.name.split(".")[-1] 55 stem = savefig.name[: -len(extension) - 1] 56 return _Path(savefig.parent, f"{stem}_{i}.{extension}") 57 58 for i, ds_i in enumerate(ds.datasets): 59 _plot( 60 model, 61 ds_i, 62 savefig=extend_savefig(i, savefig), 63 show=show and i == len(ds.datasets) - 1, 64 window_start=window_start, 65 window_end=window_end, 66 ) 67 return 68 with _torch.no_grad(): 69 tx = len(ds.x) / 48_000 70 print(f"Run (t={tx:.2f})") 71 t0 = _time() 72 output = model(ds.x).flatten().cpu().numpy() 73 t1 = _time() 74 try: 75 rt = f"{tx / (t1 - t0):.2f}" 76 except ZeroDivisionError as e: 77 rt = "???" 78 print(f"Took {t1 - t0:.2f} ({rt}x)") 79 80 _plt.figure(figsize=(16, 5)) 81 _plt.plot(output[window_start:window_end], label="Prediction") 82 _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") 83 nrmse = _rms(_torch.Tensor(output) - ds.y) / _rms(ds.y) 84 esr = nrmse**2 85 _plt.title(f"ESR={esr:.3f}") 86 _plt.legend() 87 if savefig is not None: 88 _plt.savefig(savefig) 89 if show: 90 _plt.show() 91 92 93 def _create_callbacks(learning_config): 94 """ 95 Checkpointing, essentially 96 """ 97 # Checkpoints should be run every time the validation check is run. 98 # So base it off of learning_config["trainer"]["val_check_interval"] if it's there. 99 validate_inside_epoch = "val_check_interval" in learning_config["trainer"] 100 if validate_inside_epoch: 101 kwargs = { 102 "every_n_train_steps": learning_config["trainer"]["val_check_interval"] 103 } 104 else: 105 kwargs = { 106 "every_n_epochs": learning_config["trainer"].get( 107 "check_val_every_n_epoch", 1 108 ) 109 } 110 111 checkpoint_best = _pl.callbacks.model_checkpoint.ModelCheckpoint( 112 filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}", 113 save_top_k=3, 114 monitor="val_loss", 115 **kwargs, 116 ) 117 118 # return [checkpoint_best, checkpoint_last] 119 # The last epoch that was finished. 120 checkpoint_epoch = _pl.callbacks.model_checkpoint.ModelCheckpoint( 121 filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1 122 ) 123 if not validate_inside_epoch: 124 return [checkpoint_best, checkpoint_epoch] 125 else: 126 # The last validation pass, whether at the end of an epoch or not 127 checkpoint_last = _pl.callbacks.model_checkpoint.ModelCheckpoint( 128 filename="checkpoint_last_{epoch:04d}_{step}", **kwargs 129 ) 130 return [checkpoint_best, checkpoint_last, checkpoint_epoch] 131 132 133 def main( 134 data_config, 135 model_config, 136 learning_config, 137 outdir: _Path, 138 no_show: bool = False, 139 make_plots=True, 140 ): 141 if not outdir.exists(): 142 raise RuntimeError(f"No output location found at {outdir}") 143 # Write 144 for basename, config in ( 145 ("data", data_config), 146 ("model", model_config), 147 ("learning", learning_config), 148 ): 149 with open(_Path(outdir, f"config_{basename}.json"), "w") as fp: 150 _json.dump(config, fp, indent=4) 151 152 model = _LightningModule.init_from_config(model_config) 153 # Add receptive field to data config: 154 data_config["common"] = data_config.get("common", {}) 155 if "nx" in data_config["common"]: 156 _warn( 157 f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}" 158 ) 159 data_config["common"]["nx"] = model.net.receptive_field 160 161 dataset_train = _init_dataset(data_config, _Split.TRAIN) 162 dataset_validation = _init_dataset(data_config, _Split.VALIDATION) 163 if dataset_train.sample_rate != dataset_validation.sample_rate: 164 raise RuntimeError( 165 "Train and validation data loaders have different data set sample rates: " 166 f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}" 167 ) 168 model.net.sample_rate = dataset_train.sample_rate 169 train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"]) 170 val_dataloader = _DataLoader( 171 dataset_validation, **learning_config["val_dataloader"] 172 ) 173 174 trainer = _pl.Trainer( 175 callbacks=_create_callbacks(learning_config), 176 default_root_dir=outdir, 177 **learning_config["trainer"], 178 ) 179 with _filter_warnings("ignore", category=_PossibleUserWarning): 180 trainer.fit( 181 model, 182 train_dataloader, 183 val_dataloader, 184 **learning_config.get("trainer_fit_kwargs", {}), 185 ) 186 # Go to best checkpoint 187 best_checkpoint = trainer.checkpoint_callback.best_model_path 188 if best_checkpoint != "": 189 model = _LightningModule.load_from_checkpoint( 190 trainer.checkpoint_callback.best_model_path, 191 **_LightningModule.parse_config(model_config), 192 ) 193 model.cpu() 194 model.eval() 195 if make_plots: 196 _plot( 197 model, 198 dataset_validation, 199 savefig=_Path(outdir, "comparison.png"), 200 window_start=100_000, 201 window_end=110_000, 202 show=False, 203 ) 204 _plot(model, dataset_validation, show=not no_show) 205 # Export! 206 model.net.export(outdir)