commit 8ed7f988273d61001202094f9ac8335c6eea878d
parent f181f1e4f52022400390633d0c5dc94aac1d54ae
Author: sdatkinson <steven@atkinson.mn>
Date: Sun, 21 Feb 2021 14:45:09 -0500
Refactor model & training scripts for PyTorch
Diffstat:
M | models.py | | | 158 | +++++++++++++++++++++++++++++-------------------------------------------------- |
M | train.py | | | 259 | ++++++++++++++++++++++++++++++++++++++++++++++++------------------------------- |
2 files changed, 216 insertions(+), 201 deletions(-)
diff --git a/models.py b/models.py
@@ -2,84 +2,35 @@
# File Created: Sunday, 30th December 2018 9:42:29 pm
# Author: Steven Atkinson (steven@atkinson.mn)
-import numpy as np
-import tensorflow as tf
import abc
-from tempfile import mkdtemp
-import os
import json
+import os
+from tempfile import mkdtemp
+
+import numpy as np
+import torch
+import torch.nn as nn
-def from_json(f, n_train=1, checkpoint_path=None):
+def from_json(f):
if isinstance(f, str):
with open(f, "r") as json_file:
f = json.load(json_file)
-
+
if f["type"] == "FullyConnected":
- return FullyConnected(n_train, f["input_length"],
- layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path)
+ return mlp(f["input_length"], 1, layer_sizes=f["layer_sizes"])
else:
- raise NotImplementedError("Model type {} unrecognized".format(
- f["type"]))
+ raise NotImplementedError("Model type {} unrecognized".format(f["type"]))
-class Model(object):
+class Model(nn.Module):
"""
Model parent class
"""
- def __init__(self, n_train, sess=None, checkpoint_path=None):
- """
- Make sure child classes call _build() after this!
- """
- if sess is None:
- sess = tf.get_default_session()
- self.sess = sess
-
- if checkpoint_path is None:
- checkpoint_path = os.path.join(mkdtemp(), "model.ckpt")
- if not os.path.isdir(os.path.dirname(checkpoint_path)):
- os.makedirs(os.path.dirname(checkpoint_path))
- self.checkpoint_path = checkpoint_path
-
- # self._batch_size = batch_size
- self._n_train = n_train
- self.target = tf.placeholder(tf.float32, shape=(None, 1))
- self.prediction = None
- self.total_prediction_loss = None
- self.rmse = None
- self.loss = None # Includes any regularization
-
- # @property
- # def batch_size(self):
- # return self._batch_size
-
- @property
- def n_train(self):
- return self._n_train
-
- def load(self, checkpoint_path=None):
- checkpoint_path = checkpoint_path or self.checkpoint_path
-
- try:
- ckpt = tf.train.get_checkpoint_state(checkpoint_path)
- print("Loading model: {}".format(ckpt.model_checkpoint_path))
- self.saver.restore(self.sess, ckpt.model_checkpoint_path)
- except Exception as e:
- print("Error while attempting to load model: {}".format(e))
-
- @abc.abstractclassmethod
- def predict(self, x):
- """
- A nice function for prediction.
- :param x: input array (length=n)
- :type x: array-like
- :return: (array-like) corresponding predicted outputs (length=n)
- """
- raise NotImplementedError("Implement predict()")
- def save(self, iter, checkpoint_path=None):
- checkpoint_path = checkpoint_path or self.checkpoint_path
- self.saver.save(self.sess, checkpoint_path, global_step=iter)
+ @abc.abstractmethod
+ def predict_sequence(self, x):
+ raise NotImplementedError()
def _build(self):
self.prediction = self._build_prediction()
@@ -90,18 +41,18 @@ class Model(object):
self.saver = tf.train.Saver(tf.global_variables())
def _build_loss(self):
- self.total_prediction_loss = tf.losses.mean_squared_error(self.target,
- self.prediction, weights=self.n_train)
-
+ self.total_prediction_loss = tf.losses.mean_squared_error(
+ self.target, self.prediction, weights=self.n_train
+ )
+
# Don't count this as a loss!
- self.rmse = tf.sqrt(
- self.total_prediction_loss / self.n_train)
+ self.rmse = tf.sqrt(self.total_prediction_loss / self.n_train)
return tf.losses.get_total_loss()
@abc.abstractclassmethod
def _build_prediction(self):
- raise NotImplementedError('Implement prediction for model')
+ raise NotImplementedError("Implement prediction for model")
class Autoregressive(Model):
@@ -109,57 +60,62 @@ class Autoregressive(Model):
Autoregressive models that take in a few of the most recent input samples
and predict the output at the last time point.
"""
- def __init__(self, n_train, input_length, sess=None,
- checkpoint_path=None):
- super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path)
- self._input_length = input_length
- self.x = tf.placeholder(tf.float32, shape=(None, self.input_length))
-
- @property
+
+ @abc.abstractproperty
def input_length(self):
- return self._input_length
+ raise NotImplementedError()
- def predict(self, x, batch_size=None, verbose=False):
+ def predict_sequence(self, x: torch.Tensor, batch_size=None, verbose=False):
"""
Return 1D array of predictions same length as x
"""
- n = x.size
+ n = x.numel()
batch_size = batch_size or n
# Pad x with leading zeros:
- x = np.concatenate((np.zeros(self.input_length - 1), x))
+ x = torch.cat((torch.zeros(self.input_length - 1), x))
i = 0
y = []
while i < n:
- if verbose:
- print("model.predict {}/{}".format(i, n))
this_batch_size = np.min([batch_size, n - i])
# Reshape into a batch:
- x_mtx = np.stack([x[j: j + self.input_length]
- for j in range(i, i + this_batch_size)])
+ x_mtx = torch.stack(
+ [x[j : j + self.input_length] for j in range(i, i + this_batch_size)]
+ )
# Predict and flatten.
- y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
- .flatten())
+ y.append(self(x_mtx).squeeze())
i += this_batch_size
- return np.concatenate(y)
+ return torch.cat(y)
class FullyConnected(Autoregressive):
"""
Autoregressive model taking in a sequence of the most recent inputs, putting
- them through a series of FC layers, and outputting the single output at the
+ them through a series of FC layers, and outputting the single output at the
last time step.
"""
- def __init__(self, n_train, input_length, layer_sizes=(512,),
- sess=None, checkpoint_path=None):
- super().__init__(n_train, input_length, sess=sess,
- checkpoint_path=checkpoint_path)
- self._layer_sizes = layer_sizes
- self._build()
- def _build_prediction(self):
- h = self.x
- for m in self._layer_sizes:
- h = tf.contrib.layers.fully_connected(h, m)
- y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1,
- activation_fn=tf.nn.sigmoid)
- return y
+ def __init__(self, net):
+ super().__init__()
+ self._net = net
+
+ @property
+ def input_length(self):
+ return self._net[0][0].weight.data.shape[1]
+
+ def forward(self, inputs):
+ return self._net(inputs)
+
+
+def mlp(dx, dy, layer_sizes=None):
+ def block(dx, dy, Activation=nn.ReLU):
+ return nn.Sequential(nn.Linear(dx, dy), Activation())
+
+ layer_sizes = [256, 256] if layer_sizes is None else layer_sizes
+
+ net = nn.Sequential()
+ in_features = dx
+ for i, out_features in enumerate(layer_sizes):
+ net.add_module("layer_%i" % i, block(in_features, out_features))
+ in_features = out_features
+ net.add_module("head", block(in_features, dy, Activation=nn.Tanh))
+ return FullyConnected(net)
diff --git a/train.py b/train.py
@@ -4,42 +4,56 @@
"""
Here's a script for training new models.
+
+TODO
+* Device
+* Lightning?
"""
-from argparse import ArgumentParser
-import numpy as np
import abc
-import tensorflow as tf
-import matplotlib.pyplot as plt
-from time import time
-import wavio
import json
import os
+from argparse import ArgumentParser
+from time import time
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+import wavio
import models
# Parameters for training
-check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5])
- for pwr in np.arange(0, 7)])
-plot_kwargs = {
- "window": (30000, 40000)
-}
+check_training_at = np.concatenate(
+ [10 ** pwr * np.array([1, 2, 5]) for pwr in np.arange(0, 7)]
+)
+plot_kwargs = {"window": (30000, 40000)}
wav_kwargs = {"window": (0, 5 * 44100)}
+save_dir = "output"
+
+torch.manual_seed(0)
+
+
+def ensure_save_dir():
+ if not os.path.isdir(save_dir):
+ os.makedirs(save_dir)
class Data(object):
"""
- Object for holding data and spitting out minibatches for training/segments
+ Object for holding data and spitting out minibatches for training/segments
for testing.
"""
+
def __init__(self, fname, input_length, batch_size=None):
xy = np.load(fname)
- self.x = xy[0]
- self.y = xy[1]
- self._n = self.x.size - input_length + 1
+ self.x = torch.Tensor(xy[0])
+ self.y = torch.Tensor(xy[1])
+ self._n = self.x.numel() - input_length + 1
self.input_length = input_length
self.batch_size = batch_size
-
+
@property
def n(self):
return self._n
@@ -50,70 +64,101 @@ class Data(object):
"""
if ilist is None:
n = n or self.batch_size
- ilist = np.random.randint(0, self.n, size=(n,))
- x = np.stack([self.x[i: i + self.input_length] for i in ilist])
- y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \
- [:, np.newaxis]
+ ilist = torch.randint(0, self.n, size=(n,))
+ x = torch.stack([self.x[i : i + self.input_length] for i in ilist])
+ y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None]
return x, y
def sequence(self, start, end):
end += self.input_length - 1
- return self.x[start: end], self.y[start + self.input_length - 1: end]
+ return self.x[start:end], self.y[start + self.input_length - 1 : end]
+
+
+def train_step(model, optimizer, batch):
+ model.train()
+ model.zero_grad()
+ inputs, targets = batch
+ preds = model(inputs)
+ loss = nn.MSELoss()(preds, targets)
+ loss.backward()
+ optimizer.step()
+ return loss.item()
+
+def validation_step(model, batch):
+ with torch.no_grad():
+ model.eval()
+ inputs, targets = batch
+ return nn.MSELoss()(model(inputs), targets).item()
-def train(model, train_data, batch_size=None, n_minibatches=10,
- validation_data=None, plot_at=(), wav_at=(), plot_kwargs={},
- wav_kwargs={}, save_every=100, validate_every=100):
- save_dir = os.path.dirname(model.checkpoint_path)
- sess = model.sess
- opt = tf.train.AdamOptimizer().minimize(model.loss)
- sess.run(tf.global_variables_initializer()) # For opt
+
+def train(
+ model,
+ train_data,
+ batch_size=None,
+ n_minibatches=10,
+ validation_data=None,
+ plot_at=(),
+ wav_at=(),
+ plot_kwargs={},
+ wav_kwargs={},
+ save_every=100,
+ validate_every=100,
+):
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
t_loss_list, v_loss_list = [], []
t0 = time()
for i in range(n_minibatches):
- x, y = train_data.minibatch(batch_size)
- t_loss, _ = sess.run((model.rmse, opt),
- feed_dict={model.x: x, model.target: y})
+ batch = train_data.minibatch(batch_size)
+ t_loss = train_step(model, optimizer, batch)
t_loss_list.append([i, t_loss])
- print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0),
- i + 1, n_minibatches, t_loss))
-
+ print(
+ "t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(
+ int(time() - t0), i + 1, n_minibatches, t_loss
+ )
+ )
+
# Callbacks, basically...
if i + 1 in plot_at:
- plot_predictions(model,
- validation_data if validation_data is not None else train_data,
- title="Minibatch {}".format(i + 1),
- fname="{}/mb_{}.png".format(save_dir, i + 1),
- **plot_kwargs)
+ plot_predictions(
+ model,
+ validation_data if validation_data is not None else train_data,
+ title="Minibatch {}".format(i + 1),
+ fname="{}/mb_{}.png".format(save_dir, i + 1),
+ **plot_kwargs
+ )
if i + 1 in wav_at:
print("Making wav for mb {}".format(i + 1))
- predict(model, validation_data,
- save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
- **wav_kwargs)
+ predict(
+ model,
+ validation_data,
+ save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
+ **wav_kwargs
+ )
if (i + 1) % save_every == 0:
- model.save(iter=i + 1)
+ torch.save(model.state_dict(), os.path.join(save_dir, "model_%i.pt" % i))
if i == 0 or (i + 1) % validate_every == 0:
- v_loss, _ = sess.run((model.rmse, opt),
- feed_dict={model.x: x, model.target: y})
+ v_loss = validation_step(model, batch)
print("VLoss={:8}".format(v_loss))
v_loss_list.append([i, v_loss])
-
+
# After training loop...
if validation_data is not None:
- x, y = validation_data.minibatch(train_data.batch_size)
- v_loss = sess.run(model.rmse,
- feed_dict={model.x: x, model.target: y})
+ batch = validation_data.minibatch(train_data.batch_size)
+ v_loss = validation_step(model, batch)
print("Validation loss={}".format(v_loss))
+ torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))
return np.array(t_loss_list).T, np.array(v_loss_list).T
def plot_predictions(model, data, title=None, fname=None, window=None):
- x, y, t = predict(model, data, window=window)
- plt.figure(figsize=(12, 4))
- plt.plot(x)
- plt.plot(t)
- plt.plot(y)
- plt.legend(('Input', 'Target', 'Prediction'))
+ with torch.no_grad():
+ x, y, t = predict(model, data, window=window)
+ plt.figure(figsize=(12, 4))
+ plt.plot(x)
+ plt.plot(t)
+ plt.plot(y)
+ plt.legend(("Input", "Target", "Prediction"))
if title is not None:
plt.title(title)
if fname is not None:
@@ -133,21 +178,27 @@ def plot_loss(t_loss, v_loss, fname):
plt.legend(("Training", "Validation"))
plt.savefig(fname)
plt.close()
-
+
def predict(model, data, window=None, save_wav_file=None):
- x, t = data.x, data.y
- if window is not None:
- x, t = x[window[0]: window[1]], t[window[0]: window[1]]
- y = model.predict(x).flatten()
+ with torch.no_grad():
+ x, t = data.x, data.y
+ if window is not None:
+ x, t = x[window[0] : window[1]], t[window[0] : window[1]]
+ y = model.predict_sequence(x).squeeze()
- if save_wav_file is not None:
- rate = 44100 # TODO from data
- sampwidth = 3 # 24-bit
- wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none",
- sampwidth=sampwidth)
+ if save_wav_file is not None:
+ rate = 44100 # TODO from data
+ sampwidth = 3 # 24-bit
+ wavio.write(
+ save_wav_file,
+ y.numpy() * 2 ** 23,
+ rate,
+ scale="none",
+ sampwidth=sampwidth,
+ )
- return x, y, t
+ return x, y, t
def _get_input_length(archfile):
@@ -156,43 +207,51 @@ def _get_input_length(archfile):
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument("model_arch", type=str,
- help="JSON containing model architecture")
- parser.add_argument("train_data", type=str,
- help="Filename for training data")
- parser.add_argument("validation_data", type=str,
- help="Filename for validation data")
- parser.add_argument("--save_dir", type=str, default=None,
- help="Where to save the run data (checkpoints, prediction...)")
- parser.add_argument("--batch_size", type=str, default=4096,
- help="Number of data per minibatch")
- parser.add_argument("--minibatches", type=int, default=10,
- help="Number of minibatches to train for")
+ parser.add_argument(
+ "model_arch", type=str, help="JSON containing model architecture"
+ )
+ parser.add_argument("train_data", type=str, help="Filename for training data")
+ parser.add_argument(
+ "validation_data", type=str, help="Filename for validation data"
+ )
+ parser.add_argument(
+ "--save_dir",
+ type=str,
+ default=None,
+ help="Where to save the run data (checkpoints, prediction...)",
+ )
+ parser.add_argument(
+ "--batch_size", type=str, default=4096, help="Number of data per minibatch"
+ )
+ parser.add_argument(
+ "--minibatches", type=int, default=10, help="Number of minibatches to train for"
+ )
args = parser.parse_args()
+ ensure_save_dir()
input_length = _get_input_length(args.model_arch) # Ugh, kludge
# Load the data
- train_data = Data(args.train_data, input_length,
- batch_size=args.batch_size)
+ train_data = Data(args.train_data, input_length, batch_size=args.batch_size)
validate_data = Data(args.validation_data, input_length)
-
+
# Training
- with tf.Session() as sess:
- model = models.from_json(args.model_arch, train_data.n,
- checkpoint_path=os.path.join(args.save_dir, "model.ckpt"))
- t_loss_list, v_loss_list = train(
- model,
- train_data,
- validation_data=validate_data,
- n_minibatches=args.minibatches,
- plot_at=check_training_at,
- wav_at=check_training_at,
- plot_kwargs=plot_kwargs,
- wav_kwargs=wav_kwargs)
- plot_predictions(model, validate_data, window=(0, 44100))
- print("Predict the full output")
- predict(model, validate_data,
- save_wav_file="{}/predict.wav".format(
- os.path.dirname(model.checkpoint_path)))
-
- plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))
+ model = models.from_json(args.model_arch)
+ t_loss_list, v_loss_list = train(
+ model,
+ train_data,
+ validation_data=validate_data,
+ n_minibatches=args.minibatches,
+ plot_at=check_training_at,
+ wav_at=check_training_at,
+ plot_kwargs=plot_kwargs,
+ wav_kwargs=wav_kwargs,
+ )
+ plot_predictions(model, validate_data, window=(0, 44100))
+ print("Predict the full output")
+ predict(
+ model,
+ validate_data,
+ save_wav_file="{}/predict.wav".format(save_dir),
+ )
+
+ plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(save_dir))