kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 079f33b2824a62c461360937e18c31de16176225
parent 70fbcc93879e9f091203e669eef8f75f4ad09462
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Fri, 16 Nov 2018 04:36:47 +0300

FFT: support for AVX-512

Diffstat:
Minclude/kfr/dft/fft.hpp | 205+++++++++++++++++++++++++++++++++++++------------------------------------------
Minclude/kfr/testo/assert.hpp | 8+++++---
2 files changed, 101 insertions(+), 112 deletions(-)

diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp @@ -31,6 +31,7 @@ #include "../base/read_write.hpp" #include "../base/small_buffer.hpp" #include "../base/vec.hpp" +#include "../testo/assert.hpp" #include "bitrev.hpp" #include "ft.hpp" @@ -46,6 +47,11 @@ CMT_PRAGMA_MSVC(warning(disable : 4100)) namespace kfr { +#define DFT_ASSERT TESTO_ASSERT_ACTIVE + +template <typename T> +constexpr size_t fft_vector_width = platform<T>::vector_width; + template <typename T> struct dft_stage { @@ -83,11 +89,11 @@ KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split cvec<T, width> ww = w; cvec<T, width> tw_ = tw; cvec<T, width> b1 = ww * dupeven(tw_); - ww = swap<2>(ww); + ww = swap<2>(ww); if (inverse) tw_ = -(tw_); - ww = subadd(b1, ww * dupodd(tw_)); + ww = subadd(b1, ww * dupodd(tw_)); return ww; } @@ -235,8 +241,8 @@ KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_ for (size_t i = 0; i < width; i++) { const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size); - result[i * 2] = r[0]; - result[i * 2 + 1] = r[1]; + result[i * 2] = r[0]; + result[i * 2 + 1] = r[1]; } if (split_format) ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result); @@ -248,9 +254,11 @@ KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_ template <typename T, size_t width> CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format) { + const size_t count = stage_size / 4; + // DFT_ASSERT(width <= count); size_t nnstep = size / stage_size; CMT_LOOP_NOUNROLL - for (size_t n = 0; n < stage_size / 4; n += width) + for (size_t n = 0; n < count; n += width) { initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format); initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format); @@ -295,6 +303,7 @@ KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t CMT_ASSUME(blocks > 0); CMT_ASSUME(N > 0); CMT_ASSUME(N4 > 0); + DFT_ASSERT(width <= N4); CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++) { CMT_PRAGMA_CLANG(clang loop unroll_count(2)) @@ -359,6 +368,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfals { CMT_ASSUME(blocks > 0); constexpr static size_t prefetch_offset = width * 16; + DFT_ASSERT(2 <= blocks); for (size_t b = 0; b < blocks; b += 2) { if (prefetch) @@ -384,6 +394,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfal { CMT_ASSUME(blocks > 0); constexpr static size_t prefetch_offset = width * 4; + DFT_ASSERT(2 <= blocks); CMT_PRAGMA_CLANG(clang loop unroll_count(2)) for (size_t b = 0; b < blocks; b += 2) { @@ -415,6 +426,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfals { constexpr static size_t prefetch_offset = width * 4; CMT_ASSUME(blocks > 0); + DFT_ASSERT(4 <= blocks); CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b += 4) { @@ -445,7 +457,7 @@ struct fft_stage_impl : dft_stage<T> protected: constexpr static bool prefetch = true; constexpr static bool aligned = false; - constexpr static size_t width = platform<T>::vector_width; + constexpr static size_t width = fft_vector_width<T>; virtual void do_initialize(size_t size) override final { @@ -457,7 +469,7 @@ protected: { const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); if (splitin) - in = out; + in = out; const size_t stg_size = this->stage_size; CMT_ASSUME(stg_size >= 2048); CMT_ASSUME(stg_size % 2048 == 0); @@ -479,77 +491,77 @@ struct fft_final_stage_impl : dft_stage<T> } protected: - constexpr static size_t width = platform<T>::vector_width; - constexpr static bool is_even = cometa::is_even(ilog2(size)); - constexpr static bool use_br2 = !is_even; - constexpr static bool aligned = false; - constexpr static bool prefetch = splitin; + constexpr static size_t width = fft_vector_width<T>; + constexpr static bool is_even = cometa::is_even(ilog2(size)); + constexpr static bool use_br2 = !is_even; + constexpr static bool aligned = false; + constexpr static bool prefetch = splitin; - virtual void do_initialize(size_t total_size) override final + KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {} + KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {} + + template <size_t N, bool pass_splitin> + KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle) { - complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - size_t stg_size = this->stage_size; - while (stg_size > 4) - { - initialize_twiddles<T, width>(twiddle, stg_size, total_size, true); - stg_size /= 4; - } + constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width; + constexpr size_t pass_width = const_min(width, N / 4); + initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin); + init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle); } - virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final + virtual void do_initialize(size_t total_size) override final { - constexpr bool is_double = sizeof(T) == 8; - constexpr size_t final_size = is_even ? (is_double ? 4 : 16) : (is_double ? 8 : 32); - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - final_pass(csize_t<final_size>(), out, in, twiddle); + complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle); } - KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) + virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override { - radix4_pass(512, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(128, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(32, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(csize_t<8>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + final_stage(csize<size>, 1, cbool<splitin>, out, in, twiddle); } - KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) + // KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + // const complex<T>*& twiddle) + // { + // radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + // } + // + // KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + // const complex<T>*& twiddle) + // { + // radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + // } + + KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + const complex<T>*& twiddle) { - radix4_pass(512, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(128, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(csize_t<32>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); } - KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) + KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + const complex<T>*& twiddle) { - radix4_pass(1024, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(256, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(64, 16, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(16, 64, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(csize_t<4>(), 256, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); } - KFR_INTRIN void final_pass(csize_t<16>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) + template <size_t N, bool pass_splitin> + KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out, + const complex<T>* in, const complex<T>*& twiddle) { - radix4_pass(1024, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(256, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(64, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(csize_t<16>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + static_assert(N > 8, ""); + constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width; + constexpr size_t pass_width = const_min(width, N / 4); + static_assert(pass_width == width || (pass_split == pass_splitin), ""); + static_assert(pass_width <= N / 4, ""); + radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(), + cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, + twiddle); + final_stage(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle); } }; @@ -581,6 +593,7 @@ template <typename T, bool inverse> struct fft_specialization<T, 1, inverse> : dft_stage<T> { fft_specialization(size_t) {} + protected: constexpr static bool aligned = false; virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final @@ -595,6 +608,7 @@ template <typename T, bool inverse> struct fft_specialization<T, 2, inverse> : dft_stage<T> { fft_specialization(size_t) {} + protected: constexpr static bool aligned = false; virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final @@ -610,6 +624,7 @@ template <typename T, bool inverse> struct fft_specialization<T, 3, inverse> : dft_stage<T> { fft_specialization(size_t) {} + protected: constexpr static bool aligned = false; virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final @@ -624,6 +639,7 @@ template <typename T, bool inverse> struct fft_specialization<T, 4, inverse> : dft_stage<T> { fft_specialization(size_t) {} + protected: constexpr static bool aligned = false; virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final @@ -638,6 +654,7 @@ template <typename T, bool inverse> struct fft_specialization<T, 5, inverse> : dft_stage<T> { fft_specialization(size_t) {} + protected: constexpr static bool aligned = false; virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final @@ -652,6 +669,7 @@ template <typename T, bool inverse> struct fft_specialization<T, 6, inverse> : dft_stage<T> { fft_specialization(size_t) {} + protected: constexpr static bool aligned = false; virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final @@ -671,7 +689,7 @@ struct fft_specialization<T, 7, inverse> : dft_stage<T> protected: constexpr static bool aligned = false; - constexpr static size_t width = platform<T>::vector_width; + constexpr static size_t width = fft_vector_width<T>; constexpr static bool use_br2 = true; constexpr static bool prefetch = false; constexpr static bool is_double = sizeof(T) == 8; @@ -716,6 +734,7 @@ template <bool inverse> struct fft_specialization<float, 8, inverse> : dft_stage<float> { fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; } + protected: virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final { @@ -748,48 +767,16 @@ protected: }; template <bool inverse> -struct fft_specialization<double, 8, inverse> : dft_stage<double> +struct fft_specialization<double, 8, inverse> : fft_final_stage_impl<double, false, 256, inverse> { using T = double; - fft_specialization(size_t) - { - this->stage_size = 256; - this->data_size = align_up(sizeof(complex<T>) * 256 * 3 / 2, platform<>::native_cache_alignment); - } - -protected: - constexpr static bool aligned = false; - constexpr static size_t width = platform<T>::vector_width; - constexpr static bool use_br2 = false; - constexpr static bool prefetch = false; - constexpr static size_t split_format = true; - - virtual void do_initialize(size_t total_size) override final - { - complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - initialize_twiddles<T, width>(twiddle, 256, total_size, split_format); - initialize_twiddles<T, width>(twiddle, 64, total_size, split_format); - initialize_twiddles<T, width>(twiddle, 16, total_size, split_format); - } + using fft_final_stage_impl<double, false, 256, inverse>::fft_final_stage_impl; virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final { - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - final_pass(csize_t<4>(), out, in, twiddle); + fft_final_stage_impl<double, false, 256, inverse>::do_execute(out, in, nullptr); fft_reorder(out, csize_t<8>()); } - - KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) - { - radix4_pass(256, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(64, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(16, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(csize_t<4>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } }; template <typename T, bool splitin, bool is_even> @@ -816,14 +803,14 @@ struct fft_specialization_t template <bool inverse> using type = internal::fft_specialization<T, log2n, inverse>; }; -} +} // namespace internal namespace dft_type { constexpr cbools_t<true, true> both{}; constexpr cbools_t<true, false> direct{}; constexpr cbools_t<false, true> inverse{}; -} +} // namespace dft_type template <typename T> struct dft_plan @@ -1029,7 +1016,7 @@ struct dft_plan_real : dft_plan<T> { using namespace internal; - constexpr size_t width = platform<T>::vector_width * 2; + constexpr size_t width = fft_vector_width<T> * 2; block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) { constexpr size_t width = val_of(decltype(w)()); @@ -1078,13 +1065,13 @@ private: size_t csize = this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2 - constexpr size_t width = platform<T>::vector_width * 2; - const cvec<T, 1> dc = cread<1>(out); - const size_t count = csize / 2; + constexpr size_t width = fft_vector_width<T> * 2; + const cvec<T, 1> dc = cread<1>(out); + const size_t count = csize / 2; block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) { - constexpr size_t width = val_of(decltype(w)()); - constexpr size_t widthm1 = width - 1; + constexpr size_t width = val_of(decltype(w)()); + constexpr size_t widthm1 = width - 1; const cvec<T, width> tw = cread<width>(rtwiddle.data() + i); const cvec<T, width> fpk = cread<width>(out + i); const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1))); @@ -1097,7 +1084,7 @@ private: }); { - size_t k = csize / 2; + size_t k = csize / 2; const cvec<T, 1> fpk = cread<1>(out + k); const cvec<T, 1> fpnk = negodd(fpk); cwrite<1>(out + k, fpnk); @@ -1129,13 +1116,13 @@ private: dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag()); } - constexpr size_t width = platform<T>::vector_width * 2; + constexpr size_t width = fft_vector_width<T> * 2; const size_t count = csize / 2; block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) { i++; - constexpr size_t width = val_of(decltype(w)()); - constexpr size_t widthm1 = width - 1; + constexpr size_t width = val_of(decltype(w)()); + constexpr size_t widthm1 = width - 1; const cvec<T, width> tw = cread<width>(rtwiddle.data() + i); const cvec<T, width> fpk = cread<width>(in + i); const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1))); @@ -1148,7 +1135,7 @@ private: }); { - size_t k = csize / 2; + size_t k = csize / 2; const cvec<T, 1> fpk = cread<1>(in + k); const cvec<T, 1> fpnk = 2 * negodd(fpk); cwrite<1>(out + k, fpnk); @@ -1195,7 +1182,7 @@ void fft_multiply_accumulate(univector<complex<T>, Tag1>& dest, const univector< if (fmt == dft_pack_format::Perm) dest[0] = f0; } -} +} // namespace kfr CMT_PRAGMA_GNU(GCC diagnostic pop) diff --git a/include/kfr/testo/assert.hpp b/include/kfr/testo/assert.hpp @@ -50,9 +50,7 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con return result; } -#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF)) - -#define TESTO_ASSERT(...) \ +#define TESTO_ASSERT_ACTIVE(...) \ do \ { \ if (!::testo::check_assertion(::testo::make_comparison() <= __VA_ARGS__, #__VA_ARGS__, __FILE__, \ @@ -60,6 +58,10 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con TESTO_BREAKPOINT; \ } while (0) +#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF)) + +#define TESTO_ASSERT TESTO_ASSERT_ACTIVE + #else #define TESTO_ASSERT(...) \ do \