convolution-impl.cpp (10187B)
1 /** @addtogroup dft 2 * @{ 3 */ 4 /* 5 Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com) 6 This file is part of KFR 7 8 KFR is free software: you can redistribute it and/or modify 9 it under the terms of the GNU General Public License as published by 10 the Free Software Foundation, either version 2 of the License, or 11 (at your option) any later version. 12 13 KFR is distributed in the hope that it will be useful, 14 but WITHOUT ANY WARRANTY; without even the implied warranty of 15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 GNU General Public License for more details. 17 18 You should have received a copy of the GNU General Public License 19 along with KFR. 20 21 If GPL is not suitable for your project, you must purchase a commercial license to use KFR. 22 Buying a commercial license is mandatory as soon as you develop commercial activities without 23 disclosing the source code of your own applications. 24 See https://www.kfrlib.com for details. 25 */ 26 #include <kfr/base/simd_expressions.hpp> 27 #include <kfr/dft/convolution.hpp> 28 #include <kfr/simd/complex.hpp> 29 #include <kfr/multiarch.h> 30 31 namespace kfr 32 { 33 34 template <typename T> 35 convolve_filter<T>::convolve_filter(size_t size_, size_t block_size_) 36 : data_size(size_), block_size(next_poweroftwo(block_size_)), fft(2 * block_size), temp(fft.temp_size), 37 segments((data_size + block_size - 1) / block_size), position(0), ir_segments(segments.size()), 38 saved_input(block_size), input_position(0), premul(fft.csize()), cscratch(fft.csize()), 39 scratch1(fft.size), scratch2(fft.size), overlap(block_size) 40 { 41 } 42 43 template <typename T> 44 convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t block_size_) 45 : convolve_filter(data.size(), block_size_) 46 { 47 set_data(data); 48 } 49 50 template <typename T> 51 void convolve_filter<T>::set_data(const univector_ref<const T>& data) 52 { 53 data_size = data.size(); 54 segments.resize((data_size + block_size - 1) / block_size); 55 ir_segments.resize(segments.size()); 56 univector<T> input(fft.size); 57 const ST ifftsize = reciprocal(static_cast<ST>(fft.size)); 58 for (size_t i = 0; i < ir_segments.size(); i++) 59 { 60 segments[i].resize(fft.csize()); 61 ir_segments[i].resize(fft.csize()); 62 input = padded(data.slice(i * block_size, block_size)); 63 64 fft.execute(ir_segments[i], input, temp); 65 process(ir_segments[i], ir_segments[i] * ifftsize); 66 } 67 reset(); 68 } 69 70 template <typename T> 71 void convolve_filter<T>::reset() 72 { 73 for (auto& segment : segments) 74 { 75 process(segment, zeros()); 76 } 77 position = 0; 78 process(saved_input, zeros()); 79 input_position = 0; 80 process(overlap, zeros()); 81 } 82 83 //------------------------------------------------------------------------------------- 84 85 CMT_MULTI_PROTO(namespace impl { 86 template <typename T> 87 univector<T> convolve(const univector_ref<const T>&, const univector_ref<const T>&, bool); 88 89 template <typename T> 90 class convolve_filter : public kfr::convolve_filter<T> 91 { 92 public: 93 void process_buffer_impl(T* output, const T* input, size_t size); 94 }; 95 }) 96 97 inline namespace CMT_ARCH_NAME 98 { 99 100 namespace impl 101 { 102 103 template <typename T> 104 univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2, bool correlate) 105 { 106 using ST = subtype<T>; 107 const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); 108 univector<complex<ST>> src1padded = src1; 109 univector<complex<ST>> src2padded; 110 if (correlate) 111 src2padded = reverse(src2); 112 else 113 src2padded = src2; 114 src1padded.resize(size); 115 src2padded.resize(size); 116 117 dft_plan_ptr<ST> dft = dft_cache::instance().get(ctype_t<ST>(), size); 118 univector<u8> temp(dft->temp_size); 119 dft->execute(src1padded, src1padded, temp); 120 dft->execute(src2padded, src2padded, temp); 121 src1padded = src1padded * src2padded; 122 dft->execute(src1padded, src1padded, temp, true); 123 const ST invsize = reciprocal<ST>(static_cast<ST>(size)); 124 return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; 125 } 126 template univector<f32> convolve<f32>(const univector_ref<const f32>&, const univector_ref<const f32>&, bool); 127 template univector<f64> convolve<f64>(const univector_ref<const f64>&, const univector_ref<const f64>&, bool); 128 template univector<c32> convolve<c32>(const univector_ref<const c32>&, const univector_ref<const c32>&, bool); 129 template univector<c64> convolve<c64>(const univector_ref<const c64>&, const univector_ref<const c64>&, bool); 130 131 template <typename T> 132 void convolve_filter<T>::process_buffer_impl(T* output, const T* input, size_t size) 133 { 134 // Note that the conditionals in the following algorithm are meant to 135 // reduce complexity in the common cases of either processing complete 136 // blocks (processing == block_size) or only one segment. 137 138 // For complex filtering, use CCs pack format to omit special processing in fft_multiply[_accumulate]. 139 const dft_pack_format fft_multiply_pack = this->real_fft ? dft_pack_format::Perm : dft_pack_format::CCs; 140 141 size_t processed = 0; 142 while (processed < size) 143 { 144 // Calculate how many samples to process this iteration. 145 auto const processing = std::min(size - processed, this->block_size - this->input_position); 146 147 // Prepare input to forward FFT: 148 if (processing == this->block_size) 149 { 150 // No need to work with saved_input. 151 builtin_memcpy(this->scratch1.data(), input + processed, processing * sizeof(T)); 152 } 153 else 154 { 155 // Append this iteration's input to the saved_input current block. 156 builtin_memcpy(this->saved_input.data() + this->input_position, input + processed, 157 processing * sizeof(T)); 158 builtin_memcpy(this->scratch1.data(), this->saved_input.data(), this->block_size * sizeof(T)); 159 } 160 161 // Forward FFT saved_input block. 162 this->fft.execute(this->segments[this->position], this->scratch1, this->temp); 163 164 if (this->segments.size() == 1) 165 { 166 // Just one segment/block of history. 167 // Y_k = H * X_k 168 fft_multiply(this->cscratch, this->ir_segments[0], this->segments[0], fft_multiply_pack); 169 } 170 else 171 { 172 // More than one segment/block of history so this is more involved. 173 if (this->input_position == 0) 174 { 175 // At the start of an input block, we premultiply the history from 176 // previous input blocks with the extended filter blocks. 177 178 // Y_(k-i,i) = H_i * X_(k-i) 179 // premul += Y_(k-i,i) for i=1,...,N 180 181 fft_multiply(this->premul, this->ir_segments[1], 182 this->segments[(this->position + 1) % this->segments.size()], fft_multiply_pack); 183 for (size_t i = 2; i < this->segments.size(); i++) 184 { 185 const size_t n = (this->position + i) % this->segments.size(); 186 fft_multiply_accumulate(this->premul, this->ir_segments[i], this->segments[n], 187 fft_multiply_pack); 188 } 189 } 190 // Y_(k,0) = H_0 * X_k 191 // Y_k = premul + Y_(k,0) 192 fft_multiply_accumulate(this->cscratch, this->premul, this->ir_segments[0], 193 this->segments[this->position], fft_multiply_pack); 194 } 195 // y_k = IFFT( Y_k ) 196 this->fft.execute(this->scratch2, this->cscratch, this->temp, cinvert_t{}); 197 198 // z_k = y_k + overlap 199 process(make_univector(output + processed, processing), 200 this->scratch2.slice(this->input_position, processing) + 201 this->overlap.slice(this->input_position, processing)); 202 203 this->input_position += processing; 204 processed += processing; 205 206 // If a whole block was processed, prepare for next block. 207 if (this->input_position == this->block_size) 208 { 209 // Input block k is complete. Move to (k+1)-th input block. 210 this->input_position = 0; 211 212 // Zero out the saved_input if it will be used in the next iteration. 213 auto const remaining = size - processed; 214 if (remaining < this->block_size && remaining > 0) 215 { 216 process(this->saved_input, zeros()); 217 } 218 219 builtin_memcpy(this->overlap.data(), this->scratch2.data() + this->block_size, 220 this->block_size * sizeof(T)); 221 222 this->position = this->position > 0 ? this->position - 1 : this->segments.size() - 1; 223 } 224 } 225 } 226 227 template class convolve_filter<float>; 228 template class convolve_filter<double>; 229 template class convolve_filter<complex<float>>; 230 template class convolve_filter<complex<double>>; 231 232 } // namespace impl 233 234 } // namespace CMT_ARCH_NAME 235 236 #ifdef CMT_MULTI_NEEDS_GATE 237 namespace internal_generic 238 { 239 template <typename T> 240 univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2, bool correlate) 241 { 242 CMT_MULTI_GATE(return ns::impl::convolve(src1, src2, correlate)); 243 } 244 245 template univector<f32> convolve<f32>(const univector_ref<const f32>&, const univector_ref<const f32>&, bool); 246 template univector<f64> convolve<f64>(const univector_ref<const f64>&, const univector_ref<const f64>&, bool); 247 template univector<c32> convolve<c32>(const univector_ref<const c32>&, const univector_ref<const c32>&, bool); 248 template univector<c64> convolve<c64>(const univector_ref<const c64>&, const univector_ref<const c64>&, bool); 249 250 } // namespace internal_generic 251 252 template <typename T> 253 void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size) 254 { 255 CMT_MULTI_GATE( 256 reinterpret_cast<ns::impl::convolve_filter<T>*>(this)->process_buffer_impl(output, input, size)); 257 } 258 259 template class convolve_filter<float>; 260 template class convolve_filter<double>; 261 template class convolve_filter<complex<float>>; 262 template class convolve_filter<complex<double>>; 263 #endif 264 265 } // namespace kfr