neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit facaf95a678a987e47f5753f185ed19df66a6477
parent e3f2402d1432065b35b8395e49a2a9d72761e167
Author: sdatkinson <steven@atkinson.mn>
Date:   Sun, 21 Feb 2021 15:32:01 -0500

Update reamp.py for PyTorch

Diffstat:
Mreamp.py | 79+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
1 file changed, 47 insertions(+), 32 deletions(-)

diff --git a/reamp.py b/reamp.py @@ -4,11 +4,13 @@ Reamp a .wav file Assumes 24-bit WAV files """ -from argparse import ArgumentParser -import tensorflow as tf + import os -import wavio +from argparse import ArgumentParser + import matplotlib.pyplot as plt +import torch +import wavio import models @@ -19,41 +21,56 @@ def _sampwidth_to_bits(x): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("architecture", type=str, - help="JSON filename containing NN architecture") - parser.add_argument("checkpoint_dir", type=str, - help="directory holding model checkpoint to use") - parser.add_argument("input_file", type=str, - help="Input .wav file to convert") - parser.add_argument("--output_file", type=str, default=None, - help="Where to save the output") - parser.add_argument("--batch_size", type=int, default=8192, - help="How many samples to process at a time. " + - "Reduce if there are out-of-memory issues.") - parser.add_argument("--target_file", type=str, default=None, - help=".wav file of the true output (if you want to compare)") + parser.add_argument( + "architecture", type=str, help="JSON filename containing NN architecture" + ) + parser.add_argument( + "params", type=str, help="directory holding model checkpoint to use" + ) + parser.add_argument("input_file", type=str, help="Input .wav file to convert") + parser.add_argument( + "--output_file", type=str, default=None, help="Where to save the output" + ) + parser.add_argument( + "--batch_size", + type=int, + default=8192, + help="How many samples to process at a time. " + + "Reduce if there are out-of-memory issues.", + ) + parser.add_argument( + "--target_file", + type=str, + default=None, + help=".wav file of the true output (if you want to compare)", + ) args = parser.parse_args() - if args.output_file is None: - args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav" - + output_file = ( + args.output_file + if args.output_file is not None + else args.input_file.rstrip(".wav") + "_reamped.wav" + ) + if os.path.isfile(args.output_file): print("Output file exists; skip") exit(1) - + x = wavio.read(args.input_file) rate, sampwidth = x.rate, x.sampwidth bits = _sampwidth_to_bits(sampwidth) - x_data = x.data.flatten() / 2 ** (bits - 1) - - with tf.Session() as sess: - model = models.from_json(args.architecture, - checkpoint_path=args.checkpoint_dir) - model.load() - y = model.predict(x_data, batch_size=args.batch_size, verbose=True) - wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none", - sampwidth=sampwidth) - + x_data = torch.Tensor(x.data.flatten() / 2 ** (bits - 1)) + + model = models.from_json(args.architecture) + model.load_state_dict(torch.load(args.params)) + with torch.no_grad(): + y = model.predict_sequence( + x_data, batch_size=args.batch_size, verbose=True + ).numpy() + wavio.write( + output_file, y * 2 ** (bits - 1), rate, scale="none", sampwidth=sampwidth + ) + if args.target_file is not None and os.path.isfile(args.target_file): t = wavio.read(args.target_file) t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1) @@ -63,4 +80,3 @@ if __name__ == "__main__": plt.plot(y) plt.legend(["Input", "Target", "Prediction"]) plt.show() - -\ No newline at end of file