ConvolutionLayer.cpp (2763B)
1 /* 2 ============================================================================== 3 4 ConvolutionLayer.cpp 5 Created: 10 Jan 2019 5:04:39pm 6 Author: Damskägg Eero-Pekka 7 8 ============================================================================== 9 */ 10 11 #include "ConvolutionLayer.h" 12 13 ConvolutionLayer::ConvolutionLayer(size_t inputChannels, 14 size_t outputChannels, 15 int filterWidth, 16 int dilation, 17 bool residual, 18 std::string activationName): 19 conv(inputChannels, 20 Activations::isGated(activationName) ? outputChannels * 2 : outputChannels, 21 filterWidth, 22 dilation), 23 out1x1(outputChannels, outputChannels, 1, 1), 24 residual(residual), 25 usesGating(Activations::isGated(activationName)), 26 activation(Activations::getActivationFuncArray(activationName)) 27 { 28 } 29 30 void ConvolutionLayer::process(float* data, int numSamples) 31 { 32 conv.process(data, numSamples); 33 activation(data, conv.getNumOutputChannels(), numSamples); 34 if (residual) { 35 out1x1.process(data, numSamples); 36 } 37 } 38 39 void ConvolutionLayer::process(float* data, float* skipData, int numSamples) 40 { 41 conv.process(data, numSamples); 42 activation(data, conv.getNumOutputChannels(), numSamples); 43 copySkipData(data, skipData, numSamples); 44 if (residual) { 45 out1x1.process(data, numSamples); 46 } 47 } 48 49 void ConvolutionLayer::copySkipData(float *data, float *skipData, int numSamples) 50 { 51 size_t skipChannels = usesGating ? conv.getNumOutputChannels()/2 : conv.getNumOutputChannels(); 52 for (size_t i = 0; i < (size_t)numSamples*skipChannels; ++i) 53 skipData[i] = data[i]; 54 } 55 56 void ConvolutionLayer::setParams(size_t newInputChannels, size_t newOutputChannels, 57 int newFilterWidth, int newDilation, bool newResidual, 58 std::string newActivationName) 59 { 60 activation = Activations::getActivationFuncArray(newActivationName); 61 usesGating = Activations::isGated(newActivationName); 62 size_t internalChannels = usesGating ? newOutputChannels * 2 : newOutputChannels; 63 conv.setParams(newInputChannels, internalChannels, newFilterWidth, newDilation); 64 out1x1.setParams(newOutputChannels, newOutputChannels, 1, 1); 65 residual = newResidual; 66 } 67 68 void ConvolutionLayer::setWeight(std::vector<float> W, std::string name) 69 { 70 if ((name == "W_conv") || (name == "W")) 71 conv.setWeight(W, "W"); 72 else if ((name == "b_conv") || (name == "b")) 73 conv.setWeight(W, "b"); 74 else if (name == "W_out") 75 out1x1.setWeight(W, "W"); 76 else if (name == "b_out") 77 out1x1.setWeight(W, "b"); 78 }