commit f3bc9bb02773324731c59ea811b8ca4c616b0212
parent 8a969d06cab0633552d6014a65ee5aa0abb0640c
Author: Stephen Larew <stephen@slarew.net>
Date: Thu, 20 Feb 2020 12:28:41 -0800
add complex support to convolve_filter
Diffstat:
6 files changed, 356 insertions(+), 58 deletions(-)
diff --git a/docs/docs/convolution.md b/docs/docs/convolution.md
@@ -68,3 +68,48 @@ reverb_RR.apply(tmp4, audio[1]);
audio[0] = tmp1 + tmp2;
audio[1] = tmp3 + tmp4;
```
+
+# Implementation Details
+
+The convolution filter efficiently computes the convolution of two signals.
+The efficiency is achieved by employing the FFT and the circular convolution
+theorem. The algorithm is a variant of the [overlap-add
+method](https://en.wikipedia.org/wiki/Overlap%E2%80%93add_method). It works on
+a fixed block size \(B\) for arbitrarily long input signals. Thus, the
+convolution of a streaming input signal with a long FIR filter \(h[n]\) (where
+the length of \(h[n]\) may exceed the block size \(B\)) is computed with a
+fixed complexity \(O(B \log B)\).
+
+More formally, the convolution filter computes \(y[n] = (x * h)[n]\) by
+partitioning the input \(x\) and filter \(h\) into blocks and applies the
+overlap-add method. Let \(x[n]\) be an input signal of arbitrary length. Often,
+\(x[n]\) is a streaming input with unknown length. Let \(h[n]\) be an FIR
+filter with \(M\) taps. The convolution filter works on a fixed block size
+\(B=2^b\).
+
+First, the input and filter are windowed and shifted to the origin to give the
+\(k\)-th block input \(x_k[n] = x[n + kB] , n=\{0,1,\ldots,B-1\},\forall
+k\in\mathbb{Z}\) and \(j\)-th block filter \(h_j[n] = h[n + jB] ,
+n=\{0,1,\ldots,B-1\},j=\{0,1,\ldots,\lfloor M/B \rfloor\}\). The convolution
+\(y_{k,j}[n] = (x_k * h_j)[n]\) is efficiently computed with length \(2B\) FFTs
+as
+\[
+y_{k,j}[n] = \mathrm{IFFT}(\mathrm{FFT}(x_k[n])\cdot\mathrm{FFT}(h_j[n]))
+.
+\]
+
+The overlap-add method sums the "overlap" from the previous block with the current block.
+To complete the \(k\)-th block, the contribution of all blocks of the filter
+are summed together to give
+\[ y_{k}[n] = \sum_j y_{k-j,j}[n] . \]
+The final convolution is then the sum of the shifted blocks
+\[ y[n] = \sum_k y_{k}[n - kB] . \]
+Note that \(y_k[n]\) is of length \(2B\) so its second half overlaps and adds
+into the first half of the \(y_{k+1}[n]\) block.
+
+## Maximum efficiency criterion
+
+To avoid excess computation or maximize throughput, the convolution filter
+should be given input samples in multiples of the block size \(B\). Otherwise,
+the FFT of a block is computed twice as many times as would be necessary and
+hence throughput is reduced.
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
@@ -50,4 +50,6 @@ if (ENABLE_DFT)
target_link_libraries(dft kfr_multidft)
target_compile_definitions(dft PRIVATE -DKFR_DFT_MULTI=1)
endif ()
+ add_executable(ccv ccv.cpp)
+ target_link_libraries(ccv kfr kfr_dft use_arch)
endif ()
diff --git a/examples/ccv.cpp b/examples/ccv.cpp
@@ -0,0 +1,71 @@
+/*
+ * ccv, part of KFR (https://www.kfr.dev)
+ * Copyright (C) 2019 D Levin
+ * See LICENSE.txt for details
+ */
+
+// Complex convolution filter examples
+
+#define CMT_BASETYPE_F32
+
+#include <chrono>
+#include <kfr/base.hpp>
+#include <kfr/dft.hpp>
+#include <kfr/dsp.hpp>
+
+using namespace kfr;
+
+int main()
+{
+ println(library_version());
+
+ // low-pass filter
+ univector<fbase, 1023> taps127;
+ expression_pointer<fbase> kaiser = to_pointer(window_kaiser(taps127.size(), 3.0));
+ fir_lowpass(taps127, 0.2, kaiser, true);
+
+ // Create filters.
+ size_t const block_size = 256;
+ convolve_filter<complex<fbase>> conv_filter_complex(univector<complex<fbase>>(make_complex(taps127, zeros())),
+ block_size);
+ convolve_filter<fbase> conv_filter_real(taps127, block_size);
+
+ // Create noise to filter.
+ auto const size = 1024 * 100 + 33; // not a multiple of block_size
+ univector<complex<fbase>> cnoise =
+ make_complex(truncate(gen_random_range(random_bit_generator{ 1, 2, 3, 4 }, -1.f, +1.f), size),
+ truncate(gen_random_range(random_bit_generator{ 3, 4, 9, 8 }, -1.f, +1.f), size));
+ univector<fbase> noise =
+ truncate(gen_random_range(random_bit_generator{ 3, 4, 9, 8 }, -1.f, +1.f), size);
+
+ // Filter results.
+ univector<complex<fbase>> filtered_cnoise_ccv(size), filtered_cnoise_fir(size);
+ univector<fbase> filtered_noise_ccv(size), filtered_noise_fir(size);
+
+ // Complex filtering (time and compare).
+ auto tic = std::chrono::high_resolution_clock::now();
+ conv_filter_complex.apply(filtered_cnoise_ccv, cnoise);
+ auto toc = std::chrono::high_resolution_clock::now();
+ auto const ccv_time_complex = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
+ tic = toc;
+ filtered_cnoise_fir = kfr::fir(cnoise, taps127);
+ toc = std::chrono::high_resolution_clock::now();
+ auto const fir_time_complex = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
+ auto const cdiff = rms(cabs(filtered_cnoise_fir - filtered_cnoise_ccv));
+
+ // Real filtering (time and compare).
+ tic = std::chrono::high_resolution_clock::now();
+ conv_filter_real.apply(filtered_noise_ccv, noise);
+ toc = std::chrono::high_resolution_clock::now();
+ auto const ccv_time_real = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
+ tic = toc;
+ filtered_noise_fir = kfr::fir(noise, taps127);
+ toc = std::chrono::high_resolution_clock::now();
+ auto const fir_time_real = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
+ auto const diff = rms(filtered_noise_fir - filtered_noise_ccv);
+
+ println("complex: convolution_filter ", ccv_time_complex.count(), " fir ", fir_time_complex.count(), " diff=", cdiff);
+ println("real: convolution_filter ", ccv_time_real.count(), " fir ", fir_time_real.count(), " diff=", diff);
+
+ return 0;
+}
diff --git a/include/kfr/dft/convolution.hpp b/include/kfr/dft/convolution.hpp
@@ -84,6 +84,9 @@ public:
explicit convolve_filter(size_t size, size_t block_size = 1024);
explicit convolve_filter(const univector_ref<const T>& data, size_t block_size = 1024);
void set_data(const univector_ref<const T>& data);
+ void reset() final;
+ /// Apply filter to multiples of returned block size for optimal processing efficiency.
+ size_t input_block_size() const { return block_size; }
protected:
void process_expression(T* dest, const expression_pointer<T>& src, size_t size) final
@@ -93,19 +96,36 @@ protected:
}
void process_buffer(T* output, const T* input, size_t size) final;
- const size_t size;
+ using ST = subtype<T>;
+ static constexpr auto real_fft = !std::is_same<T, complex<ST>>::value;
+ using plan_t = std::conditional_t<real_fft, dft_plan_real<T>, dft_plan<ST>>;
+
+ // Length of filter data.
+ size_t data_size;
+ // Size of block to process.
const size_t block_size;
- const dft_plan_real<T> fft;
+ // FFT plan for circular convolution.
+ const plan_t fft;
+ // Temp storage for FFT.
univector<u8> temp;
- std::vector<univector<complex<T>>> segments;
- std::vector<univector<complex<T>>> ir_segments;
- size_t input_position;
+ // History of input segments after fwd DFT. History is circular relative to position below.
+ std::vector<univector<complex<ST>>> segments;
+ // Index into segments of current block.
+ size_t position;
+ // Blocks of filter/data after fwd DFT.
+ std::vector<univector<complex<ST>>> ir_segments;
+ // Saved input for current block.
univector<T> saved_input;
- univector<complex<T>> premul;
- univector<complex<T>> cscratch;
- univector<T> scratch;
+ // Index into saved_input for next input to begin.
+ size_t input_position;
+ // Pre-multiplied products of input history and delayed filter blocks.
+ univector<complex<ST>> premul;
+ // Scratch buffer for product of filter and input for processing by reverse DFT.
+ univector<complex<ST>> cscratch;
+ // Scratch buffers for input and output of fwd and rev DFTs.
+ univector<T> scratch1, scratch2;
+ // Overlap saved from previous block to add into current block.
univector<T> overlap;
- size_t position;
};
} // namespace CMT_ARCH_NAME
diff --git a/include/kfr/dft/impl/convolution-impl.cpp b/include/kfr/dft/impl/convolution-impl.cpp
@@ -36,37 +36,39 @@ namespace intrinsics
template <typename T>
univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2)
{
- const size_t size = next_poweroftwo(src1.size() + src2.size() - 1);
- univector<complex<T>> src1padded = src1;
- univector<complex<T>> src2padded = src2;
- src1padded.resize(size, 0);
- src2padded.resize(size, 0);
-
- dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size);
+ using ST = subtype<T>;
+ const size_t size = next_poweroftwo(src1.size() + src2.size() - 1);
+ univector<complex<ST>> src1padded = src1;
+ univector<complex<ST>> src2padded = src2;
+ src1padded.resize(size);
+ src2padded.resize(size);
+
+ dft_plan_ptr<ST> dft = dft_cache::instance().get(ctype_t<ST>(), size);
univector<u8> temp(dft->temp_size);
dft->execute(src1padded, src1padded, temp);
dft->execute(src2padded, src2padded, temp);
src1padded = src1padded * src2padded;
dft->execute(src1padded, src1padded, temp, true);
- const T invsize = reciprocal<T>(size);
+ const ST invsize = reciprocal<ST>(size);
return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
}
template <typename T>
univector<T> correlate(const univector_ref<const T>& src1, const univector_ref<const T>& src2)
{
- const size_t size = next_poweroftwo(src1.size() + src2.size() - 1);
- univector<complex<T>> src1padded = src1;
- univector<complex<T>> src2padded = reverse(src2);
- src1padded.resize(size, 0);
- src2padded.resize(size, 0);
- dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size);
+ using ST = subtype<T>;
+ const size_t size = next_poweroftwo(src1.size() + src2.size() - 1);
+ univector<complex<ST>> src1padded = src1;
+ univector<complex<ST>> src2padded = reverse(src2);
+ src1padded.resize(size);
+ src2padded.resize(size);
+ dft_plan_ptr<ST> dft = dft_cache::instance().get(ctype_t<ST>(), size);
univector<u8> temp(dft->temp_size);
dft->execute(src1padded, src1padded, temp);
dft->execute(src2padded, src2padded, temp);
src1padded = src1padded * src2padded;
dft->execute(src1padded, src1padded, temp, true);
- const T invsize = reciprocal<T>(size);
+ const ST invsize = reciprocal<ST>(size);
return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
}
@@ -80,21 +82,55 @@ univector<T> autocorrelate(const univector_ref<const T>& src1)
} // namespace intrinsics
+// Create a helper template struct to handle the differences between real and complex FFT.
+template <typename T, typename ST = subtype<T>,
+ typename plan_t =
+ std::conditional_t<std::is_same<T, complex<ST>>::value, dft_plan<ST>, dft_plan_real<T>>>
+struct convolve_filter_fft
+{
+ static plan_t make(size_t size);
+ static inline void ifft(plan_t const& plan, univector<T>& out, const univector<complex<T>>& in,
+ univector<u8>& temp);
+ static size_t csize(plan_t const& plan);
+};
+// Partial template specializations for complex and real cases:
+template <typename ST>
+struct convolve_filter_fft<complex<ST>, ST, dft_plan<ST>>
+{
+ static dft_plan<ST> make(size_t size) { return dft_plan<ST>(size); }
+ static inline void ifft(dft_plan<ST> const& plan, univector<complex<ST>>& out,
+ const univector<complex<ST>>& in, univector<u8>& temp)
+ {
+ plan.execute(out, in, temp, ctrue);
+ }
+ static size_t csize(dft_plan<ST> const& plan) { return plan.size; }
+};
template <typename T>
-convolve_filter<T>::convolve_filter(size_t size, size_t block_size)
- : size(size), block_size(block_size), fft(2 * next_poweroftwo(block_size), dft_pack_format::Perm),
- temp(fft.temp_size), segments((size + block_size - 1) / block_size)
+struct convolve_filter_fft<T, T, dft_plan_real<T>>
+{
+ static dft_plan_real<T> make(size_t size) { return dft_plan_real<T>(size, dft_pack_format::Perm); }
+ static inline void ifft(dft_plan_real<T> const& plan, univector<T>& out, const univector<complex<T>>& in,
+ univector<u8>& temp)
+ {
+ plan.execute(out, in, temp);
+ }
+ static size_t csize(dft_plan_real<T> const& plan) { return plan.size / 2; }
+};
+template <typename T>
+convolve_filter<T>::convolve_filter(size_t size_, size_t block_size_)
+ : data_size(size_), block_size(next_poweroftwo(block_size_)),
+ fft(convolve_filter_fft<T>::make(2 * block_size)), temp(fft.temp_size),
+ segments((data_size + block_size - 1) / block_size), ir_segments(segments.size()), input_position(0),
+ saved_input(block_size), premul(convolve_filter_fft<T>::csize(fft)),
+ cscratch(convolve_filter_fft<T>::csize(fft)), scratch1(fft.size), scratch2(fft.size),
+ overlap(block_size), position(0)
{
}
template <typename T>
-convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t block_size)
- : size(data.size()), block_size(next_poweroftwo(block_size)),
- fft(2 * next_poweroftwo(block_size), dft_pack_format::Perm), temp(fft.temp_size),
- segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)),
- ir_segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)),
- input_position(0), position(0)
+convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t block_size_)
+ : convolve_filter(data.size(), block_size_)
{
set_data(data);
}
@@ -102,65 +138,125 @@ convolve_filter<T>::convolve_filter(const univector_ref<const T>& data, size_t b
template <typename T>
void convolve_filter<T>::set_data(const univector_ref<const T>& data)
{
+ data_size = data.size();
+ segments.resize((data_size + block_size - 1) / block_size);
+ ir_segments.resize(segments.size());
univector<T> input(fft.size);
- const T ifftsize = reciprocal(T(fft.size));
+ const ST ifftsize = reciprocal(ST(fft.size));
for (size_t i = 0; i < ir_segments.size(); i++)
{
- segments[i].resize(block_size);
- ir_segments[i].resize(block_size, 0);
+ segments[i].resize(convolve_filter_fft<T>::csize(fft));
+ ir_segments[i].resize(convolve_filter_fft<T>::csize(fft));
input = padded(data.slice(i * block_size, block_size));
fft.execute(ir_segments[i], input, temp);
process(ir_segments[i], ir_segments[i] * ifftsize);
}
- saved_input.resize(block_size, 0);
- scratch.resize(block_size * 2);
- premul.resize(block_size, 0);
- cscratch.resize(block_size);
- overlap.resize(block_size, 0);
+ reset();
}
template <typename T>
void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size)
{
+ // Note that the conditionals in the following algorithm are meant to
+ // reduce complexity in the common cases of either processing complete
+ // blocks (processing == block_size) or only one segment.
+
+ // For complex filtering, use CCs pack format to omit special processing in fft_multiply[_accumulate].
+ static constexpr auto fft_multiply_pack = real_fft ? dft_pack_format::Perm : dft_pack_format::CCs;
+
size_t processed = 0;
while (processed < size)
{
- const size_t processing = std::min(size - processed, block_size - input_position);
- builtin_memcpy(saved_input.data() + input_position, input + processed, processing * sizeof(T));
+ // Calculate how many samples to process this iteration.
+ auto const processing = std::min(size - processed, block_size - input_position);
+
+ // Prepare input to forward FFT:
+ if (processing == block_size)
+ {
+ // No need to work with saved_input.
+ builtin_memcpy(scratch1.data(), input + processed, processing * sizeof(T));
+ }
+ else
+ {
+ // Append this iteration's input to the saved_input current block.
+ builtin_memcpy(saved_input.data() + input_position, input + processed, processing * sizeof(T));
+ builtin_memcpy(scratch1.data(), saved_input.data(), block_size * sizeof(T));
+ }
- process(scratch, padded(saved_input));
- fft.execute(segments[position], scratch, temp);
+ // Forward FFT saved_input block.
+ fft.execute(segments[position], scratch1, temp);
- if (input_position == 0)
+ if (segments.size() == 1)
{
- process(premul, zeros());
- for (size_t i = 1; i < segments.size(); i++)
+ // Just one segment/block of history.
+ // Y_k = H * X_k
+ fft_multiply(cscratch, ir_segments[0], segments[0], fft_multiply_pack);
+ }
+ else
+ {
+ // More than one segment/block of history so this is more involved.
+ if (input_position == 0)
{
- const size_t n = (position + i) % segments.size();
- fft_multiply_accumulate(premul, ir_segments[i], segments[n], dft_pack_format::Perm);
+ // At the start of an input block, we premultiply the history from
+ // previous input blocks with the extended filter blocks.
+
+ // Y_(k-i,i) = H_i * X_(k-i)
+ // premul += Y_(k-i,i) for i=1,...,N
+
+ fft_multiply(premul, ir_segments[1], segments[(position + 1) % segments.size()],
+ fft_multiply_pack);
+ for (size_t i = 2; i < segments.size(); i++)
+ {
+ const size_t n = (position + i) % segments.size();
+ fft_multiply_accumulate(premul, ir_segments[i], segments[n], fft_multiply_pack);
+ }
}
+ // Y_(k,0) = H_0 * X_k
+ // Y_k = premul + Y_(k,0)
+ fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], fft_multiply_pack);
}
- fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], dft_pack_format::Perm);
-
- fft.execute(scratch, cscratch, temp);
+ // y_k = IFFT( Y_k )
+ convolve_filter_fft<T>::ifft(fft, scratch2, cscratch, temp);
+ // z_k = y_k + overlap
process(make_univector(output + processed, processing),
- scratch.slice(input_position) + overlap.slice(input_position));
+ scratch2.slice(input_position) + overlap.slice(input_position));
input_position += processing;
+ processed += processing;
+
+ // If a whole block was processed, prepare for next block.
if (input_position == block_size)
{
+ // Input block k is complete. Move to (k+1)-th input block.
input_position = 0;
- process(saved_input, zeros());
- builtin_memcpy(overlap.data(), scratch.data() + block_size, block_size * sizeof(T));
+ // Zero out the saved_input if it will be used in the next iteration.
+ auto const remaining = size - processed;
+ if (remaining < block_size && remaining > 0)
+ {
+ process(saved_input, zeros());
+ }
+
+ builtin_memcpy(overlap.data(), scratch2.data() + block_size, block_size * sizeof(T));
position = position > 0 ? position - 1 : segments.size() - 1;
}
+ }
+}
- processed += processing;
+template <typename T>
+void convolve_filter<T>::reset()
+{
+ for (auto& segment : segments)
+ {
+ process(segment, zeros());
}
+ position = 0;
+ process(saved_input, zeros());
+ input_position = 0;
+ process(overlap, zeros());
}
namespace intrinsics
@@ -168,40 +264,68 @@ namespace intrinsics
template univector<float> convolve<float>(const univector_ref<const float>&,
const univector_ref<const float>&);
+template univector<complex<float>> convolve<complex<float>>(const univector_ref<const complex<float>>&,
+ const univector_ref<const complex<float>>&);
template univector<float> correlate<float>(const univector_ref<const float>&,
const univector_ref<const float>&);
+template univector<complex<float>> correlate<complex<float>>(const univector_ref<const complex<float>>&,
+ const univector_ref<const complex<float>>&);
template univector<float> autocorrelate<float>(const univector_ref<const float>&);
+template univector<complex<float>> autocorrelate<complex<float>>(const univector_ref<const complex<float>>&);
} // namespace intrinsics
template convolve_filter<float>::convolve_filter(size_t, size_t);
+template convolve_filter<complex<float>>::convolve_filter(size_t, size_t);
template convolve_filter<float>::convolve_filter(const univector_ref<const float>&, size_t);
+template convolve_filter<complex<float>>::convolve_filter(const univector_ref<const complex<float>>&, size_t);
template void convolve_filter<float>::set_data(const univector_ref<const float>&);
+template void convolve_filter<complex<float>>::set_data(const univector_ref<const complex<float>>&);
template void convolve_filter<float>::process_buffer(float* output, const float* input, size_t size);
+template void convolve_filter<complex<float>>::process_buffer(complex<float>* output,
+ const complex<float>* input, size_t size);
+
+template void convolve_filter<float>::reset();
+template void convolve_filter<complex<float>>::reset();
namespace intrinsics
{
template univector<double> convolve<double>(const univector_ref<const double>&,
const univector_ref<const double>&);
+template univector<complex<double>> convolve<complex<double>>(const univector_ref<const complex<double>>&,
+ const univector_ref<const complex<double>>&);
template univector<double> correlate<double>(const univector_ref<const double>&,
const univector_ref<const double>&);
+template univector<complex<double>> correlate<complex<double>>(const univector_ref<const complex<double>>&,
+ const univector_ref<const complex<double>>&);
template univector<double> autocorrelate<double>(const univector_ref<const double>&);
+template univector<complex<double>> autocorrelate<complex<double>>(
+ const univector_ref<const complex<double>>&);
} // namespace intrinsics
template convolve_filter<double>::convolve_filter(size_t, size_t);
+template convolve_filter<complex<double>>::convolve_filter(size_t, size_t);
template convolve_filter<double>::convolve_filter(const univector_ref<const double>&, size_t);
+template convolve_filter<complex<double>>::convolve_filter(const univector_ref<const complex<double>>&,
+ size_t);
template void convolve_filter<double>::set_data(const univector_ref<const double>&);
+template void convolve_filter<complex<double>>::set_data(const univector_ref<const complex<double>>&);
template void convolve_filter<double>::process_buffer(double* output, const double* input, size_t size);
+template void convolve_filter<complex<double>>::process_buffer(complex<double>* output,
+ const complex<double>* input, size_t size);
+
+template void convolve_filter<double>::reset();
+template void convolve_filter<complex<double>>::reset();
template <typename T>
filter<T>* make_convolve_filter(const univector_ref<const T>& taps, size_t block_size)
@@ -210,7 +334,9 @@ filter<T>* make_convolve_filter(const univector_ref<const T>& taps, size_t block
}
template filter<float>* make_convolve_filter(const univector_ref<const float>&, size_t);
+template filter<complex<float>>* make_convolve_filter(const univector_ref<const complex<float>>&, size_t);
template filter<double>* make_convolve_filter(const univector_ref<const double>&, size_t);
+template filter<complex<double>>* make_convolve_filter(const univector_ref<const complex<double>>&, size_t);
} // namespace CMT_ARCH_NAME
} // namespace kfr
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -32,7 +32,17 @@ TEST(test_convolve)
CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 })) < 0.0001);
}
-TEST(test_fft_convolve)
+TEST(test_complex_convolve)
+{
+ univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 });
+ univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
+ univector<complex<fbase>> c = convolve(a, b);
+ CHECK(c.size() == 9u);
+ CHECK(rms(cabs(c - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 }))) <
+ 0.0001);
+}
+
+TEST(test_convolve_filter)
{
univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
@@ -42,6 +52,21 @@ TEST(test_fft_convolve)
CHECK(rms(dest - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75 })) < 0.0001);
}
+TEST(test_complex_convolve_filter)
+{
+ univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 });
+ univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
+ univector<complex<fbase>, 5> dest;
+ convolve_filter<complex<fbase>> filter(a);
+ filter.apply(dest, b);
+ CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) < 0.0001);
+ filter.apply(dest, b);
+ CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) > 0.0001);
+ filter.reset();
+ filter.apply(dest, b);
+ CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) < 0.0001);
+}
+
TEST(test_correlate)
{
univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
@@ -51,6 +76,15 @@ TEST(test_correlate)
CHECK(rms(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 })) < 0.0001);
}
+TEST(test_complex_correlate)
+{
+ univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 });
+ univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
+ univector<complex<fbase>> c = correlate(a, b);
+ CHECK(c.size() == 9u);
+ CHECK(rms(cabs(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 }))) < 0.0001);
+}
+
#if defined CMT_ARCH_ARM || !defined NDEBUG
constexpr size_t fft_stopsize = 12;
constexpr size_t dft_stopsize = 101;