GuitarLSTM

Deep learning models for guitar amp/pedal emulation using LSTM with Keras
Log | Files | Refs | README

train.py (8565B)


      1 import tensorflow as tf
      2 from tensorflow.keras import Sequential
      3 from tensorflow.keras.layers import LSTM, Conv1D, Dense
      4 from tensorflow.keras.optimizers import Adam
      5 from tensorflow.keras.backend import clear_session
      6 from tensorflow.keras.activations import tanh, elu, relu
      7 from tensorflow.keras.models import load_model
      8 import tensorflow.keras.backend as K
      9 from tensorflow.keras.utils import Sequence
     10 
     11 import os
     12 from scipy import signal
     13 from scipy.io import wavfile
     14 import numpy as np
     15 import matplotlib.pyplot as plt
     16 import math
     17 import h5py
     18 import argparse
     19 
     20    
     21 def pre_emphasis_filter(x, coeff=0.95):
     22     return tf.concat([x, x - coeff * x], 1)
     23     
     24 def error_to_signal(y_true, y_pred): 
     25     """
     26     Error to signal ratio with pre-emphasis filter:
     27     """
     28     y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)
     29     return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / K.sum(tf.pow(y_true, 2), axis=0) + 1e-10
     30     
     31 def save_wav(name, data):
     32     wavfile.write(name, 44100, data.flatten().astype(np.float32))
     33 
     34 def normalize(data):
     35     data_max = max(data)
     36     data_min = min(data)
     37     data_norm = max(data_max,abs(data_min))
     38     return data / data_norm
     39 
     40 def main(args):
     41     '''Ths is a similar Tensorflow/Keras implementation of the LSTM model from the paper:
     42         "Real-Time Guitar Amplifier Emulation with Deep Learning"
     43         https://www.mdpi.com/2076-3417/10/3/766/htm
     44 
     45         Uses a stack of two 1-D Convolutional layers, followed by LSTM, followed by 
     46         a Dense (fully connected) layer. Three preset training modes are available, 
     47         with further customization by editing the code. A Sequential tf.keras model 
     48         is implemented here.
     49 
     50         Note: RAM may be a limiting factor for the parameter "input_size". The wav data
     51           is preprocessed and stored in RAM, which improves training speed but quickly runs out
     52           if using a large number for "input_size".  Reduce this if you are experiencing
     53           RAM issues. Also, you can use the "--split_data" option to divide the data by the
     54           specified amount and train the model on each set. Doing this will allow for a higher
     55           input_size setting (more accurate results).
     56         
     57         --training_mode=0   Speed training (default)
     58         --training_mode=1   Accuracy training
     59         --training_mode=2   Extended training (set max_epochs as desired, for example 50+)
     60     '''
     61 
     62     name = args.name
     63     if not os.path.exists('models/'+name):
     64         os.makedirs('models/'+name)
     65     else:
     66         print("A model folder with the same name already exists. Please choose a new name.")
     67         return
     68 
     69     train_mode = args.training_mode     # 0 = speed training, 
     70                                         # 1 = accuracy training 
     71                                         # 2 = extended training
     72     batch_size = args.batch_size 
     73     test_size = 0.2
     74     epochs = args.max_epochs
     75     input_size = args.input_size
     76 
     77     # TRAINING MODE
     78     if train_mode == 0:         # Speed Training
     79         learning_rate = 0.01 
     80         conv1d_strides = 12    
     81         conv1d_filters = 16
     82         hidden_units = 36
     83     elif train_mode == 1:       # Accuracy Training (~10x longer than Speed Training)
     84         learning_rate = 0.01 
     85         conv1d_strides = 4
     86         conv1d_filters = 36
     87         hidden_units= 64
     88     else:                       # Extended Training (~60x longer than Accuracy Training)
     89         learning_rate = 0.0005 
     90         conv1d_strides = 3
     91         conv1d_filters = 36
     92         hidden_units= 96
     93 
     94     # Create Sequential Model ###########################################
     95     clear_session()
     96     model = Sequential()
     97     model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same',input_shape=(input_size,1)))
     98     model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same'))
     99     model.add(LSTM(hidden_units))
    100     model.add(Dense(1, activation=None))
    101     model.compile(optimizer=Adam(learning_rate=learning_rate), loss=error_to_signal, metrics=[error_to_signal])
    102     print(model.summary())
    103 
    104     # Load and Preprocess Data ###########################################
    105     in_rate, in_data = wavfile.read(args.in_file)
    106     out_rate, out_data = wavfile.read(args.out_file)
    107 
    108     X_all = in_data.astype(np.float32).flatten()  
    109     X_all = normalize(X_all).reshape(len(X_all),1)   
    110     y_all = out_data.astype(np.float32).flatten() 
    111     y_all = normalize(y_all).reshape(len(y_all),1)   
    112 
    113     # If splitting the data for training, do this part
    114     if args.split_data > 1:
    115         num_split = len(X_all) // args.split_data
    116         X = X_all[0:num_split*args.split_data]
    117         y = y_all[0:num_split*args.split_data]
    118         X_data = np.split(X, args.split_data)
    119         y_data = np.split(y, args.split_data)
    120 
    121         # Perform training on each split dataset
    122         for i in range(len(X_data)):
    123             print("\nTraining on split data " + str(i+1) + "/" +str(len(X_data)))
    124             X_split = X_data[i]
    125             y_split = y_data[i]
    126 
    127             y_ordered = y_split[input_size-1:] 
    128 
    129             indices = np.arange(input_size) + np.arange(len(X_split)-input_size+1)[:,np.newaxis] 
    130             X_ordered = tf.gather(X_split,indices) 
    131 
    132             shuffled_indices = np.random.permutation(len(X_ordered)) 
    133             X_random = tf.gather(X_ordered,shuffled_indices)
    134             y_random = tf.gather(y_ordered, shuffled_indices)
    135 
    136             # Train Model ###################################################
    137             model.fit(X_random,y_random, epochs=epochs, batch_size=batch_size, validation_split=0.2)  
    138  
    139 
    140         model.save('models/'+name+'/'+name+'.h5')
    141 
    142     # If training on the full set of input data in one run, do this part
    143     else:
    144         y_ordered = y_all[input_size-1:] 
    145 
    146         indices = np.arange(input_size) + np.arange(len(X_all)-input_size+1)[:,np.newaxis] 
    147         X_ordered = tf.gather(X_all,indices) 
    148 
    149         shuffled_indices = np.random.permutation(len(X_ordered)) 
    150         X_random = tf.gather(X_ordered,shuffled_indices)
    151         y_random = tf.gather(y_ordered, shuffled_indices)
    152 
    153         # Train Model ###################################################
    154         model.fit(X_random,y_random, epochs=epochs, batch_size=batch_size, validation_split=test_size)    
    155 
    156         model.save('models/'+name+'/'+name+'.h5')
    157 
    158     # Run Prediction #################################################
    159     print("Running prediction..")
    160 
    161     # Get the last 20% of the wav data to run prediction and plot results
    162     y_the_rest, y_last_part = np.split(y_all, [int(len(y_all)*.8)])
    163     x_the_rest, x_last_part = np.split(X_all, [int(len(X_all)*.8)])
    164     y_test = y_last_part[input_size-1:] 
    165     indices = np.arange(input_size) + np.arange(len(x_last_part)-input_size+1)[:,np.newaxis] 
    166     X_test = tf.gather(x_last_part,indices) 
    167 
    168     prediction = model.predict(X_test, batch_size=batch_size)
    169 
    170     save_wav('models/'+name+'/y_pred.wav', prediction)
    171     save_wav('models/'+name+'/x_test.wav', x_last_part)
    172     save_wav('models/'+name+'/y_test.wav', y_test)
    173 
    174     # Add additional data to the saved model (like input_size)
    175     filename = 'models/'+name+'/'+name+'.h5'
    176     f = h5py.File(filename, 'a')
    177     grp = f.create_group("info")
    178     dset = grp.create_dataset("input_size", (1,), dtype='int16')
    179     dset[0] = input_size
    180     f.close()
    181 
    182     # Create Analysis Plots ###########################################
    183     if args.create_plots == 1:
    184         print("Plotting results..")
    185         import plot
    186 
    187         plot.analyze_pred_vs_actual({   'output_wav':'models/'+name+'/y_test.wav',
    188                                             'pred_wav':'models/'+name+'/y_pred.wav', 
    189                                             'input_wav':'models/'+name+'/x_test.wav',
    190                                             'model_name':name,
    191                                             'show_plots':1,
    192                                             'path':'models/'+name
    193                                         })
    194 
    195 if __name__ == "__main__":
    196     parser = argparse.ArgumentParser()
    197     parser.add_argument("in_file")
    198     parser.add_argument("out_file")
    199     parser.add_argument("name")
    200     parser.add_argument("--training_mode", type=int, default=0)
    201     parser.add_argument("--batch_size", type=int, default=4096)
    202     parser.add_argument("--max_epochs", type=int, default=1)
    203     parser.add_argument("--create_plots", type=int, default=1)
    204     parser.add_argument("--input_size", type=int, default=100)
    205     parser.add_argument("--split_data", type=int, default=1)
    206     args = parser.parse_args()
    207     main(args)