predict.py (2320B)
1 import tensorflow as tf 2 from tensorflow.keras.models import load_model 3 import tensorflow.keras.backend as K 4 from tensorflow.keras.optimizers import Adam 5 6 import matplotlib.pyplot as plt 7 import os 8 from scipy import signal 9 from scipy.io import wavfile 10 import numpy as np 11 import matplotlib.pyplot as plt 12 import math 13 import h5py 14 import argparse 15 16 17 def save_wav(name, data): 18 if name.endswith('.wav') == False: 19 name = name + '.wav' 20 wavfile.write(name, 44100, data.flatten().astype(np.float32)) 21 print("Predicted wav file generated: "+name) 22 23 def pre_emphasis_filter(x, coeff=0.95): 24 return tf.concat([x, x - coeff * x], 1) 25 26 def error_to_signal(y_true, y_pred): 27 y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred) 28 return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / K.sum(tf.pow(y_true, 2), axis=0) + 1e-10 29 30 def normalize(data): 31 data_max = max(data) 32 data_min = min(data) 33 data_norm = max(data_max,abs(data_min)) 34 return data / data_norm 35 36 def predict(args): 37 ''' 38 Predicts the output wav given an input wav file, trained GuitarLSTM model, 39 and output wav filename. 40 ''' 41 # Read the input_size from the .h5 model file 42 f = h5py.File(args.model, 'a') 43 input_size = f["info"]["input_size"][0] 44 f.close() 45 46 # Load model from .h5 model file 47 name = args.out_filename 48 model = load_model(args.model, custom_objects={'error_to_signal' : error_to_signal}) 49 50 # Load and Preprocess Data 51 print("Processing input wav..") 52 in_rate, in_data = wavfile.read(args.in_file) 53 54 X = in_data.astype(np.float32).flatten() 55 X = normalize(X).reshape(len(X),1) 56 57 indices = np.arange(input_size) + np.arange(len(X)-input_size+1)[:,np.newaxis] 58 X_ordered = tf.gather(X,indices) 59 60 # Run prediction and save output audio as a wav file 61 print("Running prediction..") 62 prediction = model.predict(X_ordered, batch_size=args.batch_size) 63 save_wav(name, prediction) 64 65 if __name__ == "__main__": 66 parser = argparse.ArgumentParser() 67 parser.add_argument("in_file") 68 parser.add_argument("out_filename") 69 parser.add_argument("model") 70 parser.add_argument("--train_data", default="data.pickle") 71 parser.add_argument("--batch_size", type=int, default=4096) 72 args = parser.parse_args() 73 predict(args)