GuitarLSTM

Deep learning models for guitar amp/pedal emulation using LSTM with Keras
Log | Files | Refs | README

predict.py (2320B)


      1 import tensorflow as tf
      2 from tensorflow.keras.models import load_model
      3 import tensorflow.keras.backend as K
      4 from tensorflow.keras.optimizers import Adam
      5 
      6 import matplotlib.pyplot as plt
      7 import os
      8 from scipy import signal
      9 from scipy.io import wavfile
     10 import numpy as np
     11 import matplotlib.pyplot as plt
     12 import math
     13 import h5py
     14 import argparse
     15 
     16 
     17 def save_wav(name, data):
     18     if name.endswith('.wav') == False:
     19         name = name + '.wav'
     20     wavfile.write(name, 44100, data.flatten().astype(np.float32))
     21     print("Predicted wav file generated: "+name)
     22 
     23 def pre_emphasis_filter(x, coeff=0.95):
     24     return tf.concat([x, x - coeff * x], 1)
     25     
     26 def error_to_signal(y_true, y_pred): 
     27     y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)
     28     return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / K.sum(tf.pow(y_true, 2), axis=0) + 1e-10
     29 
     30 def normalize(data):
     31     data_max = max(data)
     32     data_min = min(data)
     33     data_norm = max(data_max,abs(data_min))
     34     return data / data_norm
     35 
     36 def predict(args):
     37     '''
     38     Predicts the output wav given an input wav file, trained GuitarLSTM model, 
     39     and output wav filename.
     40     '''
     41     # Read the input_size from the .h5 model file
     42     f = h5py.File(args.model, 'a')
     43     input_size = f["info"]["input_size"][0]
     44     f.close()
     45 
     46     # Load model from .h5 model file
     47     name = args.out_filename
     48     model = load_model(args.model, custom_objects={'error_to_signal' : error_to_signal})
     49     
     50     # Load and Preprocess Data
     51     print("Processing input wav..")
     52     in_rate, in_data = wavfile.read(args.in_file)
     53 
     54     X = in_data.astype(np.float32).flatten()  
     55     X = normalize(X).reshape(len(X),1)   
     56 
     57     indices = np.arange(input_size) + np.arange(len(X)-input_size+1)[:,np.newaxis] 
     58     X_ordered = tf.gather(X,indices) 
     59 
     60     # Run prediction and save output audio as a wav file
     61     print("Running prediction..")
     62     prediction = model.predict(X_ordered, batch_size=args.batch_size)
     63     save_wav(name, prediction)
     64 
     65 if __name__ == "__main__":
     66     parser = argparse.ArgumentParser()
     67     parser.add_argument("in_file")
     68     parser.add_argument("out_filename")
     69     parser.add_argument("model")
     70     parser.add_argument("--train_data", default="data.pickle")
     71     parser.add_argument("--batch_size", type=int, default=4096)
     72     args = parser.parse_args()
     73     predict(args)