WaveNetLoader.cpp (2333B)
1 /* 2 ============================================================================== 3 4 WaveNetLoader.cpp 5 Created: 3 Feb 2019 8:55:31pm 6 Author: Eero-Pekka Damskägg 7 8 Modified by keyth72 9 10 ============================================================================== 11 */ 12 13 #include "WaveNetLoader.h" 14 15 WaveNetLoader::WaveNetLoader(var jsonFile) 16 { 17 // Edit this line to point to your binary json file in project resources 18 config = JSON::parse(jsonFile); 19 if (config.hasProperty("level_adjust")) 20 { 21 levelAdjust = config["level_adjust"]; 22 } 23 numChannels = config["residual_channels"]; 24 inputChannels = config["input_channels"]; 25 outputChannels = config["output_channels"]; 26 filterWidth = config["filter_width"]; 27 activation = config["activation"].toString().toStdString(); 28 dilations = readDilations(); 29 } 30 31 WaveNetLoader::WaveNetLoader(var jsonFile, File configFile) 32 { 33 // Edit this line to point to your binary json file in project resources 34 config = JSON::parse(configFile); 35 36 if (config.hasProperty("level_adjust")) 37 { 38 levelAdjust = config["level_adjust"]; 39 } 40 numChannels = config["residual_channels"]; 41 inputChannels = config["input_channels"]; 42 outputChannels = config["output_channels"]; 43 filterWidth = config["filter_width"]; 44 activation = config["activation"].toString().toStdString(); 45 dilations = readDilations(); 46 } 47 48 std::vector<int> WaveNetLoader::readDilations() 49 { 50 std::vector<int> newDilations; 51 if (auto dilationsArray = config.getProperty("dilations", var()).getArray()) 52 { 53 for (int dil : *dilationsArray) 54 newDilations.push_back(dil); 55 } 56 return newDilations; 57 } 58 59 void WaveNetLoader::loadVariables(WaveNet &model) 60 { 61 if (auto variablesArray = config.getProperty("variables", var()).getArray()) 62 { 63 for (auto& variable : *variablesArray) 64 { 65 int layerIdx = variable["layer_idx"]; 66 std::string name = variable["name"].toString().toStdString(); 67 std::vector<float> data; 68 if (auto dataArray = variable.getProperty("data", var()).getArray()) 69 { 70 for (float value : *dataArray) 71 data.push_back(value); 72 } 73 model.setWeight(data, layerIdx, name); 74 } 75 } 76 }