train.py (8565B)
1 import tensorflow as tf 2 from tensorflow.keras import Sequential 3 from tensorflow.keras.layers import LSTM, Conv1D, Dense 4 from tensorflow.keras.optimizers import Adam 5 from tensorflow.keras.backend import clear_session 6 from tensorflow.keras.activations import tanh, elu, relu 7 from tensorflow.keras.models import load_model 8 import tensorflow.keras.backend as K 9 from tensorflow.keras.utils import Sequence 10 11 import os 12 from scipy import signal 13 from scipy.io import wavfile 14 import numpy as np 15 import matplotlib.pyplot as plt 16 import math 17 import h5py 18 import argparse 19 20 21 def pre_emphasis_filter(x, coeff=0.95): 22 return tf.concat([x, x - coeff * x], 1) 23 24 def error_to_signal(y_true, y_pred): 25 """ 26 Error to signal ratio with pre-emphasis filter: 27 """ 28 y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred) 29 return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / K.sum(tf.pow(y_true, 2), axis=0) + 1e-10 30 31 def save_wav(name, data): 32 wavfile.write(name, 44100, data.flatten().astype(np.float32)) 33 34 def normalize(data): 35 data_max = max(data) 36 data_min = min(data) 37 data_norm = max(data_max,abs(data_min)) 38 return data / data_norm 39 40 def main(args): 41 '''Ths is a similar Tensorflow/Keras implementation of the LSTM model from the paper: 42 "Real-Time Guitar Amplifier Emulation with Deep Learning" 43 https://www.mdpi.com/2076-3417/10/3/766/htm 44 45 Uses a stack of two 1-D Convolutional layers, followed by LSTM, followed by 46 a Dense (fully connected) layer. Three preset training modes are available, 47 with further customization by editing the code. A Sequential tf.keras model 48 is implemented here. 49 50 Note: RAM may be a limiting factor for the parameter "input_size". The wav data 51 is preprocessed and stored in RAM, which improves training speed but quickly runs out 52 if using a large number for "input_size". Reduce this if you are experiencing 53 RAM issues. Also, you can use the "--split_data" option to divide the data by the 54 specified amount and train the model on each set. Doing this will allow for a higher 55 input_size setting (more accurate results). 56 57 --training_mode=0 Speed training (default) 58 --training_mode=1 Accuracy training 59 --training_mode=2 Extended training (set max_epochs as desired, for example 50+) 60 ''' 61 62 name = args.name 63 if not os.path.exists('models/'+name): 64 os.makedirs('models/'+name) 65 else: 66 print("A model folder with the same name already exists. Please choose a new name.") 67 return 68 69 train_mode = args.training_mode # 0 = speed training, 70 # 1 = accuracy training 71 # 2 = extended training 72 batch_size = args.batch_size 73 test_size = 0.2 74 epochs = args.max_epochs 75 input_size = args.input_size 76 77 # TRAINING MODE 78 if train_mode == 0: # Speed Training 79 learning_rate = 0.01 80 conv1d_strides = 12 81 conv1d_filters = 16 82 hidden_units = 36 83 elif train_mode == 1: # Accuracy Training (~10x longer than Speed Training) 84 learning_rate = 0.01 85 conv1d_strides = 4 86 conv1d_filters = 36 87 hidden_units= 64 88 else: # Extended Training (~60x longer than Accuracy Training) 89 learning_rate = 0.0005 90 conv1d_strides = 3 91 conv1d_filters = 36 92 hidden_units= 96 93 94 # Create Sequential Model ########################################### 95 clear_session() 96 model = Sequential() 97 model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same',input_shape=(input_size,1))) 98 model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same')) 99 model.add(LSTM(hidden_units)) 100 model.add(Dense(1, activation=None)) 101 model.compile(optimizer=Adam(learning_rate=learning_rate), loss=error_to_signal, metrics=[error_to_signal]) 102 print(model.summary()) 103 104 # Load and Preprocess Data ########################################### 105 in_rate, in_data = wavfile.read(args.in_file) 106 out_rate, out_data = wavfile.read(args.out_file) 107 108 X_all = in_data.astype(np.float32).flatten() 109 X_all = normalize(X_all).reshape(len(X_all),1) 110 y_all = out_data.astype(np.float32).flatten() 111 y_all = normalize(y_all).reshape(len(y_all),1) 112 113 # If splitting the data for training, do this part 114 if args.split_data > 1: 115 num_split = len(X_all) // args.split_data 116 X = X_all[0:num_split*args.split_data] 117 y = y_all[0:num_split*args.split_data] 118 X_data = np.split(X, args.split_data) 119 y_data = np.split(y, args.split_data) 120 121 # Perform training on each split dataset 122 for i in range(len(X_data)): 123 print("\nTraining on split data " + str(i+1) + "/" +str(len(X_data))) 124 X_split = X_data[i] 125 y_split = y_data[i] 126 127 y_ordered = y_split[input_size-1:] 128 129 indices = np.arange(input_size) + np.arange(len(X_split)-input_size+1)[:,np.newaxis] 130 X_ordered = tf.gather(X_split,indices) 131 132 shuffled_indices = np.random.permutation(len(X_ordered)) 133 X_random = tf.gather(X_ordered,shuffled_indices) 134 y_random = tf.gather(y_ordered, shuffled_indices) 135 136 # Train Model ################################################### 137 model.fit(X_random,y_random, epochs=epochs, batch_size=batch_size, validation_split=0.2) 138 139 140 model.save('models/'+name+'/'+name+'.h5') 141 142 # If training on the full set of input data in one run, do this part 143 else: 144 y_ordered = y_all[input_size-1:] 145 146 indices = np.arange(input_size) + np.arange(len(X_all)-input_size+1)[:,np.newaxis] 147 X_ordered = tf.gather(X_all,indices) 148 149 shuffled_indices = np.random.permutation(len(X_ordered)) 150 X_random = tf.gather(X_ordered,shuffled_indices) 151 y_random = tf.gather(y_ordered, shuffled_indices) 152 153 # Train Model ################################################### 154 model.fit(X_random,y_random, epochs=epochs, batch_size=batch_size, validation_split=test_size) 155 156 model.save('models/'+name+'/'+name+'.h5') 157 158 # Run Prediction ################################################# 159 print("Running prediction..") 160 161 # Get the last 20% of the wav data to run prediction and plot results 162 y_the_rest, y_last_part = np.split(y_all, [int(len(y_all)*.8)]) 163 x_the_rest, x_last_part = np.split(X_all, [int(len(X_all)*.8)]) 164 y_test = y_last_part[input_size-1:] 165 indices = np.arange(input_size) + np.arange(len(x_last_part)-input_size+1)[:,np.newaxis] 166 X_test = tf.gather(x_last_part,indices) 167 168 prediction = model.predict(X_test, batch_size=batch_size) 169 170 save_wav('models/'+name+'/y_pred.wav', prediction) 171 save_wav('models/'+name+'/x_test.wav', x_last_part) 172 save_wav('models/'+name+'/y_test.wav', y_test) 173 174 # Add additional data to the saved model (like input_size) 175 filename = 'models/'+name+'/'+name+'.h5' 176 f = h5py.File(filename, 'a') 177 grp = f.create_group("info") 178 dset = grp.create_dataset("input_size", (1,), dtype='int16') 179 dset[0] = input_size 180 f.close() 181 182 # Create Analysis Plots ########################################### 183 if args.create_plots == 1: 184 print("Plotting results..") 185 import plot 186 187 plot.analyze_pred_vs_actual({ 'output_wav':'models/'+name+'/y_test.wav', 188 'pred_wav':'models/'+name+'/y_pred.wav', 189 'input_wav':'models/'+name+'/x_test.wav', 190 'model_name':name, 191 'show_plots':1, 192 'path':'models/'+name 193 }) 194 195 if __name__ == "__main__": 196 parser = argparse.ArgumentParser() 197 parser.add_argument("in_file") 198 parser.add_argument("out_file") 199 parser.add_argument("name") 200 parser.add_argument("--training_mode", type=int, default=0) 201 parser.add_argument("--batch_size", type=int, default=4096) 202 parser.add_argument("--max_epochs", type=int, default=1) 203 parser.add_argument("--create_plots", type=int, default=1) 204 parser.add_argument("--input_size", type=int, default=100) 205 parser.add_argument("--split_data", type=int, default=1) 206 args = parser.parse_args() 207 main(args)