plot.py (5816B)
1 import matplotlib.pyplot as plt 2 import numpy as np 3 4 from scipy.io import wavfile 5 import sys 6 from scipy import signal 7 import argparse 8 import struct 9 10 11 def error_to_signal(y, y_pred, use_filter=1): 12 """ 13 Error to signal ratio with pre-emphasis filter: 14 https://www.mdpi.com/2076-3417/10/3/766/htm 15 """ 16 if use_filter == 1: 17 y, y_pred = pre_emphasis_filter(y), pre_emphasis_filter(y_pred) 18 return np.sum(np.power(y - y_pred, 2)) / (np.sum(np.power(y, 2) + 1e-10)) 19 20 21 def pre_emphasis_filter(x, coeff=0.95): 22 return np.concatenate([x, np.subtract(x, np.multiply(x, coeff))]) 23 24 25 def read_wave(wav_file): 26 # Extract Audio and framerate from Wav File 27 fs, signal = wavfile.read(wav_file) 28 return signal, fs 29 30 31 def analyze_pred_vs_actual(args): 32 """Generate plots to analyze the predicted signal vs the actual 33 signal. 34 35 Inputs: 36 output_wav : The actual signal, by default will use y_test.wav from the test.py output 37 pred_wav : The predicted signal, by default will use y_pred.wav from the test.py output 38 input_wav : The pre effect signal, by default will use x_test.wav from the test.py output 39 model_name : Used to add the model name to the plot .png filename 40 path : The save path for generated .png figures 41 show_plots : Default is 1 to show plots, 0 to only generate .png files and suppress plots 42 43 1. Plots the two signals 44 2. Calculates Error to signal ratio the same way Pedalnet evauluates the model for training 45 3. Plots the absolute value of pred_signal - actual_signal (to visualize abs error over time) 46 4. Plots the spectrogram of (pred_signal - actual signal) 47 The idea here is to show problem frequencies from the model training 48 """ 49 try: 50 output_wav = args.path + '/' + args.output_wav 51 pred_wav = args.path + '/' + args.pred_wav 52 input_wav = args.path + '/' + args.input_wav 53 model_name = args.model_name 54 show_plots = args.show_plots 55 path = args.path 56 except: 57 output_wav = args['output_wav'] 58 pred_wav = args['pred_wav'] 59 input_wav = args['input_wav'] 60 model_name = args['model_name'] 61 show_plots = args['show_plots'] 62 path = args['path'] 63 64 # Read the input wav file 65 signal3, fs3 = read_wave(input_wav) 66 67 # Read the output wav file 68 signal1, fs = read_wave(output_wav) 69 70 Time = np.linspace(0, len(signal1) / fs, num=len(signal1)) 71 fig, (ax3, ax1, ax2) = plt.subplots(3, sharex=True, figsize=(13, 8)) 72 fig.suptitle("Predicted vs Actual Signal") 73 ax1.plot(Time, signal1, label=output_wav, color="red") 74 75 # Read the predicted wav file 76 signal2, fs2 = read_wave(pred_wav) 77 78 Time2 = np.linspace(0, len(signal2) / fs2, num=len(signal2)) 79 ax1.plot(Time2, signal2, label=pred_wav, color="green") 80 ax1.legend(loc="upper right") 81 ax1.set_xlabel("Time (s)") 82 ax1.set_ylabel("Amplitude") 83 ax1.set_title("Wav File Comparison") 84 ax1.grid("on") 85 86 error_list = [] 87 for s1, s2 in zip(signal1, signal2): 88 error_list.append(abs(s2 - s1)) 89 90 # Calculate error to signal ratio with pre-emphasis filter as 91 # used to train the model 92 e2s = error_to_signal(signal1, signal2) 93 e2s_no_filter = error_to_signal(signal1, signal2, use_filter=0) 94 print("Error to signal (with pre-emphasis filter): ", e2s) 95 print("Error to signal (no pre-emphasis filter): ", e2s_no_filter) 96 fig.suptitle("Predicted vs Actual Signal (error to signal: " + str(round(e2s, 4)) + ")") 97 # Plot signal difference 98 signal_diff = signal2 - signal1 99 ax2.plot(Time2, error_list, label="signal diff", color="blue") 100 ax2.set_xlabel("Time (s)") 101 ax2.set_ylabel("Amplitude") 102 ax2.set_title("abs(pred_signal-actual_signal)") 103 ax2.grid("on") 104 105 # Plot the original signal 106 Time3 = np.linspace(0, len(signal3) / fs3, num=len(signal3)) 107 ax3.plot(Time3, signal3, label=input_wav, color="purple") 108 ax3.legend(loc="upper right") 109 ax3.set_xlabel("Time (s)") 110 ax3.set_ylabel("Amplitude") 111 ax3.set_title("Original Input") 112 ax3.grid("on") 113 114 # Save the plot 115 plt.savefig(path+'/'+model_name + "_signal_comparison_e2s_" + str(round(e2s, 4)) + ".png", bbox_inches="tight") 116 117 # Create a zoomed in plot of 0.01 seconds centered at the max input signal value 118 sig_temp = signal1.tolist() 119 plt.axis( 120 [ 121 Time3[sig_temp.index((max(sig_temp)))] - 0.005, 122 Time3[sig_temp.index((max(sig_temp)))] + 0.005, 123 min(signal2), 124 max(signal2), 125 ] 126 ) 127 plt.savefig(path+'/'+model_name + "_Detail_signal_comparison_e2s_" + str(round(e2s, 4)) + ".png", bbox_inches="tight") 128 129 # Reset the axis 130 plt.axis([0, Time3[-1], min(signal2), max(signal2)]) 131 132 # Plot spectrogram difference 133 # plt.figure(figsize=(12, 8)) 134 # print("Creating spectrogram data..") 135 # frequencies, times, spectrogram = signal.spectrogram(signal_diff, 44100) 136 # plt.pcolormesh(times, frequencies, 10 * np.log10(spectrogram)) 137 # plt.colorbar() 138 # plt.title("Diff Spectrogram") 139 # plt.ylabel("Frequency [Hz]") 140 # plt.xlabel("Time [sec]") 141 # plt.savefig(path+'/'+model_name + "_diff_spectrogram.png", bbox_inches="tight") 142 143 if show_plots == 1: 144 plt.show() 145 146 147 if __name__ == "__main__": 148 parser = argparse.ArgumentParser() 149 #parser.add_argument("--path", default=".") 150 parser.add_argument("--output_wav", default="y_test.wav") 151 parser.add_argument("--pred_wav", default="y_pred.wav") 152 parser.add_argument("--input_wav", default="x_test.wav") 153 parser.add_argument("--model_name", default="plot") 154 parser.add_argument("--path", default="") 155 parser.add_argument("--show_plots", default=1) 156 args = parser.parse_args() 157 analyze_pred_vs_actual(args)