commit 09792b75aeca5232dc60aac8e5bf96c9258a72cc
parent facaf95a678a987e47f5753f185ed19df66a6477
Author: sdatkinson <steven@atkinson.mn>
Date: Sun, 21 Feb 2021 15:44:23 -0500
Fix hard-coded save_dir
Diffstat:
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/train.py b/train.py
@@ -30,15 +30,20 @@ check_training_at = np.concatenate(
)
plot_kwargs = {"window": (30000, 40000)}
wav_kwargs = {"window": (0, 5 * 44100)}
-save_dir = "output2"
torch.manual_seed(0)
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-def ensure_save_dir():
+def ensure_save_dir(args):
+ save_dir = (
+ os.path.join(os.path.dirname(__file__), "output")
+ if args.save_dir is None
+ else args.save_dir
+ )
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
+ return save_dir
class Data(object):
@@ -228,7 +233,7 @@ if __name__ == "__main__":
"--minibatches", type=int, default=10, help="Number of minibatches to train for"
)
args = parser.parse_args()
- ensure_save_dir()
+ save_dir = ensure_save_dir(args)
input_length = _get_input_length(args.model_arch) # Ugh, kludge
# Load the data