guitar_lstm_colab.ipynb (8988B)
1 { 2 "cells": [ 3 { 4 "cell_type": "code", 5 "execution_count": null, 6 "metadata": { 7 "id": "RF2uyPfxgi8H" 8 }, 9 "outputs": [], 10 "source": [ 11 "# TO USE: \n", 12 "# 1. Upload your input and output wav files to the current directory in Colab\n", 13 "# 2. Edit the USER INPUTS section to point to your wav files, and choose a\n", 14 "# model name, and number of epochs for training. \n", 15 "# 3. Run each section of code. The trained models and output wav files will be \n", 16 "# added to the \"models\" directory.\n", 17 "#\n", 18 "# Note: Tested on CPU and GPU runtimes.\n", 19 "# Note: Uses MSE for loss calculation instead of Error to Signal with Pre-emphasis filter\n", 20 "\n", 21 "import tensorflow as tf\n", 22 "from tensorflow.keras import Sequential\n", 23 "from tensorflow.keras.layers import LSTM, Conv1D, Dense\n", 24 "from tensorflow.keras.optimizers import Adam\n", 25 "from tensorflow.keras.backend import clear_session\n", 26 "from tensorflow.keras.activations import tanh, elu, relu\n", 27 "from tensorflow.keras.models import load_model\n", 28 "import tensorflow.keras.backend as K\n", 29 "from tensorflow.keras.utils import Sequence\n", 30 "\n", 31 "import os\n", 32 "from scipy import signal\n", 33 "from scipy.io import wavfile\n", 34 "import numpy as np\n", 35 "import matplotlib.pyplot as plt\n", 36 "import math\n", 37 "import h5py" 38 ] 39 }, 40 { 41 "cell_type": "code", 42 "execution_count": null, 43 "metadata": { 44 "id": "U22mDBe4jaf2" 45 }, 46 "outputs": [], 47 "source": [ 48 "# EDIT THIS SECTION FOR USER INPUTS\n", 49 "#\n", 50 "name = 'test'\n", 51 "in_file = 'data/ts9_test1_in_FP32.wav'\n", 52 "out_file = 'data/ts9_test1_out_FP32.wav'\n", 53 "epochs = 1\n", 54 "\n", 55 "train_mode = 0 # 0 = speed training, \n", 56 " # 1 = accuracy training \n", 57 " # 2 = extended training\n", 58 "\n", 59 "input_size = 150 \n", 60 "\n", 61 "if not os.path.exists('models/'+name):\n", 62 " os.makedirs('models/'+name)\n", 63 "else:\n", 64 " print(\"A model with the same name already exists. Please choose a new name.\")\n", 65 " exit" 66 ] 67 }, 68 { 69 "cell_type": "code", 70 "execution_count": null, 71 "metadata": {}, 72 "outputs": [], 73 "source": [ 74 "class WindowArray(Sequence):\n", 75 " \n", 76 " def __init__(self, x, y, window_len, batch_size=32):\n", 77 " self.x = x\n", 78 " self.y = y[window_len-1:] \n", 79 " self.window_len = window_len\n", 80 " self.batch_size = batch_size\n", 81 " \n", 82 " def __len__(self):\n", 83 " return (len(self.x) - self.window_len +1) // self.batch_size\n", 84 " \n", 85 " def __getitem__(self, index):\n", 86 " x_out = np.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])\n", 87 " y_out = self.y[index*self.batch_size:(index+1)*self.batch_size]\n", 88 " return x_out, y_out" 89 ] 90 }, 91 { 92 "cell_type": "code", 93 "execution_count": null, 94 "metadata": { 95 "id": "WqI-cGt1jaG2" 96 }, 97 "outputs": [], 98 "source": [ 99 "def pre_emphasis_filter(x, coeff=0.95):\n", 100 " return tf.concat([x, x - coeff * x], 1)\n", 101 " \n", 102 "def error_to_signal(y_true, y_pred): \n", 103 " \"\"\"\n", 104 " Error to signal ratio with pre-emphasis filter:\n", 105 " \"\"\"\n", 106 " y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)\n", 107 " return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / (K.sum(tf.pow(y_true, 2), axis=0) + 1e-10)\n", 108 " \n", 109 "def save_wav(name, data):\n", 110 " wavfile.write(name, 44100, data.flatten().astype(np.float32))\n", 111 "\n", 112 "def normalize(data):\n", 113 " data_max = max(data)\n", 114 " data_min = min(data)\n", 115 " data_norm = max(data_max,abs(data_min))\n", 116 " return data / data_norm\n", 117 "\n", 118 "\n", 119 "'''This is a similar Tensorflow/Keras implementation of the LSTM model from the paper:\n", 120 " \"Real-Time Guitar Amplifier Emulation with Deep Learning\"\n", 121 " https://www.mdpi.com/2076-3417/10/3/766/htm\n", 122 "\n", 123 " Uses a stack of two 1-D Convolutional layers, followed by LSTM, followed by \n", 124 " a Dense (fully connected) layer. Three preset training modes are available, \n", 125 " with further customization by editing the code. A Sequential tf.keras model \n", 126 " is implemented here.\n", 127 "\n", 128 " Note: RAM may be a limiting factor for the parameter \"input_size\". The wav data\n", 129 " is preprocessed and stored in RAM, which improves training speed but quickly runs out\n", 130 " if using a large number for \"input_size\". Reduce this if you are experiencing\n", 131 " RAM issues. \n", 132 " \n", 133 " --training_mode=0 Speed training (default)\n", 134 " --training_mode=1 Accuracy training\n", 135 " --training_mode=2 Extended training (set max_epochs as desired, for example 50+)\n", 136 "'''\n", 137 "\n", 138 "batch_size = 4096 \n", 139 "test_size = 0.2\n", 140 "\n", 141 "if train_mode == 0: # Speed Training\n", 142 " learning_rate = 0.01 \n", 143 " conv1d_strides = 12 \n", 144 " conv1d_filters = 16\n", 145 " hidden_units = 36\n", 146 "elif train_mode == 1: # Accuracy Training (~10x longer than Speed Training)\n", 147 " learning_rate = 0.01 \n", 148 " conv1d_strides = 4\n", 149 " conv1d_filters = 36\n", 150 " hidden_units= 64\n", 151 "else: # Extended Training (~60x longer than Accuracy Training)\n", 152 " learning_rate = 0.0005 \n", 153 " conv1d_strides = 3\n", 154 " conv1d_filters = 36\n", 155 " hidden_units= 96\n", 156 "\n", 157 "\n", 158 "# Create Sequential Model ###########################################\n", 159 "clear_session()\n", 160 "model = Sequential()\n", 161 "model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same',input_shape=(input_size,1)))\n", 162 "model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same'))\n", 163 "model.add(LSTM(hidden_units))\n", 164 "model.add(Dense(1, activation=None))\n", 165 "model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mse', metrics=[error_to_signal])\n", 166 "model.summary()\n", 167 "\n", 168 "# Load and Preprocess Data ###########################################\n", 169 "in_rate, in_data = wavfile.read(in_file)\n", 170 "out_rate, out_data = wavfile.read(out_file)\n", 171 "\n", 172 "X_all = in_data.astype(np.float32).flatten() \n", 173 "X_all = normalize(X_all).reshape(len(X_all),1) \n", 174 "y_all = out_data.astype(np.float32).flatten() \n", 175 "y_all = normalize(y_all).reshape(len(y_all),1)\n", 176 "\n", 177 "train_examples = int(len(X_all)*0.8)\n", 178 "train_arr = WindowArray(X_all[:train_examples], y_all[:train_examples], input_size, batch_size=batch_size)\n", 179 "val_arr = WindowArray(X_all[train_examples:], y_all[train_examples:], input_size, batch_size=batch_size)\n", 180 "\n", 181 "# Train Model ###################################################\n", 182 "history = model.fit(train_arr, validation_data=val_arr, epochs=epochs, shuffle=True) \n", 183 "model.save('models/'+name+'/'+name+'.h5')\n", 184 "\n", 185 "# Run Prediction #################################################\n", 186 "print(\"Running prediction..\")\n", 187 "\n", 188 "# Get the last 20% of the wav data to run prediction and plot results\n", 189 "y_the_rest, y_last_part = np.split(y_all, [int(len(y_all)*.8)])\n", 190 "x_the_rest, x_last_part = np.split(X_all, [int(len(X_all)*.8)])\n", 191 "y_test = y_last_part[input_size-1:] \n", 192 "test_arr = WindowArray(x_last_part, y_last_part, input_size, batch_size = batch_size)\n", 193 "\n", 194 "prediction = model.predict(test_arr)\n", 195 "\n", 196 "save_wav('models/'+name+'/y_pred.wav', prediction)\n", 197 "save_wav('models/'+name+'/x_test.wav', x_last_part)\n", 198 "save_wav('models/'+name+'/y_test.wav', y_test)\n", 199 "\n", 200 "# Add additional data to the saved model (like input_size)\n", 201 "filename = 'models/'+name+'/'+name+'.h5'\n", 202 "f = h5py.File(filename, 'a')\n", 203 "grp = f.create_group(\"info\")\n", 204 "dset = grp.create_dataset(\"input_size\", (1,), dtype='int16')\n", 205 "dset[0] = input_size\n", 206 "f.close()" 207 ] 208 } 209 ], 210 "metadata": { 211 "accelerator": "GPU", 212 "colab": { 213 "collapsed_sections": [], 214 "name": "guitar_lstm_colab.ipynb", 215 "provenance": [] 216 }, 217 "kernelspec": { 218 "display_name": "Python 3", 219 "language": "python", 220 "name": "python3" 221 }, 222 "language_info": { 223 "codemirror_mode": { 224 "name": "ipython", 225 "version": 3 226 }, 227 "file_extension": ".py", 228 "mimetype": "text/x-python", 229 "name": "python", 230 "nbconvert_exporter": "python", 231 "pygments_lexer": "ipython3", 232 "version": "3.7.7" 233 } 234 }, 235 "nbformat": 4, 236 "nbformat_minor": 4 237 }