commit 3d70b8ebff659b8adce726f214c33044c47cc25a
parent a790c8736525f942abf5487357514447a090a4f3
Author: jmiller656 <joshxmiller656@gmail.com>
Date: Tue, 2 Feb 2021 20:31:51 -0500
Use keras Sequence class for data loader
Diffstat:
1 file changed, 31 insertions(+), 49 deletions(-)
diff --git a/guitar_lstm_colab.ipynb b/guitar_lstm_colab.ipynb
@@ -57,22 +57,7 @@
" # 1 = accuracy training \n",
" # 2 = extended training\n",
"\n",
- "input_size = 150 # !!!IMPORTANT !!!: The input_size is set at 150 for Colab notebook. \n",
- " # A higher setting may result in crashing due to\n",
- " # memory limitation of 8GB for the free version\n",
- " # of Colab. This setting limits the accuracy of\n",
- " # the training, especially for complex guitar signals\n",
- " # such as high distortion.\n",
- " # \n",
- " # !!!IMPORTANT!!!: You will most likely need to cycle the runtime to \n",
- " # free up RAM between training sessions.\n",
- " #\n",
- " # Increase the \"split_data\" parameter to reduce the RAM used and\n",
- " # still allow for a higher \"input_size\" setting. \n",
- " #\n",
- " # Future dev note: Using a custom dataloader may be a good\n",
- " # workaround for this limitation, at the cost\n",
- " # of slower training.\n",
+ "input_size = 150 \n",
"\n",
"if not os.path.exists('models/'+name):\n",
" os.makedirs('models/'+name)\n",
@@ -84,6 +69,29 @@
{
"cell_type": "code",
"execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class WindowArray(Sequence):\n",
+ " \n",
+ " def __init__(self, x, y, window_len, batch_size=32):\n",
+ " self.x = x\n",
+ " self.y = y[window_len-1:] \n",
+ " self.window_len = window_len\n",
+ " self.batch_size = batch_size\n",
+ " \n",
+ " def __len__(self):\n",
+ " return (len(self.x) - self.window_len +1) // self.batch_size\n",
+ " \n",
+ " def __getitem__(self, index):\n",
+ " x_out = np.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])\n",
+ " y_out = self.y[index*self.batch_size:(index+1)*self.batch_size]\n",
+ " return x_out, y_out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {
"id": "WqI-cGt1jaG2"
},
@@ -97,7 +105,7 @@
" Error to signal ratio with pre-emphasis filter:\n",
" \"\"\"\n",
" y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)\n",
- " return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / K.sum(tf.pow(y_true, 2), axis=0) + 1e-10\n",
+ " return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / (K.sum(tf.pow(y_true, 2), axis=0) + 1e-10)\n",
" \n",
"def save_wav(name, data):\n",
" wavfile.write(name, 44100, data.flatten().astype(np.float32))\n",
@@ -167,12 +175,12 @@
"y_all = out_data.astype(np.float32).flatten() \n",
"y_all = normalize(y_all).reshape(len(y_all),1)\n",
"\n",
- "y_ordered = y_all[input_size-1:] \n",
- "indices = np.arange(input_size) + np.arange(len(X_all)-input_size+1)[:,np.newaxis] \n",
- "x_ordered = np.take(X_all, indices)[:,:, np.newaxis]\n",
+ "train_examples = int(len(X_all)*0.8)\n",
+ "train_arr = WindowArray(X_all[:train_examples], y_all[:train_examples], input_size, batch_size=batch_size)\n",
+ "val_arr = WindowArray(X_all[train_examples:], y_all[train_examples:], input_size, batch_size=batch_size)\n",
"\n",
"# Train Model ###################################################\n",
- "model.fit(x_ordered,y_ordered, epochs=epochs, batch_size=batch_size, validation_split=test_size, shuffle=True) \n",
+ "history = model.fit(train_arr, validation_data=val_arr, epochs=epochs, shuffle=True) \n",
"model.save('models/'+name+'/'+name+'.h5')\n",
"\n",
"# Run Prediction #################################################\n",
@@ -182,10 +190,9 @@
"y_the_rest, y_last_part = np.split(y_all, [int(len(y_all)*.8)])\n",
"x_the_rest, x_last_part = np.split(X_all, [int(len(X_all)*.8)])\n",
"y_test = y_last_part[input_size-1:] \n",
- "indices = np.arange(input_size) + np.arange(len(x_last_part)-input_size+1)[:,np.newaxis] \n",
- "X_test = np.take(x_last_part,indices)[:, :, np.newaxis]\n",
+ "test_arr = WindowArray(x_last_part, y_last_part, input_size, batch_size = batch_size)\n",
"\n",
- "prediction = model.predict(X_test, batch_size=batch_size)\n",
+ "prediction = model.predict(test_arr)\n",
"\n",
"save_wav('models/'+name+'/y_pred.wav', prediction)\n",
"save_wav('models/'+name+'/x_test.wav', x_last_part)\n",
@@ -199,31 +206,6 @@
"dset[0] = input_size\n",
"f.close()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import IPython.display as ipd"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ipd.Audio('models/'+name+'/y_pred.wav')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {