kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

convolution-impl.cpp (10187B)


      1 /** @addtogroup dft
      2  *  @{
      3  */
      4 /*
      5   Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com)
      6   This file is part of KFR
      7 
      8   KFR is free software: you can redistribute it and/or modify
      9   it under the terms of the GNU General Public License as published by
     10   the Free Software Foundation, either version 2 of the License, or
     11   (at your option) any later version.
     12 
     13   KFR is distributed in the hope that it will be useful,
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     16   GNU General Public License for more details.
     17 
     18   You should have received a copy of the GNU General Public License
     19   along with KFR.
     20 
     21   If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
     22   Buying a commercial license is mandatory as soon as you develop commercial activities without
     23   disclosing the source code of your own applications.
     24   See https://www.kfrlib.com for details.
     25  */
     26 #include <kfr/base/simd_expressions.hpp>
     27 #include <kfr/dft/convolution.hpp>
     28 #include <kfr/simd/complex.hpp>
     29 #include <kfr/multiarch.h>
     30 
     31 namespace kfr
     32 {
     33 
     34 template <typename T>
     35 convolve_filter<T>::convolve_filter(size_t size_, size_t block_size_)
     36     : data_size(size_), block_size(next_poweroftwo(block_size_)), fft(2 * block_size), temp(fft.temp_size),
     37       segments((data_size + block_size - 1) / block_size), position(0), ir_segments(segments.size()),
     38       saved_input(block_size), input_position(0), premul(fft.csize()), cscratch(fft.csize()),
     39       scratch1(fft.size), scratch2(fft.size), overlap(block_size)
     40 {
     41 }
     42 
     43 template <typename T>
     44 convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t block_size_)
     45     : convolve_filter(data.size(), block_size_)
     46 {
     47     set_data(data);
     48 }
     49 
     50 template <typename T>
     51 void convolve_filter<T>::set_data(const univector_ref<const T>& data)
     52 {
     53     data_size = data.size();
     54     segments.resize((data_size + block_size - 1) / block_size);
     55     ir_segments.resize(segments.size());
     56     univector<T> input(fft.size);
     57     const ST ifftsize = reciprocal(static_cast<ST>(fft.size));
     58     for (size_t i = 0; i < ir_segments.size(); i++)
     59     {
     60         segments[i].resize(fft.csize());
     61         ir_segments[i].resize(fft.csize());
     62         input = padded(data.slice(i * block_size, block_size));
     63 
     64         fft.execute(ir_segments[i], input, temp);
     65         process(ir_segments[i], ir_segments[i] * ifftsize);
     66     }
     67     reset();
     68 }
     69 
     70 template <typename T>
     71 void convolve_filter<T>::reset()
     72 {
     73     for (auto& segment : segments)
     74     {
     75         process(segment, zeros());
     76     }
     77     position = 0;
     78     process(saved_input, zeros());
     79     input_position = 0;
     80     process(overlap, zeros());
     81 }
     82 
     83 //-------------------------------------------------------------------------------------
     84 
     85 CMT_MULTI_PROTO(namespace impl {
     86     template <typename T>
     87     univector<T> convolve(const univector_ref<const T>&, const univector_ref<const T>&, bool);
     88 
     89     template <typename T>
     90     class convolve_filter : public kfr::convolve_filter<T>
     91     {
     92     public:
     93         void process_buffer_impl(T* output, const T* input, size_t size);
     94     };
     95 })
     96 
     97 inline namespace CMT_ARCH_NAME
     98 {
     99 
    100 namespace impl
    101 {
    102 
    103 template <typename T>
    104 univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2, bool correlate)
    105 {
    106     using ST                          = subtype<T>;
    107     const size_t size                 = next_poweroftwo(src1.size() + src2.size() - 1);
    108     univector<complex<ST>> src1padded = src1;
    109     univector<complex<ST>> src2padded;
    110     if (correlate)
    111         src2padded = reverse(src2);
    112     else
    113         src2padded = src2;
    114     src1padded.resize(size);
    115     src2padded.resize(size);
    116 
    117     dft_plan_ptr<ST> dft = dft_cache::instance().get(ctype_t<ST>(), size);
    118     univector<u8> temp(dft->temp_size);
    119     dft->execute(src1padded, src1padded, temp);
    120     dft->execute(src2padded, src2padded, temp);
    121     src1padded = src1padded * src2padded;
    122     dft->execute(src1padded, src1padded, temp, true);
    123     const ST invsize = reciprocal<ST>(static_cast<ST>(size));
    124     return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
    125 }
    126 template univector<f32> convolve<f32>(const univector_ref<const f32>&, const univector_ref<const f32>&, bool);
    127 template univector<f64> convolve<f64>(const univector_ref<const f64>&, const univector_ref<const f64>&, bool);
    128 template univector<c32> convolve<c32>(const univector_ref<const c32>&, const univector_ref<const c32>&, bool);
    129 template univector<c64> convolve<c64>(const univector_ref<const c64>&, const univector_ref<const c64>&, bool);
    130 
    131 template <typename T>
    132 void convolve_filter<T>::process_buffer_impl(T* output, const T* input, size_t size)
    133 {
    134     // Note that the conditionals in the following algorithm are meant to
    135     // reduce complexity in the common cases of either processing complete
    136     // blocks (processing == block_size) or only one segment.
    137 
    138     // For complex filtering, use CCs pack format to omit special processing in fft_multiply[_accumulate].
    139     const dft_pack_format fft_multiply_pack = this->real_fft ? dft_pack_format::Perm : dft_pack_format::CCs;
    140 
    141     size_t processed = 0;
    142     while (processed < size)
    143     {
    144         // Calculate how many samples to process this iteration.
    145         auto const processing = std::min(size - processed, this->block_size - this->input_position);
    146 
    147         // Prepare input to forward FFT:
    148         if (processing == this->block_size)
    149         {
    150             // No need to work with saved_input.
    151             builtin_memcpy(this->scratch1.data(), input + processed, processing * sizeof(T));
    152         }
    153         else
    154         {
    155             // Append this iteration's input to the saved_input current block.
    156             builtin_memcpy(this->saved_input.data() + this->input_position, input + processed,
    157                            processing * sizeof(T));
    158             builtin_memcpy(this->scratch1.data(), this->saved_input.data(), this->block_size * sizeof(T));
    159         }
    160 
    161         // Forward FFT saved_input block.
    162         this->fft.execute(this->segments[this->position], this->scratch1, this->temp);
    163 
    164         if (this->segments.size() == 1)
    165         {
    166             // Just one segment/block of history.
    167             // Y_k = H * X_k
    168             fft_multiply(this->cscratch, this->ir_segments[0], this->segments[0], fft_multiply_pack);
    169         }
    170         else
    171         {
    172             // More than one segment/block of history so this is more involved.
    173             if (this->input_position == 0)
    174             {
    175                 // At the start of an input block, we premultiply the history from
    176                 // previous input blocks with the extended filter blocks.
    177 
    178                 // Y_(k-i,i) = H_i * X_(k-i)
    179                 // premul += Y_(k-i,i) for i=1,...,N
    180 
    181                 fft_multiply(this->premul, this->ir_segments[1],
    182                              this->segments[(this->position + 1) % this->segments.size()], fft_multiply_pack);
    183                 for (size_t i = 2; i < this->segments.size(); i++)
    184                 {
    185                     const size_t n = (this->position + i) % this->segments.size();
    186                     fft_multiply_accumulate(this->premul, this->ir_segments[i], this->segments[n],
    187                                             fft_multiply_pack);
    188                 }
    189             }
    190             // Y_(k,0) = H_0 * X_k
    191             // Y_k = premul + Y_(k,0)
    192             fft_multiply_accumulate(this->cscratch, this->premul, this->ir_segments[0],
    193                                     this->segments[this->position], fft_multiply_pack);
    194         }
    195         // y_k = IFFT( Y_k )
    196         this->fft.execute(this->scratch2, this->cscratch, this->temp, cinvert_t{});
    197 
    198         // z_k = y_k + overlap
    199         process(make_univector(output + processed, processing),
    200                 this->scratch2.slice(this->input_position, processing) +
    201                     this->overlap.slice(this->input_position, processing));
    202 
    203         this->input_position += processing;
    204         processed += processing;
    205 
    206         // If a whole block was processed, prepare for next block.
    207         if (this->input_position == this->block_size)
    208         {
    209             // Input block k is complete. Move to (k+1)-th input block.
    210             this->input_position = 0;
    211 
    212             // Zero out the saved_input if it will be used in the next iteration.
    213             auto const remaining = size - processed;
    214             if (remaining < this->block_size && remaining > 0)
    215             {
    216                 process(this->saved_input, zeros());
    217             }
    218 
    219             builtin_memcpy(this->overlap.data(), this->scratch2.data() + this->block_size,
    220                            this->block_size * sizeof(T));
    221 
    222             this->position = this->position > 0 ? this->position - 1 : this->segments.size() - 1;
    223         }
    224     }
    225 }
    226 
    227 template class convolve_filter<float>;
    228 template class convolve_filter<double>;
    229 template class convolve_filter<complex<float>>;
    230 template class convolve_filter<complex<double>>;
    231 
    232 } // namespace impl
    233 
    234 } // namespace CMT_ARCH_NAME
    235 
    236 #ifdef CMT_MULTI_NEEDS_GATE
    237 namespace internal_generic
    238 {
    239 template <typename T>
    240 univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2, bool correlate)
    241 {
    242     CMT_MULTI_GATE(return ns::impl::convolve(src1, src2, correlate));
    243 }
    244 
    245 template univector<f32> convolve<f32>(const univector_ref<const f32>&, const univector_ref<const f32>&, bool);
    246 template univector<f64> convolve<f64>(const univector_ref<const f64>&, const univector_ref<const f64>&, bool);
    247 template univector<c32> convolve<c32>(const univector_ref<const c32>&, const univector_ref<const c32>&, bool);
    248 template univector<c64> convolve<c64>(const univector_ref<const c64>&, const univector_ref<const c64>&, bool);
    249 
    250 } // namespace internal_generic
    251 
    252 template <typename T>
    253 void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size)
    254 {
    255     CMT_MULTI_GATE(
    256         reinterpret_cast<ns::impl::convolve_filter<T>*>(this)->process_buffer_impl(output, input, size));
    257 }
    258 
    259 template class convolve_filter<float>;
    260 template class convolve_filter<double>;
    261 template class convolve_filter<complex<float>>;
    262 template class convolve_filter<complex<double>>;
    263 #endif
    264 
    265 } // namespace kfr