commit 079f33b2824a62c461360937e18c31de16176225
parent 70fbcc93879e9f091203e669eef8f75f4ad09462
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Fri, 16 Nov 2018 04:36:47 +0300
FFT: support for AVX-512
Diffstat:
2 files changed, 101 insertions(+), 112 deletions(-)
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -31,6 +31,7 @@
#include "../base/read_write.hpp"
#include "../base/small_buffer.hpp"
#include "../base/vec.hpp"
+#include "../testo/assert.hpp"
#include "bitrev.hpp"
#include "ft.hpp"
@@ -46,6 +47,11 @@ CMT_PRAGMA_MSVC(warning(disable : 4100))
namespace kfr
{
+#define DFT_ASSERT TESTO_ASSERT_ACTIVE
+
+template <typename T>
+constexpr size_t fft_vector_width = platform<T>::vector_width;
+
template <typename T>
struct dft_stage
{
@@ -83,11 +89,11 @@ KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split
cvec<T, width> ww = w;
cvec<T, width> tw_ = tw;
cvec<T, width> b1 = ww * dupeven(tw_);
- ww = swap<2>(ww);
+ ww = swap<2>(ww);
if (inverse)
tw_ = -(tw_);
- ww = subadd(b1, ww * dupodd(tw_));
+ ww = subadd(b1, ww * dupodd(tw_));
return ww;
}
@@ -235,8 +241,8 @@ KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_
for (size_t i = 0; i < width; i++)
{
const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size);
- result[i * 2] = r[0];
- result[i * 2 + 1] = r[1];
+ result[i * 2] = r[0];
+ result[i * 2 + 1] = r[1];
}
if (split_format)
ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result);
@@ -248,9 +254,11 @@ KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_
template <typename T, size_t width>
CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format)
{
+ const size_t count = stage_size / 4;
+ // DFT_ASSERT(width <= count);
size_t nnstep = size / stage_size;
CMT_LOOP_NOUNROLL
- for (size_t n = 0; n < stage_size / 4; n += width)
+ for (size_t n = 0; n < count; n += width)
{
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format);
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format);
@@ -295,6 +303,7 @@ KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t
CMT_ASSUME(blocks > 0);
CMT_ASSUME(N > 0);
CMT_ASSUME(N4 > 0);
+ DFT_ASSERT(width <= N4);
CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++)
{
CMT_PRAGMA_CLANG(clang loop unroll_count(2))
@@ -359,6 +368,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfals
{
CMT_ASSUME(blocks > 0);
constexpr static size_t prefetch_offset = width * 16;
+ DFT_ASSERT(2 <= blocks);
for (size_t b = 0; b < blocks; b += 2)
{
if (prefetch)
@@ -384,6 +394,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfal
{
CMT_ASSUME(blocks > 0);
constexpr static size_t prefetch_offset = width * 4;
+ DFT_ASSERT(2 <= blocks);
CMT_PRAGMA_CLANG(clang loop unroll_count(2))
for (size_t b = 0; b < blocks; b += 2)
{
@@ -415,6 +426,7 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfals
{
constexpr static size_t prefetch_offset = width * 4;
CMT_ASSUME(blocks > 0);
+ DFT_ASSERT(4 <= blocks);
CMT_LOOP_NOUNROLL
for (size_t b = 0; b < blocks; b += 4)
{
@@ -445,7 +457,7 @@ struct fft_stage_impl : dft_stage<T>
protected:
constexpr static bool prefetch = true;
constexpr static bool aligned = false;
- constexpr static size_t width = platform<T>::vector_width;
+ constexpr static size_t width = fft_vector_width<T>;
virtual void do_initialize(size_t size) override final
{
@@ -457,7 +469,7 @@ protected:
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
if (splitin)
- in = out;
+ in = out;
const size_t stg_size = this->stage_size;
CMT_ASSUME(stg_size >= 2048);
CMT_ASSUME(stg_size % 2048 == 0);
@@ -479,77 +491,77 @@ struct fft_final_stage_impl : dft_stage<T>
}
protected:
- constexpr static size_t width = platform<T>::vector_width;
- constexpr static bool is_even = cometa::is_even(ilog2(size));
- constexpr static bool use_br2 = !is_even;
- constexpr static bool aligned = false;
- constexpr static bool prefetch = splitin;
+ constexpr static size_t width = fft_vector_width<T>;
+ constexpr static bool is_even = cometa::is_even(ilog2(size));
+ constexpr static bool use_br2 = !is_even;
+ constexpr static bool aligned = false;
+ constexpr static bool prefetch = splitin;
- virtual void do_initialize(size_t total_size) override final
+ KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {}
+ KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {}
+
+ template <size_t N, bool pass_splitin>
+ KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle)
{
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- size_t stg_size = this->stage_size;
- while (stg_size > 4)
- {
- initialize_twiddles<T, width>(twiddle, stg_size, total_size, true);
- stg_size /= 4;
- }
+ constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
+ constexpr size_t pass_width = const_min(width, N / 4);
+ initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin);
+ init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle);
}
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ virtual void do_initialize(size_t total_size) override final
{
- constexpr bool is_double = sizeof(T) == 8;
- constexpr size_t final_size = is_even ? (is_double ? 4 : 16) : (is_double ? 8 : 32);
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass(csize_t<final_size>(), out, in, twiddle);
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override
{
- radix4_pass(512, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(128, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(32, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<8>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ final_stage(csize<size>, 1, cbool<splitin>, out, in, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ // KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ // const complex<T>*& twiddle)
+ // {
+ // radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ // }
+ //
+ // KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ // const complex<T>*& twiddle)
+ // {
+ // radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ // }
+
+ KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ const complex<T>*& twiddle)
{
- radix4_pass(512, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(128, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<32>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ const complex<T>*& twiddle)
{
- radix4_pass(1024, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(256, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(64, 16, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(16, 64, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<4>(), 256, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
- KFR_INTRIN void final_pass(csize_t<16>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ template <size_t N, bool pass_splitin>
+ KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out,
+ const complex<T>* in, const complex<T>*& twiddle)
{
- radix4_pass(1024, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(256, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(64, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<16>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ static_assert(N > 8, "");
+ constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
+ constexpr size_t pass_width = const_min(width, N / 4);
+ static_assert(pass_width == width || (pass_split == pass_splitin), "");
+ static_assert(pass_width <= N / 4, "");
+ radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(),
+ cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in,
+ twiddle);
+ final_stage(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
}
};
@@ -581,6 +593,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 1, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -595,6 +608,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 2, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -610,6 +624,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 3, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -624,6 +639,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 4, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -638,6 +654,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 5, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -652,6 +669,7 @@ template <typename T, bool inverse>
struct fft_specialization<T, 6, inverse> : dft_stage<T>
{
fft_specialization(size_t) {}
+
protected:
constexpr static bool aligned = false;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
@@ -671,7 +689,7 @@ struct fft_specialization<T, 7, inverse> : dft_stage<T>
protected:
constexpr static bool aligned = false;
- constexpr static size_t width = platform<T>::vector_width;
+ constexpr static size_t width = fft_vector_width<T>;
constexpr static bool use_br2 = true;
constexpr static bool prefetch = false;
constexpr static bool is_double = sizeof(T) == 8;
@@ -716,6 +734,7 @@ template <bool inverse>
struct fft_specialization<float, 8, inverse> : dft_stage<float>
{
fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; }
+
protected:
virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final
{
@@ -748,48 +767,16 @@ protected:
};
template <bool inverse>
-struct fft_specialization<double, 8, inverse> : dft_stage<double>
+struct fft_specialization<double, 8, inverse> : fft_final_stage_impl<double, false, 256, inverse>
{
using T = double;
- fft_specialization(size_t)
- {
- this->stage_size = 256;
- this->data_size = align_up(sizeof(complex<T>) * 256 * 3 / 2, platform<>::native_cache_alignment);
- }
-
-protected:
- constexpr static bool aligned = false;
- constexpr static size_t width = platform<T>::vector_width;
- constexpr static bool use_br2 = false;
- constexpr static bool prefetch = false;
- constexpr static size_t split_format = true;
-
- virtual void do_initialize(size_t total_size) override final
- {
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- initialize_twiddles<T, width>(twiddle, 256, total_size, split_format);
- initialize_twiddles<T, width>(twiddle, 64, total_size, split_format);
- initialize_twiddles<T, width>(twiddle, 16, total_size, split_format);
- }
+ using fft_final_stage_impl<double, false, 256, inverse>::fft_final_stage_impl;
virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
{
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass(csize_t<4>(), out, in, twiddle);
+ fft_final_stage_impl<double, false, 256, inverse>::do_execute(out, in, nullptr);
fft_reorder(out, csize_t<8>());
}
-
- KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
- {
- radix4_pass(256, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(64, 4, csize_t<width>(), ctrue, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(16, 16, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<4>(), 64, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- }
};
template <typename T, bool splitin, bool is_even>
@@ -816,14 +803,14 @@ struct fft_specialization_t
template <bool inverse>
using type = internal::fft_specialization<T, log2n, inverse>;
};
-}
+} // namespace internal
namespace dft_type
{
constexpr cbools_t<true, true> both{};
constexpr cbools_t<true, false> direct{};
constexpr cbools_t<false, true> inverse{};
-}
+} // namespace dft_type
template <typename T>
struct dft_plan
@@ -1029,7 +1016,7 @@ struct dft_plan_real : dft_plan<T>
{
using namespace internal;
- constexpr size_t width = platform<T>::vector_width * 2;
+ constexpr size_t width = fft_vector_width<T> * 2;
block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) {
constexpr size_t width = val_of(decltype(w)());
@@ -1078,13 +1065,13 @@ private:
size_t csize =
this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2
- constexpr size_t width = platform<T>::vector_width * 2;
- const cvec<T, 1> dc = cread<1>(out);
- const size_t count = csize / 2;
+ constexpr size_t width = fft_vector_width<T> * 2;
+ const cvec<T, 1> dc = cread<1>(out);
+ const size_t count = csize / 2;
block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
- constexpr size_t width = val_of(decltype(w)());
- constexpr size_t widthm1 = width - 1;
+ constexpr size_t width = val_of(decltype(w)());
+ constexpr size_t widthm1 = width - 1;
const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
const cvec<T, width> fpk = cread<width>(out + i);
const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1)));
@@ -1097,7 +1084,7 @@ private:
});
{
- size_t k = csize / 2;
+ size_t k = csize / 2;
const cvec<T, 1> fpk = cread<1>(out + k);
const cvec<T, 1> fpnk = negodd(fpk);
cwrite<1>(out + k, fpnk);
@@ -1129,13 +1116,13 @@ private:
dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag());
}
- constexpr size_t width = platform<T>::vector_width * 2;
+ constexpr size_t width = fft_vector_width<T> * 2;
const size_t count = csize / 2;
block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
i++;
- constexpr size_t width = val_of(decltype(w)());
- constexpr size_t widthm1 = width - 1;
+ constexpr size_t width = val_of(decltype(w)());
+ constexpr size_t widthm1 = width - 1;
const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
const cvec<T, width> fpk = cread<width>(in + i);
const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1)));
@@ -1148,7 +1135,7 @@ private:
});
{
- size_t k = csize / 2;
+ size_t k = csize / 2;
const cvec<T, 1> fpk = cread<1>(in + k);
const cvec<T, 1> fpnk = 2 * negodd(fpk);
cwrite<1>(out + k, fpnk);
@@ -1195,7 +1182,7 @@ void fft_multiply_accumulate(univector<complex<T>, Tag1>& dest, const univector<
if (fmt == dft_pack_format::Perm)
dest[0] = f0;
}
-}
+} // namespace kfr
CMT_PRAGMA_GNU(GCC diagnostic pop)
diff --git a/include/kfr/testo/assert.hpp b/include/kfr/testo/assert.hpp
@@ -50,9 +50,7 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con
return result;
}
-#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF))
-
-#define TESTO_ASSERT(...) \
+#define TESTO_ASSERT_ACTIVE(...) \
do \
{ \
if (!::testo::check_assertion(::testo::make_comparison() <= __VA_ARGS__, #__VA_ARGS__, __FILE__, \
@@ -60,6 +58,10 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con
TESTO_BREAKPOINT; \
} while (0)
+#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF))
+
+#define TESTO_ASSERT TESTO_ASSERT_ACTIVE
+
#else
#define TESTO_ASSERT(...) \
do \