GuitarLSTM

Deep learning models for guitar amp/pedal emulation using LSTM with Keras
Log | Files | Refs | README

plot.py (5816B)


      1 import matplotlib.pyplot as plt
      2 import numpy as np
      3 
      4 from scipy.io import wavfile
      5 import sys
      6 from scipy import signal
      7 import argparse
      8 import struct
      9 
     10 
     11 def error_to_signal(y, y_pred, use_filter=1):
     12     """
     13     Error to signal ratio with pre-emphasis filter:
     14     https://www.mdpi.com/2076-3417/10/3/766/htm
     15     """
     16     if use_filter == 1:
     17         y, y_pred = pre_emphasis_filter(y), pre_emphasis_filter(y_pred)
     18     return np.sum(np.power(y - y_pred, 2)) / (np.sum(np.power(y, 2) + 1e-10))
     19 
     20 
     21 def pre_emphasis_filter(x, coeff=0.95):
     22     return np.concatenate([x, np.subtract(x, np.multiply(x, coeff))])
     23 
     24 
     25 def read_wave(wav_file):
     26     # Extract Audio and framerate from Wav File
     27     fs, signal = wavfile.read(wav_file)
     28     return signal, fs
     29 
     30 
     31 def analyze_pred_vs_actual(args):
     32     """Generate plots to analyze the predicted signal vs the actual
     33     signal.
     34 
     35     Inputs:
     36         output_wav : The actual signal, by default will use y_test.wav from the test.py output
     37         pred_wav : The predicted signal, by default will use y_pred.wav from the test.py output
     38         input_wav : The pre effect signal, by default will use x_test.wav from the test.py output
     39         model_name : Used to add the model name to the plot .png filename
     40         path   :   The save path for generated .png figures
     41         show_plots : Default is 1 to show plots, 0 to only generate .png files and suppress plots
     42 
     43     1. Plots the two signals
     44     2. Calculates Error to signal ratio the same way Pedalnet evauluates the model for training
     45     3. Plots the absolute value of pred_signal - actual_signal  (to visualize abs error over time)
     46     4. Plots the spectrogram of (pred_signal - actual signal)
     47          The idea here is to show problem frequencies from the model training
     48     """
     49     try:
     50         output_wav = args.path + '/' + args.output_wav
     51         pred_wav = args.path + '/' + args.pred_wav
     52         input_wav = args.path + '/' + args.input_wav
     53         model_name = args.model_name
     54         show_plots = args.show_plots
     55         path = args.path
     56     except:
     57         output_wav = args['output_wav']
     58         pred_wav = args['pred_wav']
     59         input_wav = args['input_wav']
     60         model_name = args['model_name']
     61         show_plots = args['show_plots']
     62         path = args['path']
     63 
     64     # Read the input wav file
     65     signal3, fs3 = read_wave(input_wav)
     66 
     67     # Read the output wav file
     68     signal1, fs = read_wave(output_wav)
     69 
     70     Time = np.linspace(0, len(signal1) / fs, num=len(signal1))
     71     fig, (ax3, ax1, ax2) = plt.subplots(3, sharex=True, figsize=(13, 8))
     72     fig.suptitle("Predicted vs Actual Signal")
     73     ax1.plot(Time, signal1, label=output_wav, color="red")
     74 
     75     # Read the predicted wav file
     76     signal2, fs2 = read_wave(pred_wav)
     77 
     78     Time2 = np.linspace(0, len(signal2) / fs2, num=len(signal2))
     79     ax1.plot(Time2, signal2, label=pred_wav, color="green")
     80     ax1.legend(loc="upper right")
     81     ax1.set_xlabel("Time (s)")
     82     ax1.set_ylabel("Amplitude")
     83     ax1.set_title("Wav File Comparison")
     84     ax1.grid("on")
     85 
     86     error_list = []
     87     for s1, s2 in zip(signal1, signal2):
     88         error_list.append(abs(s2 - s1))
     89 
     90     # Calculate error to signal ratio with pre-emphasis filter as
     91     #    used to train the model
     92     e2s = error_to_signal(signal1, signal2)
     93     e2s_no_filter = error_to_signal(signal1, signal2, use_filter=0)
     94     print("Error to signal (with pre-emphasis filter): ", e2s)
     95     print("Error to signal (no pre-emphasis filter): ", e2s_no_filter)
     96     fig.suptitle("Predicted vs Actual Signal (error to signal: " + str(round(e2s, 4)) + ")")
     97     # Plot signal difference
     98     signal_diff = signal2 - signal1
     99     ax2.plot(Time2, error_list, label="signal diff", color="blue")
    100     ax2.set_xlabel("Time (s)")
    101     ax2.set_ylabel("Amplitude")
    102     ax2.set_title("abs(pred_signal-actual_signal)")
    103     ax2.grid("on")
    104 
    105     # Plot the original signal
    106     Time3 = np.linspace(0, len(signal3) / fs3, num=len(signal3))
    107     ax3.plot(Time3, signal3, label=input_wav, color="purple")
    108     ax3.legend(loc="upper right")
    109     ax3.set_xlabel("Time (s)")
    110     ax3.set_ylabel("Amplitude")
    111     ax3.set_title("Original Input")
    112     ax3.grid("on")
    113 
    114     # Save the plot
    115     plt.savefig(path+'/'+model_name + "_signal_comparison_e2s_" + str(round(e2s, 4)) + ".png", bbox_inches="tight")
    116 
    117     # Create a zoomed in plot of 0.01 seconds centered at the max input signal value
    118     sig_temp = signal1.tolist()
    119     plt.axis(
    120         [
    121             Time3[sig_temp.index((max(sig_temp)))] - 0.005,
    122             Time3[sig_temp.index((max(sig_temp)))] + 0.005,
    123             min(signal2),
    124             max(signal2),
    125         ]
    126     )
    127     plt.savefig(path+'/'+model_name + "_Detail_signal_comparison_e2s_" + str(round(e2s, 4)) + ".png", bbox_inches="tight")
    128 
    129     # Reset the axis
    130     plt.axis([0, Time3[-1], min(signal2), max(signal2)])
    131 
    132     # Plot spectrogram difference
    133     # plt.figure(figsize=(12, 8))
    134     # print("Creating spectrogram data..")
    135     # frequencies, times, spectrogram = signal.spectrogram(signal_diff, 44100)
    136     # plt.pcolormesh(times, frequencies, 10 * np.log10(spectrogram))
    137     # plt.colorbar()
    138     # plt.title("Diff Spectrogram")
    139     # plt.ylabel("Frequency [Hz]")
    140     # plt.xlabel("Time [sec]")
    141     # plt.savefig(path+'/'+model_name + "_diff_spectrogram.png", bbox_inches="tight")
    142 
    143     if show_plots == 1:
    144         plt.show()
    145 
    146 
    147 if __name__ == "__main__":
    148     parser = argparse.ArgumentParser()
    149     #parser.add_argument("--path", default=".")
    150     parser.add_argument("--output_wav", default="y_test.wav")
    151     parser.add_argument("--pred_wav", default="y_pred.wav")
    152     parser.add_argument("--input_wav", default="x_test.wav")
    153     parser.add_argument("--model_name", default="plot")
    154     parser.add_argument("--path", default="")
    155     parser.add_argument("--show_plots", default=1)
    156     args = parser.parse_args()
    157     analyze_pred_vs_actual(args)