kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit f3bc9bb02773324731c59ea811b8ca4c616b0212
parent 8a969d06cab0633552d6014a65ee5aa0abb0640c
Author: Stephen Larew <stephen@slarew.net>
Date:   Thu, 20 Feb 2020 12:28:41 -0800

add complex support to convolve_filter

Diffstat:
Mdocs/docs/convolution.md | 45+++++++++++++++++++++++++++++++++++++++++++++
Mexamples/CMakeLists.txt | 2++
Aexamples/ccv.cpp | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minclude/kfr/dft/convolution.hpp | 38+++++++++++++++++++++++++++++---------
Minclude/kfr/dft/impl/convolution-impl.cpp | 222++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------
Mtests/dft_test.cpp | 36+++++++++++++++++++++++++++++++++++-
6 files changed, 356 insertions(+), 58 deletions(-)

diff --git a/docs/docs/convolution.md b/docs/docs/convolution.md @@ -68,3 +68,48 @@ reverb_RR.apply(tmp4, audio[1]); audio[0] = tmp1 + tmp2; audio[1] = tmp3 + tmp4; ``` + +# Implementation Details + +The convolution filter efficiently computes the convolution of two signals. +The efficiency is achieved by employing the FFT and the circular convolution +theorem. The algorithm is a variant of the [overlap-add +method](https://en.wikipedia.org/wiki/Overlap%E2%80%93add_method). It works on +a fixed block size \(B\) for arbitrarily long input signals. Thus, the +convolution of a streaming input signal with a long FIR filter \(h[n]\) (where +the length of \(h[n]\) may exceed the block size \(B\)) is computed with a +fixed complexity \(O(B \log B)\). + +More formally, the convolution filter computes \(y[n] = (x * h)[n]\) by +partitioning the input \(x\) and filter \(h\) into blocks and applies the +overlap-add method. Let \(x[n]\) be an input signal of arbitrary length. Often, +\(x[n]\) is a streaming input with unknown length. Let \(h[n]\) be an FIR +filter with \(M\) taps. The convolution filter works on a fixed block size +\(B=2^b\). + +First, the input and filter are windowed and shifted to the origin to give the +\(k\)-th block input \(x_k[n] = x[n + kB] , n=\{0,1,\ldots,B-1\},\forall +k\in\mathbb{Z}\) and \(j\)-th block filter \(h_j[n] = h[n + jB] , +n=\{0,1,\ldots,B-1\},j=\{0,1,\ldots,\lfloor M/B \rfloor\}\). The convolution +\(y_{k,j}[n] = (x_k * h_j)[n]\) is efficiently computed with length \(2B\) FFTs +as +\[ +y_{k,j}[n] = \mathrm{IFFT}(\mathrm{FFT}(x_k[n])\cdot\mathrm{FFT}(h_j[n])) +. +\] + +The overlap-add method sums the "overlap" from the previous block with the current block. +To complete the \(k\)-th block, the contribution of all blocks of the filter +are summed together to give +\[ y_{k}[n] = \sum_j y_{k-j,j}[n] . \] +The final convolution is then the sum of the shifted blocks +\[ y[n] = \sum_k y_{k}[n - kB] . \] +Note that \(y_k[n]\) is of length \(2B\) so its second half overlaps and adds +into the first half of the \(y_{k+1}[n]\) block. + +## Maximum efficiency criterion + +To avoid excess computation or maximize throughput, the convolution filter +should be given input samples in multiples of the block size \(B\). Otherwise, +the FFT of a block is computed twice as many times as would be necessary and +hence throughput is reduced. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt @@ -50,4 +50,6 @@ if (ENABLE_DFT) target_link_libraries(dft kfr_multidft) target_compile_definitions(dft PRIVATE -DKFR_DFT_MULTI=1) endif () + add_executable(ccv ccv.cpp) + target_link_libraries(ccv kfr kfr_dft use_arch) endif () diff --git a/examples/ccv.cpp b/examples/ccv.cpp @@ -0,0 +1,71 @@ +/* + * ccv, part of KFR (https://www.kfr.dev) + * Copyright (C) 2019 D Levin + * See LICENSE.txt for details + */ + +// Complex convolution filter examples + +#define CMT_BASETYPE_F32 + +#include <chrono> +#include <kfr/base.hpp> +#include <kfr/dft.hpp> +#include <kfr/dsp.hpp> + +using namespace kfr; + +int main() +{ + println(library_version()); + + // low-pass filter + univector<fbase, 1023> taps127; + expression_pointer<fbase> kaiser = to_pointer(window_kaiser(taps127.size(), 3.0)); + fir_lowpass(taps127, 0.2, kaiser, true); + + // Create filters. + size_t const block_size = 256; + convolve_filter<complex<fbase>> conv_filter_complex(univector<complex<fbase>>(make_complex(taps127, zeros())), + block_size); + convolve_filter<fbase> conv_filter_real(taps127, block_size); + + // Create noise to filter. + auto const size = 1024 * 100 + 33; // not a multiple of block_size + univector<complex<fbase>> cnoise = + make_complex(truncate(gen_random_range(random_bit_generator{ 1, 2, 3, 4 }, -1.f, +1.f), size), + truncate(gen_random_range(random_bit_generator{ 3, 4, 9, 8 }, -1.f, +1.f), size)); + univector<fbase> noise = + truncate(gen_random_range(random_bit_generator{ 3, 4, 9, 8 }, -1.f, +1.f), size); + + // Filter results. + univector<complex<fbase>> filtered_cnoise_ccv(size), filtered_cnoise_fir(size); + univector<fbase> filtered_noise_ccv(size), filtered_noise_fir(size); + + // Complex filtering (time and compare). + auto tic = std::chrono::high_resolution_clock::now(); + conv_filter_complex.apply(filtered_cnoise_ccv, cnoise); + auto toc = std::chrono::high_resolution_clock::now(); + auto const ccv_time_complex = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic); + tic = toc; + filtered_cnoise_fir = kfr::fir(cnoise, taps127); + toc = std::chrono::high_resolution_clock::now(); + auto const fir_time_complex = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic); + auto const cdiff = rms(cabs(filtered_cnoise_fir - filtered_cnoise_ccv)); + + // Real filtering (time and compare). + tic = std::chrono::high_resolution_clock::now(); + conv_filter_real.apply(filtered_noise_ccv, noise); + toc = std::chrono::high_resolution_clock::now(); + auto const ccv_time_real = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic); + tic = toc; + filtered_noise_fir = kfr::fir(noise, taps127); + toc = std::chrono::high_resolution_clock::now(); + auto const fir_time_real = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic); + auto const diff = rms(filtered_noise_fir - filtered_noise_ccv); + + println("complex: convolution_filter ", ccv_time_complex.count(), " fir ", fir_time_complex.count(), " diff=", cdiff); + println("real: convolution_filter ", ccv_time_real.count(), " fir ", fir_time_real.count(), " diff=", diff); + + return 0; +} diff --git a/include/kfr/dft/convolution.hpp b/include/kfr/dft/convolution.hpp @@ -84,6 +84,9 @@ public: explicit convolve_filter(size_t size, size_t block_size = 1024); explicit convolve_filter(const univector_ref<const T>& data, size_t block_size = 1024); void set_data(const univector_ref<const T>& data); + void reset() final; + /// Apply filter to multiples of returned block size for optimal processing efficiency. + size_t input_block_size() const { return block_size; } protected: void process_expression(T* dest, const expression_pointer<T>& src, size_t size) final @@ -93,19 +96,36 @@ protected: } void process_buffer(T* output, const T* input, size_t size) final; - const size_t size; + using ST = subtype<T>; + static constexpr auto real_fft = !std::is_same<T, complex<ST>>::value; + using plan_t = std::conditional_t<real_fft, dft_plan_real<T>, dft_plan<ST>>; + + // Length of filter data. + size_t data_size; + // Size of block to process. const size_t block_size; - const dft_plan_real<T> fft; + // FFT plan for circular convolution. + const plan_t fft; + // Temp storage for FFT. univector<u8> temp; - std::vector<univector<complex<T>>> segments; - std::vector<univector<complex<T>>> ir_segments; - size_t input_position; + // History of input segments after fwd DFT. History is circular relative to position below. + std::vector<univector<complex<ST>>> segments; + // Index into segments of current block. + size_t position; + // Blocks of filter/data after fwd DFT. + std::vector<univector<complex<ST>>> ir_segments; + // Saved input for current block. univector<T> saved_input; - univector<complex<T>> premul; - univector<complex<T>> cscratch; - univector<T> scratch; + // Index into saved_input for next input to begin. + size_t input_position; + // Pre-multiplied products of input history and delayed filter blocks. + univector<complex<ST>> premul; + // Scratch buffer for product of filter and input for processing by reverse DFT. + univector<complex<ST>> cscratch; + // Scratch buffers for input and output of fwd and rev DFTs. + univector<T> scratch1, scratch2; + // Overlap saved from previous block to add into current block. univector<T> overlap; - size_t position; }; } // namespace CMT_ARCH_NAME diff --git a/include/kfr/dft/impl/convolution-impl.cpp b/include/kfr/dft/impl/convolution-impl.cpp @@ -36,37 +36,39 @@ namespace intrinsics template <typename T> univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2) { - const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); - univector<complex<T>> src1padded = src1; - univector<complex<T>> src2padded = src2; - src1padded.resize(size, 0); - src2padded.resize(size, 0); - - dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size); + using ST = subtype<T>; + const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); + univector<complex<ST>> src1padded = src1; + univector<complex<ST>> src2padded = src2; + src1padded.resize(size); + src2padded.resize(size); + + dft_plan_ptr<ST> dft = dft_cache::instance().get(ctype_t<ST>(), size); univector<u8> temp(dft->temp_size); dft->execute(src1padded, src1padded, temp); dft->execute(src2padded, src2padded, temp); src1padded = src1padded * src2padded; dft->execute(src1padded, src1padded, temp, true); - const T invsize = reciprocal<T>(size); + const ST invsize = reciprocal<ST>(size); return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; } template <typename T> univector<T> correlate(const univector_ref<const T>& src1, const univector_ref<const T>& src2) { - const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); - univector<complex<T>> src1padded = src1; - univector<complex<T>> src2padded = reverse(src2); - src1padded.resize(size, 0); - src2padded.resize(size, 0); - dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size); + using ST = subtype<T>; + const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); + univector<complex<ST>> src1padded = src1; + univector<complex<ST>> src2padded = reverse(src2); + src1padded.resize(size); + src2padded.resize(size); + dft_plan_ptr<ST> dft = dft_cache::instance().get(ctype_t<ST>(), size); univector<u8> temp(dft->temp_size); dft->execute(src1padded, src1padded, temp); dft->execute(src2padded, src2padded, temp); src1padded = src1padded * src2padded; dft->execute(src1padded, src1padded, temp, true); - const T invsize = reciprocal<T>(size); + const ST invsize = reciprocal<ST>(size); return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; } @@ -80,21 +82,55 @@ univector<T> autocorrelate(const univector_ref<const T>& src1) } // namespace intrinsics +// Create a helper template struct to handle the differences between real and complex FFT. +template <typename T, typename ST = subtype<T>, + typename plan_t = + std::conditional_t<std::is_same<T, complex<ST>>::value, dft_plan<ST>, dft_plan_real<T>>> +struct convolve_filter_fft +{ + static plan_t make(size_t size); + static inline void ifft(plan_t const& plan, univector<T>& out, const univector<complex<T>>& in, + univector<u8>& temp); + static size_t csize(plan_t const& plan); +}; +// Partial template specializations for complex and real cases: +template <typename ST> +struct convolve_filter_fft<complex<ST>, ST, dft_plan<ST>> +{ + static dft_plan<ST> make(size_t size) { return dft_plan<ST>(size); } + static inline void ifft(dft_plan<ST> const& plan, univector<complex<ST>>& out, + const univector<complex<ST>>& in, univector<u8>& temp) + { + plan.execute(out, in, temp, ctrue); + } + static size_t csize(dft_plan<ST> const& plan) { return plan.size; } +}; template <typename T> -convolve_filter<T>::convolve_filter(size_t size, size_t block_size) - : size(size), block_size(block_size), fft(2 * next_poweroftwo(block_size), dft_pack_format::Perm), - temp(fft.temp_size), segments((size + block_size - 1) / block_size) +struct convolve_filter_fft<T, T, dft_plan_real<T>> +{ + static dft_plan_real<T> make(size_t size) { return dft_plan_real<T>(size, dft_pack_format::Perm); } + static inline void ifft(dft_plan_real<T> const& plan, univector<T>& out, const univector<complex<T>>& in, + univector<u8>& temp) + { + plan.execute(out, in, temp); + } + static size_t csize(dft_plan_real<T> const& plan) { return plan.size / 2; } +}; +template <typename T> +convolve_filter<T>::convolve_filter(size_t size_, size_t block_size_) + : data_size(size_), block_size(next_poweroftwo(block_size_)), + fft(convolve_filter_fft<T>::make(2 * block_size)), temp(fft.temp_size), + segments((data_size + block_size - 1) / block_size), ir_segments(segments.size()), input_position(0), + saved_input(block_size), premul(convolve_filter_fft<T>::csize(fft)), + cscratch(convolve_filter_fft<T>::csize(fft)), scratch1(fft.size), scratch2(fft.size), + overlap(block_size), position(0) { } template <typename T> -convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t block_size) - : size(data.size()), block_size(next_poweroftwo(block_size)), - fft(2 * next_poweroftwo(block_size), dft_pack_format::Perm), temp(fft.temp_size), - segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), - ir_segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), - input_position(0), position(0) +convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t block_size_) + : convolve_filter(data.size(), block_size_) { set_data(data); } @@ -102,65 +138,125 @@ convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t b template <typename T> void convolve_filter<T>::set_data(const univector_ref<const T>& data) { + data_size = data.size(); + segments.resize((data_size + block_size - 1) / block_size); + ir_segments.resize(segments.size()); univector<T> input(fft.size); - const T ifftsize = reciprocal(T(fft.size)); + const ST ifftsize = reciprocal(ST(fft.size)); for (size_t i = 0; i < ir_segments.size(); i++) { - segments[i].resize(block_size); - ir_segments[i].resize(block_size, 0); + segments[i].resize(convolve_filter_fft<T>::csize(fft)); + ir_segments[i].resize(convolve_filter_fft<T>::csize(fft)); input = padded(data.slice(i * block_size, block_size)); fft.execute(ir_segments[i], input, temp); process(ir_segments[i], ir_segments[i] * ifftsize); } - saved_input.resize(block_size, 0); - scratch.resize(block_size * 2); - premul.resize(block_size, 0); - cscratch.resize(block_size); - overlap.resize(block_size, 0); + reset(); } template <typename T> void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size) { + // Note that the conditionals in the following algorithm are meant to + // reduce complexity in the common cases of either processing complete + // blocks (processing == block_size) or only one segment. + + // For complex filtering, use CCs pack format to omit special processing in fft_multiply[_accumulate]. + static constexpr auto fft_multiply_pack = real_fft ? dft_pack_format::Perm : dft_pack_format::CCs; + size_t processed = 0; while (processed < size) { - const size_t processing = std::min(size - processed, block_size - input_position); - builtin_memcpy(saved_input.data() + input_position, input + processed, processing * sizeof(T)); + // Calculate how many samples to process this iteration. + auto const processing = std::min(size - processed, block_size - input_position); + + // Prepare input to forward FFT: + if (processing == block_size) + { + // No need to work with saved_input. + builtin_memcpy(scratch1.data(), input + processed, processing * sizeof(T)); + } + else + { + // Append this iteration's input to the saved_input current block. + builtin_memcpy(saved_input.data() + input_position, input + processed, processing * sizeof(T)); + builtin_memcpy(scratch1.data(), saved_input.data(), block_size * sizeof(T)); + } - process(scratch, padded(saved_input)); - fft.execute(segments[position], scratch, temp); + // Forward FFT saved_input block. + fft.execute(segments[position], scratch1, temp); - if (input_position == 0) + if (segments.size() == 1) { - process(premul, zeros()); - for (size_t i = 1; i < segments.size(); i++) + // Just one segment/block of history. + // Y_k = H * X_k + fft_multiply(cscratch, ir_segments[0], segments[0], fft_multiply_pack); + } + else + { + // More than one segment/block of history so this is more involved. + if (input_position == 0) { - const size_t n = (position + i) % segments.size(); - fft_multiply_accumulate(premul, ir_segments[i], segments[n], dft_pack_format::Perm); + // At the start of an input block, we premultiply the history from + // previous input blocks with the extended filter blocks. + + // Y_(k-i,i) = H_i * X_(k-i) + // premul += Y_(k-i,i) for i=1,...,N + + fft_multiply(premul, ir_segments[1], segments[(position + 1) % segments.size()], + fft_multiply_pack); + for (size_t i = 2; i < segments.size(); i++) + { + const size_t n = (position + i) % segments.size(); + fft_multiply_accumulate(premul, ir_segments[i], segments[n], fft_multiply_pack); + } } + // Y_(k,0) = H_0 * X_k + // Y_k = premul + Y_(k,0) + fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], fft_multiply_pack); } - fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], dft_pack_format::Perm); - - fft.execute(scratch, cscratch, temp); + // y_k = IFFT( Y_k ) + convolve_filter_fft<T>::ifft(fft, scratch2, cscratch, temp); + // z_k = y_k + overlap process(make_univector(output + processed, processing), - scratch.slice(input_position) + overlap.slice(input_position)); + scratch2.slice(input_position) + overlap.slice(input_position)); input_position += processing; + processed += processing; + + // If a whole block was processed, prepare for next block. if (input_position == block_size) { + // Input block k is complete. Move to (k+1)-th input block. input_position = 0; - process(saved_input, zeros()); - builtin_memcpy(overlap.data(), scratch.data() + block_size, block_size * sizeof(T)); + // Zero out the saved_input if it will be used in the next iteration. + auto const remaining = size - processed; + if (remaining < block_size && remaining > 0) + { + process(saved_input, zeros()); + } + + builtin_memcpy(overlap.data(), scratch2.data() + block_size, block_size * sizeof(T)); position = position > 0 ? position - 1 : segments.size() - 1; } + } +} - processed += processing; +template <typename T> +void convolve_filter<T>::reset() +{ + for (auto& segment : segments) + { + process(segment, zeros()); } + position = 0; + process(saved_input, zeros()); + input_position = 0; + process(overlap, zeros()); } namespace intrinsics @@ -168,40 +264,68 @@ namespace intrinsics template univector<float> convolve<float>(const univector_ref<const float>&, const univector_ref<const float>&); +template univector<complex<float>> convolve<complex<float>>(const univector_ref<const complex<float>>&, + const univector_ref<const complex<float>>&); template univector<float> correlate<float>(const univector_ref<const float>&, const univector_ref<const float>&); +template univector<complex<float>> correlate<complex<float>>(const univector_ref<const complex<float>>&, + const univector_ref<const complex<float>>&); template univector<float> autocorrelate<float>(const univector_ref<const float>&); +template univector<complex<float>> autocorrelate<complex<float>>(const univector_ref<const complex<float>>&); } // namespace intrinsics template convolve_filter<float>::convolve_filter(size_t, size_t); +template convolve_filter<complex<float>>::convolve_filter(size_t, size_t); template convolve_filter<float>::convolve_filter(const univector_ref<const float>&, size_t); +template convolve_filter<complex<float>>::convolve_filter(const univector_ref<const complex<float>>&, size_t); template void convolve_filter<float>::set_data(const univector_ref<const float>&); +template void convolve_filter<complex<float>>::set_data(const univector_ref<const complex<float>>&); template void convolve_filter<float>::process_buffer(float* output, const float* input, size_t size); +template void convolve_filter<complex<float>>::process_buffer(complex<float>* output, + const complex<float>* input, size_t size); + +template void convolve_filter<float>::reset(); +template void convolve_filter<complex<float>>::reset(); namespace intrinsics { template univector<double> convolve<double>(const univector_ref<const double>&, const univector_ref<const double>&); +template univector<complex<double>> convolve<complex<double>>(const univector_ref<const complex<double>>&, + const univector_ref<const complex<double>>&); template univector<double> correlate<double>(const univector_ref<const double>&, const univector_ref<const double>&); +template univector<complex<double>> correlate<complex<double>>(const univector_ref<const complex<double>>&, + const univector_ref<const complex<double>>&); template univector<double> autocorrelate<double>(const univector_ref<const double>&); +template univector<complex<double>> autocorrelate<complex<double>>( + const univector_ref<const complex<double>>&); } // namespace intrinsics template convolve_filter<double>::convolve_filter(size_t, size_t); +template convolve_filter<complex<double>>::convolve_filter(size_t, size_t); template convolve_filter<double>::convolve_filter(const univector_ref<const double>&, size_t); +template convolve_filter<complex<double>>::convolve_filter(const univector_ref<const complex<double>>&, + size_t); template void convolve_filter<double>::set_data(const univector_ref<const double>&); +template void convolve_filter<complex<double>>::set_data(const univector_ref<const complex<double>>&); template void convolve_filter<double>::process_buffer(double* output, const double* input, size_t size); +template void convolve_filter<complex<double>>::process_buffer(complex<double>* output, + const complex<double>* input, size_t size); + +template void convolve_filter<double>::reset(); +template void convolve_filter<complex<double>>::reset(); template <typename T> filter<T>* make_convolve_filter(const univector_ref<const T>& taps, size_t block_size) @@ -210,7 +334,9 @@ filter<T>* make_convolve_filter(const univector_ref<const T>& taps, size_t block } template filter<float>* make_convolve_filter(const univector_ref<const float>&, size_t); +template filter<complex<float>>* make_convolve_filter(const univector_ref<const complex<float>>&, size_t); template filter<double>* make_convolve_filter(const univector_ref<const double>&, size_t); +template filter<complex<double>>* make_convolve_filter(const univector_ref<const complex<double>>&, size_t); } // namespace CMT_ARCH_NAME } // namespace kfr diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp @@ -32,7 +32,17 @@ TEST(test_convolve) CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 })) < 0.0001); } -TEST(test_fft_convolve) +TEST(test_complex_convolve) +{ + univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 }); + univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 }); + univector<complex<fbase>> c = convolve(a, b); + CHECK(c.size() == 9u); + CHECK(rms(cabs(c - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 }))) < + 0.0001); +} + +TEST(test_convolve_filter) { univector<fbase, 5> a({ 1, 2, 3, 4, 5 }); univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 }); @@ -42,6 +52,21 @@ TEST(test_fft_convolve) CHECK(rms(dest - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75 })) < 0.0001); } +TEST(test_complex_convolve_filter) +{ + univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 }); + univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 }); + univector<complex<fbase>, 5> dest; + convolve_filter<complex<fbase>> filter(a); + filter.apply(dest, b); + CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) < 0.0001); + filter.apply(dest, b); + CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) > 0.0001); + filter.reset(); + filter.apply(dest, b); + CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) < 0.0001); +} + TEST(test_correlate) { univector<fbase, 5> a({ 1, 2, 3, 4, 5 }); @@ -51,6 +76,15 @@ TEST(test_correlate) CHECK(rms(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 })) < 0.0001); } +TEST(test_complex_correlate) +{ + univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 }); + univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 }); + univector<complex<fbase>> c = correlate(a, b); + CHECK(c.size() == 9u); + CHECK(rms(cabs(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 }))) < 0.0001); +} + #if defined CMT_ARCH_ARM || !defined NDEBUG constexpr size_t fft_stopsize = 12; constexpr size_t dft_stopsize = 101;