commit fd75e6440a2659a4991d2690c5c7677494a505c5
parent 2e9ef22bc9777f665c067ce25f99e5a02ca64d32
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Fri, 16 Nov 2018 04:56:30 +0300
Merge branch 'building_dft' into 3.0
Diffstat:
6 files changed, 1132 insertions(+), 1005 deletions(-)
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
@@ -1,16 +1,16 @@
# Copyright (C) 2016 D Levin (http://www.kfrlib.com)
# This file is part of KFR
-#
+#
# KFR is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
-#
+#
# KFR is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
-#
+#
# You should have received a copy of the GNU General Public License
# along with KFR.
@@ -23,4 +23,4 @@ add_executable(biquads biquads.cpp ${KFR_SRC})
add_executable(window window.cpp ${KFR_SRC})
add_executable(fir fir.cpp ${KFR_SRC})
add_executable(sample_rate_conversion sample_rate_conversion.cpp ${KFR_SRC})
-add_executable(dft dft.cpp ${KFR_SRC} ${DFT_SRC})
+add_executable(dft dft.cpp ${KFR_SRC} ${DFT_SRC} ../include/kfr/dft/dft-src.cpp)
diff --git a/include/kfr/dft.hpp b/include/kfr/dft.hpp
@@ -24,8 +24,6 @@
#include "base.hpp"
-#include "dft/bitrev.hpp"
#include "dft/convolution.hpp"
#include "dft/fft.hpp"
-#include "dft/ft.hpp"
#include "dft/reference_dft.hpp"
diff --git a/include/kfr/dft/dft-src.cpp b/include/kfr/dft/dft-src.cpp
@@ -0,0 +1,1101 @@
+/** @addtogroup dft
+ * @{
+ */
+/*
+ Copyright (C) 2016 D Levin (https://www.kfrlib.com)
+ This file is part of KFR
+
+ KFR is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ KFR is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with KFR.
+
+ If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
+ Buying a commercial license is mandatory as soon as you develop commercial activities without
+ disclosing the source code of your own applications.
+ See https://www.kfrlib.com for details.
+ */
+
+#include "bitrev.hpp"
+#include "fft.hpp"
+#include "ft.hpp"
+#include "../testo/assert.hpp"
+
+CMT_PRAGMA_GNU(GCC diagnostic push)
+#if CMT_HAS_WARNING("-Wshadow")
+CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow")
+#endif
+
+CMT_PRAGMA_MSVC(warning(push))
+CMT_PRAGMA_MSVC(warning(disable : 4100))
+
+namespace kfr
+{
+
+#define DFT_ASSERT TESTO_ASSERT_INACTIVE
+
+template <typename T>
+constexpr size_t fft_vector_width = platform<T>::vector_width;
+
+template <typename T>
+struct dft_stage
+{
+ size_t stage_size = 0;
+ size_t data_size = 0;
+ size_t temp_size = 0;
+ u8* data = nullptr;
+ size_t repeats = 1;
+ size_t out_offset = 0;
+ const char* name;
+ bool recursion = false;
+
+ void initialize(size_t size) { do_initialize(size); }
+
+ KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp) { do_execute(out, in, temp); }
+ virtual ~dft_stage() {}
+
+protected:
+ virtual void do_initialize(size_t) {}
+ virtual void do_execute(complex<T>*, const complex<T>*, u8* temp) = 0;
+};
+
+CMT_PRAGMA_GNU(GCC diagnostic push)
+#if CMT_HAS_WARNING("-Wassume")
+CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wassume")
+#endif
+
+namespace internal
+{
+
+template <size_t width, bool inverse, typename T>
+KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split_format*/, cbool_t<inverse>,
+ const cvec<T, width>& w, const cvec<T, width>& tw)
+{
+ cvec<T, width> ww = w;
+ cvec<T, width> tw_ = tw;
+ cvec<T, width> b1 = ww * dupeven(tw_);
+ ww = swap<2>(ww);
+
+ if (inverse)
+ tw_ = -(tw_);
+ ww = subadd(b1, ww * dupodd(tw_));
+ return ww;
+}
+
+template <size_t width, bool use_br2, bool inverse, bool aligned, typename T>
+KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, cfalse_t, cfalse_t, cfalse_t, cbool_t<use_br2>,
+ cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in,
+ const complex<T>* twiddle)
+{
+ const size_t N4 = N / 4;
+ cvec<T, width> w1, w2, w3;
+
+ cvec<T, width> sum02, sum13, diff02, diff13;
+
+ cvec<T, width> a0, a1, a2, a3;
+ a0 = cread<width, aligned>(in + 0);
+ a2 = cread<width, aligned>(in + N4 * 2);
+ sum02 = a0 + a2;
+
+ a1 = cread<width, aligned>(in + N4);
+ a3 = cread<width, aligned>(in + N4 * 3);
+ sum13 = a1 + a3;
+
+ cwrite<width, aligned>(out, sum02 + sum13);
+ w2 = sum02 - sum13;
+ cwrite<width, aligned>(out + N4 * (use_br2 ? 1 : 2),
+ radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w2,
+ cread<width, true>(twiddle + width)));
+ diff02 = a0 - a2;
+ diff13 = a1 - a3;
+ if (inverse)
+ {
+ diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T()));
+ diff13 = swap<2>(diff13);
+ }
+ else
+ {
+ diff13 = swap<2>(diff13);
+ diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T()));
+ }
+
+ w1 = diff02 + diff13;
+
+ cwrite<width, aligned>(out + N4 * (use_br2 ? 2 : 1),
+ radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w1,
+ cread<width, true>(twiddle + 0)));
+ w3 = diff02 - diff13;
+ cwrite<width, aligned>(out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(),
+ w3, cread<width, true>(twiddle + width * 2)));
+}
+
+template <size_t width, bool inverse, typename T>
+KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, ctrue_t /*split_format*/, cbool_t<inverse>,
+ const cvec<T, width>& w, const cvec<T, width>& tw)
+{
+ vec<T, width> re1, im1, twre, twim;
+ split(w, re1, im1);
+ split(tw, twre, twim);
+
+ const vec<T, width> b1re = re1 * twre;
+ const vec<T, width> b1im = im1 * twre;
+ if (inverse)
+ return concat(b1re + im1 * twim, b1im - re1 * twim);
+ else
+ return concat(b1re - im1 * twim, b1im + re1 * twim);
+}
+
+template <size_t width, bool splitout, bool splitin, bool use_br2, bool inverse, bool aligned, typename T>
+KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, ctrue_t, cbool_t<splitout>, cbool_t<splitin>,
+ cbool_t<use_br2>, cbool_t<inverse>, cbool_t<aligned>, complex<T>* out,
+ const complex<T>* in, const complex<T>* twiddle)
+{
+ const size_t N4 = N / 4;
+ cvec<T, width> w1, w2, w3;
+ constexpr bool read_split = !splitin && splitout;
+ constexpr bool write_split = splitin && !splitout;
+
+ vec<T, width> re0, im0, re1, im1, re2, im2, re3, im3;
+
+ split(cread_split<width, aligned, read_split>(in + N4 * 0), re0, im0);
+ split(cread_split<width, aligned, read_split>(in + N4 * 1), re1, im1);
+ split(cread_split<width, aligned, read_split>(in + N4 * 2), re2, im2);
+ split(cread_split<width, aligned, read_split>(in + N4 * 3), re3, im3);
+
+ const vec<T, width> sum02re = re0 + re2;
+ const vec<T, width> sum02im = im0 + im2;
+ const vec<T, width> sum13re = re1 + re3;
+ const vec<T, width> sum13im = im1 + im3;
+
+ cwrite_split<width, aligned, write_split>(out, concat(sum02re + sum13re, sum02im + sum13im));
+ w2 = concat(sum02re - sum13re, sum02im - sum13im);
+ cwrite_split<width, aligned, write_split>(
+ out + N4 * (use_br2 ? 1 : 2), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w2,
+ cread<width, true>(twiddle + width)));
+
+ const vec<T, width> diff02re = re0 - re2;
+ const vec<T, width> diff02im = im0 - im2;
+ const vec<T, width> diff13re = re1 - re3;
+ const vec<T, width> diff13im = im1 - im3;
+
+ (inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re);
+ (inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re);
+
+ cwrite_split<width, aligned, write_split>(
+ out + N4 * (use_br2 ? 2 : 1), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w1,
+ cread<width, true>(twiddle + 0)));
+ cwrite_split<width, aligned, write_split>(
+ out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w3,
+ cread<width, true>(twiddle + width * 2)));
+}
+
+template <typename T>
+CMT_NOINLINE cvec<T, 1> calculate_twiddle(size_t n, size_t size)
+{
+ if (n == 0)
+ {
+ return make_vector(static_cast<T>(1), static_cast<T>(0));
+ }
+ else if (n == size / 4)
+ {
+ return make_vector(static_cast<T>(0), static_cast<T>(-1));
+ }
+ else if (n == size / 2)
+ {
+ return make_vector(static_cast<T>(-1), static_cast<T>(0));
+ }
+ else if (n == size * 3 / 4)
+ {
+ return make_vector(static_cast<T>(0), static_cast<T>(1));
+ }
+ else
+ {
+ fbase kth = c_pi<fbase, 2> * (n / static_cast<fbase>(size));
+ fbase tcos = +kfr::cos(kth);
+ fbase tsin = -kfr::sin(kth);
+ return make_vector(static_cast<T>(tcos), static_cast<T>(tsin));
+ }
+}
+
+template <typename T, size_t width>
+KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_t nnstep, size_t size,
+ bool split_format)
+{
+ vec<T, 2 * width> result = T();
+ CMT_LOOP_UNROLL
+ for (size_t i = 0; i < width; i++)
+ {
+ const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size);
+ result[i * 2] = r[0];
+ result[i * 2 + 1] = r[1];
+ }
+ if (split_format)
+ ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result);
+ else
+ ref_cast<cvec<T, width>>(twiddle[0]) = result;
+ twiddle += width;
+}
+
+template <typename T, size_t width>
+CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format)
+{
+ const size_t count = stage_size / 4;
+ size_t nnstep = size / stage_size;
+ DFT_ASSERT(width <= count);
+ CMT_LOOP_NOUNROLL
+ for (size_t n = 0; n < count; n += width)
+ {
+ initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format);
+ initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format);
+ initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 3, nnstep * 3, size, split_format);
+ }
+}
+
+#ifdef CMT_ARCH_X86
+#ifdef CMT_COMPILER_GNU
+#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr), 0, _MM_HINT_T0);
+#else
+#define KFR_PREFETCH(addr) _mm_prefetch(::kfr::ptr_cast<char>(addr), _MM_HINT_T0);
+#endif
+#else
+#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr));
+#endif
+
+template <typename T>
+KFR_SINTRIN void prefetch_one(const complex<T>* in)
+{
+ KFR_PREFETCH(in);
+}
+
+template <typename T>
+KFR_SINTRIN void prefetch_four(size_t stride, const complex<T>* in)
+{
+ KFR_PREFETCH(in);
+ KFR_PREFETCH(in + stride);
+ KFR_PREFETCH(in + stride * 2);
+ KFR_PREFETCH(in + stride * 3);
+}
+
+template <typename Ntype, size_t width, bool splitout, bool splitin, bool prefetch, bool use_br2,
+ bool inverse, bool aligned, typename T>
+KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t<splitout>, cbool_t<splitin>,
+ cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
+ complex<T>* out, const complex<T>* in, const complex<T>*& twiddle)
+{
+ constexpr static size_t prefetch_offset = width * 8;
+ const auto N4 = N / csize_t<4>();
+ const auto N43 = N4 * csize_t<3>();
+ CMT_ASSUME(blocks > 0);
+ CMT_ASSUME(N > 0);
+ CMT_ASSUME(N4 > 0);
+ DFT_ASSERT(width <= N4);
+ CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++)
+ {
+ CMT_PRAGMA_CLANG(clang loop unroll_count(2))
+ for (size_t n2 = 0; n2 < N4; n2 += width)
+ {
+ if (prefetch)
+ prefetch_four(N4, in + prefetch_offset);
+ radix4_body(N, csize_t<width>(), cbool_t<(splitout || splitin)>(), cbool_t<splitout>(),
+ cbool_t<splitin>(), cbool_t<use_br2>(), cbool_t<inverse>(), cbool_t<aligned>(), out,
+ in, twiddle + n2 * 3);
+ in += width;
+ out += width;
+ }
+ in += N43;
+ out += N43;
+ }
+ twiddle += N43;
+ return {};
+}
+
+template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
+KFR_SINTRIN ctrue_t radix4_pass(csize_t<32>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
+ cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
+ complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
+{
+ CMT_ASSUME(blocks > 0);
+ constexpr static size_t prefetch_offset = 32 * 4;
+ for (size_t b = 0; b < blocks; b++)
+ {
+ if (prefetch)
+ prefetch_four(csize_t<64>(), out + prefetch_offset);
+ cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7;
+ split(cread<8, aligned>(out + 0), w0, w1);
+ split(cread<8, aligned>(out + 8), w2, w3);
+ split(cread<8, aligned>(out + 16), w4, w5);
+ split(cread<8, aligned>(out + 24), w6, w7);
+
+ butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7);
+
+ w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>());
+ w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>());
+ w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>());
+ w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>());
+ w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>());
+ w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>());
+ w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>());
+
+ cvec<T, 8> z0, z1, z2, z3;
+ transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3);
+
+ butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3);
+ cwrite<32, aligned>(out, bitreverse<2>(concat(z0, z1, z2, z3)));
+ out += 32;
+ }
+ return {};
+}
+
+template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
+KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
+ cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
+ complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
+{
+ CMT_ASSUME(blocks > 0);
+ DFT_ASSERT(2 <= blocks);
+ constexpr static size_t prefetch_offset = width * 16;
+ for (size_t b = 0; b < blocks; b += 2)
+ {
+ if (prefetch)
+ prefetch_one(out + prefetch_offset);
+
+ cvec<T, 8> vlo = cread<8, aligned>(out + 0);
+ cvec<T, 8> vhi = cread<8, aligned>(out + 8);
+ butterfly8<inverse>(vlo);
+ butterfly8<inverse>(vhi);
+ vlo = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vlo);
+ vhi = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vhi);
+ cwrite<8, aligned>(out, vlo);
+ cwrite<8, aligned>(out + 8, vhi);
+ out += 16;
+ }
+ return {};
+}
+
+template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
+KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
+ cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
+ complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
+{
+ CMT_ASSUME(blocks > 0);
+ constexpr static size_t prefetch_offset = width * 4;
+ DFT_ASSERT(2 <= blocks);
+ CMT_PRAGMA_CLANG(clang loop unroll_count(2))
+ for (size_t b = 0; b < blocks; b += 2)
+ {
+ if (prefetch)
+ prefetch_one(out + prefetch_offset);
+
+ cvec<T, 16> vlo = cread<16, aligned>(out);
+ cvec<T, 16> vhi = cread<16, aligned>(out + 16);
+ butterfly4<4, inverse>(vlo);
+ butterfly4<4, inverse>(vhi);
+ apply_twiddles4<0, 4, 4, inverse>(vlo);
+ apply_twiddles4<0, 4, 4, inverse>(vhi);
+ vlo = digitreverse4<2>(vlo);
+ vhi = digitreverse4<2>(vhi);
+ butterfly4<4, inverse>(vlo);
+ butterfly4<4, inverse>(vhi);
+
+ use_br2 ? cbitreverse_write(out, vlo) : cdigitreverse4_write(out, vlo);
+ use_br2 ? cbitreverse_write(out + 16, vhi) : cdigitreverse4_write(out + 16, vhi);
+ out += 32;
+ }
+ return {};
+}
+
+template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
+KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
+ cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
+ complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
+{
+ constexpr static size_t prefetch_offset = width * 4;
+ CMT_ASSUME(blocks > 0);
+ DFT_ASSERT(4 <= blocks);
+ CMT_LOOP_NOUNROLL
+ for (size_t b = 0; b < blocks; b += 4)
+ {
+ if (prefetch)
+ prefetch_one(out + prefetch_offset);
+
+ cvec<T, 16> v16 = cdigitreverse4_read<16, aligned>(out);
+ butterfly4<4, inverse>(v16);
+ cdigitreverse4_write<aligned>(out, v16);
+
+ out += 4 * 4;
+ }
+ return {};
+}
+
+template <typename T, bool splitin, bool is_even, bool inverse>
+struct fft_stage_impl : dft_stage<T>
+{
+ fft_stage_impl(size_t stage_size)
+ {
+ this->stage_size = stage_size;
+ this->repeats = 4;
+ this->recursion = true;
+ this->data_size =
+ align_up(sizeof(complex<T>) * stage_size / 4 * 3, platform<>::native_cache_alignment);
+ }
+
+protected:
+ constexpr static bool prefetch = true;
+ constexpr static bool aligned = false;
+ constexpr static size_t width = fft_vector_width<T>;
+
+ virtual void do_initialize(size_t size) override final
+ {
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ initialize_twiddles<T, width>(twiddle, this->stage_size, size, true);
+ }
+
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ {
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ if (splitin)
+ in = out;
+ const size_t stg_size = this->stage_size;
+ CMT_ASSUME(stg_size >= 2048);
+ CMT_ASSUME(stg_size % 2048 == 0);
+ radix4_pass(stg_size, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<!is_even>(),
+ cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
+ }
+};
+
+template <typename T, bool splitin, size_t size, bool inverse>
+struct fft_final_stage_impl : dft_stage<T>
+{
+ fft_final_stage_impl(size_t)
+ {
+ this->stage_size = size;
+ this->out_offset = size;
+ this->repeats = 4;
+ this->recursion = true;
+ this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, platform<>::native_cache_alignment);
+ }
+
+protected:
+ constexpr static size_t width = fft_vector_width<T>;
+ constexpr static bool is_even = cometa::is_even(ilog2(size));
+ constexpr static bool use_br2 = !is_even;
+ constexpr static bool aligned = false;
+ constexpr static bool prefetch = splitin;
+
+ KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {}
+ KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {}
+
+ template <size_t N, bool pass_splitin>
+ KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle)
+ {
+ constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
+ constexpr size_t pass_width = const_min(width, N / 4);
+ initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin);
+ init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle);
+ }
+
+ virtual void do_initialize(size_t total_size) override final
+ {
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle);
+ }
+
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override
+ {
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ final_stage(csize<size>, 1, cbool<splitin>, out, in, twiddle);
+ }
+
+ // KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ // const complex<T>*& twiddle)
+ // {
+ // radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ // }
+ //
+ // KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ // const complex<T>*& twiddle)
+ // {
+ // radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ // }
+
+ KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ const complex<T>*& twiddle)
+ {
+ radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ }
+
+ KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
+ const complex<T>*& twiddle)
+ {
+ radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ }
+
+ template <size_t N, bool pass_splitin>
+ KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out,
+ const complex<T>* in, const complex<T>*& twiddle)
+ {
+ static_assert(N > 8, "");
+ constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
+ constexpr size_t pass_width = const_min(width, N / 4);
+ static_assert(pass_width == width || (pass_split == pass_splitin), "");
+ static_assert(pass_width <= N / 4, "");
+ radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(),
+ cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in,
+ twiddle);
+ final_stage(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
+ }
+};
+
+template <typename T, bool is_even>
+struct fft_reorder_stage_impl : dft_stage<T>
+{
+ fft_reorder_stage_impl(size_t stage_size)
+ {
+ this->stage_size = stage_size;
+ log2n = ilog2(stage_size);
+ this->data_size = 0;
+ }
+
+protected:
+ size_t log2n;
+
+ virtual void do_initialize(size_t) override final {}
+
+ virtual void do_execute(complex<T>* out, const complex<T>*, u8* /*temp*/) override final
+ {
+ fft_reorder(out, log2n, cbool_t<!is_even>());
+ }
+};
+
+template <typename T, size_t log2n, bool inverse>
+struct fft_specialization;
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 1, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t) {}
+
+protected:
+ constexpr static bool aligned = false;
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ {
+ cvec<T, 1> a0, a1;
+ split(cread<2, aligned>(in), a0, a1);
+ cwrite<2, aligned>(out, concat(a0 + a1, a0 - a1));
+ }
+};
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 2, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t) {}
+
+protected:
+ constexpr static bool aligned = false;
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ {
+ cvec<T, 1> a0, a1, a2, a3;
+ split(cread<4>(in), a0, a1, a2, a3);
+ butterfly(cbool_t<inverse>(), a0, a1, a2, a3, a0, a1, a2, a3);
+ cwrite<4>(out, concat(a0, a1, a2, a3));
+ }
+};
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 3, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t) {}
+
+protected:
+ constexpr static bool aligned = false;
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ {
+ cvec<T, 8> v8 = cread<8, aligned>(in);
+ butterfly8<inverse>(v8);
+ cwrite<8, aligned>(out, v8);
+ }
+};
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 4, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t) {}
+
+protected:
+ constexpr static bool aligned = false;
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ {
+ cvec<T, 16> v16 = cread<16, aligned>(in);
+ butterfly16<inverse>(v16);
+ cwrite<16, aligned>(out, v16);
+ }
+};
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 5, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t) {}
+
+protected:
+ constexpr static bool aligned = false;
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ {
+ cvec<T, 32> v32 = cread<32, aligned>(in);
+ butterfly32<inverse>(v32);
+ cwrite<32, aligned>(out, v32);
+ }
+};
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 6, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t) {}
+
+protected:
+ constexpr static bool aligned = false;
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ {
+ butterfly64(cbool_t<inverse>(), cbool_t<aligned>(), out, in);
+ }
+};
+
+template <typename T, bool inverse>
+struct fft_specialization<T, 7, inverse> : dft_stage<T>
+{
+ fft_specialization(size_t)
+ {
+ this->stage_size = 128;
+ this->data_size = align_up(sizeof(complex<T>) * 128 * 3 / 2, platform<>::native_cache_alignment);
+ }
+
+protected:
+ constexpr static bool aligned = false;
+ constexpr static size_t width = platform<T>::vector_width;
+ constexpr static bool use_br2 = true;
+ constexpr static bool prefetch = false;
+ constexpr static bool is_double = sizeof(T) == 8;
+ constexpr static size_t final_size = is_double ? 8 : 32;
+ constexpr static size_t split_format = final_size == 8;
+
+ virtual void do_initialize(size_t total_size) override final
+ {
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ initialize_twiddles<T, width>(twiddle, 128, total_size, split_format);
+ initialize_twiddles<T, width>(twiddle, 32, total_size, split_format);
+ initialize_twiddles<T, width>(twiddle, 8, total_size, split_format);
+ }
+
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ {
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ final_pass(csize_t<final_size>(), out, in, twiddle);
+ fft_reorder(out, csize_t<7>());
+ }
+
+ KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ {
+ radix4_pass(128, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
+ cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
+ radix4_pass(32, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
+ cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ radix4_pass(csize_t<8>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ }
+
+ KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
+ {
+ radix4_pass(128, 1, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
+ cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
+ radix4_pass(csize_t<32>(), 4, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
+ cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ }
+};
+
+template <bool inverse>
+struct fft_specialization<float, 8, inverse> : dft_stage<float>
+{
+ fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; }
+
+protected:
+ virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final
+ {
+ complex<float>* scratch = ptr_cast<complex<float>>(temp);
+ if (out == in)
+ {
+ butterfly16_multi_flip<0, inverse>(scratch, out);
+ butterfly16_multi_flip<1, inverse>(scratch, out);
+ butterfly16_multi_flip<2, inverse>(scratch, out);
+ butterfly16_multi_flip<3, inverse>(scratch, out);
+
+ butterfly16_multi_natural<0, inverse>(out, scratch);
+ butterfly16_multi_natural<1, inverse>(out, scratch);
+ butterfly16_multi_natural<2, inverse>(out, scratch);
+ butterfly16_multi_natural<3, inverse>(out, scratch);
+ }
+ else
+ {
+ butterfly16_multi_flip<0, inverse>(out, in);
+ butterfly16_multi_flip<1, inverse>(out, in);
+ butterfly16_multi_flip<2, inverse>(out, in);
+ butterfly16_multi_flip<3, inverse>(out, in);
+
+ butterfly16_multi_natural<0, inverse>(out, out);
+ butterfly16_multi_natural<1, inverse>(out, out);
+ butterfly16_multi_natural<2, inverse>(out, out);
+ butterfly16_multi_natural<3, inverse>(out, out);
+ }
+ }
+};
+
+template <bool inverse>
+struct fft_specialization<double, 8, inverse> : fft_final_stage_impl<double, false, 256, inverse>
+{
+ using T = double;
+ using fft_final_stage_impl<double, false, 256, inverse>::fft_final_stage_impl;
+
+ virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ {
+ fft_final_stage_impl<double, false, 256, inverse>::do_execute(out, in, nullptr);
+ fft_reorder(out, csize_t<8>());
+ }
+};
+
+template <typename T, bool splitin, bool is_even>
+struct fft_stage_impl_t
+{
+ template <bool inverse>
+ using type = internal::fft_stage_impl<T, splitin, is_even, inverse>;
+};
+template <typename T, bool splitin, size_t size>
+struct fft_final_stage_impl_t
+{
+ template <bool inverse>
+ using type = internal::fft_final_stage_impl<T, splitin, size, inverse>;
+};
+template <typename T, bool is_even>
+struct fft_reorder_stage_impl_t
+{
+ template <bool>
+ using type = internal::fft_reorder_stage_impl<T, is_even>;
+};
+template <typename T, size_t log2n, bool aligned>
+struct fft_specialization_t
+{
+ template <bool inverse>
+ using type = internal::fft_specialization<T, log2n, inverse>;
+};
+} // namespace internal
+
+//
+
+template <typename T>
+template <template <bool inverse> class Stage>
+void dft_plan<T>::add_stage(size_t stage_size, cbools_t<true, true>)
+{
+ dft_stage<T>* direct_stage = new Stage<false>(stage_size);
+ direct_stage->name = nullptr;
+ dft_stage<T>* inverse_stage = new Stage<true>(stage_size);
+ inverse_stage->name = nullptr;
+ this->data_size += direct_stage->data_size;
+ this->temp_size += direct_stage->temp_size;
+ stages[0].push_back(dft_stage_ptr(direct_stage));
+ stages[1].push_back(dft_stage_ptr(inverse_stage));
+}
+
+template <typename T>
+template <template <bool inverse> class Stage>
+void dft_plan<T>::add_stage(size_t stage_size, cbools_t<true, false>)
+{
+ dft_stage<T>* direct_stage = new Stage<false>(stage_size);
+ direct_stage->name = nullptr;
+ this->data_size += direct_stage->data_size;
+ this->temp_size += direct_stage->temp_size;
+ stages[0].push_back(dft_stage_ptr(direct_stage));
+}
+
+template <typename T>
+template <template <bool inverse> class Stage>
+void dft_plan<T>::add_stage(size_t stage_size, cbools_t<false, true>)
+{
+ dft_stage<T>* inverse_stage = new Stage<true>(stage_size);
+ inverse_stage->name = nullptr;
+ this->data_size += inverse_stage->data_size;
+ this->temp_size += inverse_stage->temp_size;
+ stages[1].push_back(dft_stage_ptr(inverse_stage));
+}
+
+template <typename T>
+template <bool direct, bool inverse, bool is_even, bool first>
+void dft_plan<T>::make_fft(size_t stage_size, cbools_t<direct, inverse> type, cbool_t<is_even>,
+ cbool_t<first>)
+{
+ constexpr size_t final_size = is_even ? 1024 : 512;
+
+ using fft_stage_impl_t = internal::fft_stage_impl_t<T, !first, is_even>;
+ using fft_final_stage_impl_t = internal::fft_final_stage_impl_t<T, !first, final_size>;
+
+ if (stage_size >= 2048)
+ {
+ add_stage<fft_stage_impl_t::template type>(stage_size, type);
+
+ make_fft(stage_size / 4, cbools_t<direct, inverse>(), cbool_t<is_even>(), cfalse);
+ }
+ else
+ {
+ add_stage<fft_final_stage_impl_t::template type>(final_size, type);
+ }
+}
+
+template <typename T>
+template <bool direct, bool inverse>
+void dft_plan<T>::initialize(cbools_t<direct, inverse>)
+{
+ data = autofree<u8>(data_size);
+ if (direct)
+ {
+ size_t offset = 0;
+ for (dft_stage_ptr& stage : stages[0])
+ {
+ stage->data = data.data() + offset;
+ stage->initialize(this->size);
+ offset += stage->data_size;
+ }
+ }
+ if (inverse)
+ {
+ size_t offset = 0;
+ for (dft_stage_ptr& stage : stages[1])
+ {
+ stage->data = data.data() + offset;
+ if (!direct)
+ stage->initialize(this->size);
+ offset += stage->data_size;
+ }
+ }
+}
+
+template <typename T>
+template <bool inverse>
+void dft_plan<T>::execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
+{
+ size_t stack[32] = { 0 };
+
+ const size_t count = stages[inverse].size();
+
+ for (size_t depth = 0; depth < count;)
+ {
+ if (stages[inverse][depth]->recursion)
+ {
+ complex<T>* rout = out;
+ const complex<T>* rin = in;
+ size_t rdepth = depth;
+ size_t maxdepth = depth;
+ do
+ {
+ if (stack[rdepth] == stages[inverse][rdepth]->repeats)
+ {
+ stack[rdepth] = 0;
+ rdepth--;
+ }
+ else
+ {
+ stages[inverse][rdepth]->execute(rout, rin, temp);
+ rout += stages[inverse][rdepth]->out_offset;
+ rin = rout;
+ stack[rdepth]++;
+ if (rdepth < count - 1 && stages[inverse][rdepth + 1]->recursion)
+ rdepth++;
+ else
+ maxdepth = rdepth;
+ }
+ } while (rdepth != depth);
+ depth = maxdepth + 1;
+ }
+ else
+ {
+ stages[inverse][depth]->execute(out, in, temp);
+ depth++;
+ }
+ in = out;
+ }
+}
+
+template <typename T>
+template <bool direct, bool inverse>
+dft_plan<T>::dft_plan(size_t size, cbools_t<direct, inverse> type) : size(size), temp_size(0), data_size(0)
+{
+ if (is_poweroftwo(size))
+ {
+ const size_t log2n = ilog2(size);
+ cswitch(
+ csizes_t<1, 2, 3, 4, 5, 6, 7, 8>(), log2n,
+ [&](auto log2n) {
+ (void)log2n;
+ this->add_stage<
+ internal::fft_specialization_t<T, val_of(decltype(log2n)()), false>::template type>(size,
+ type);
+ },
+ [&]() {
+ cswitch(cfalse_true, is_even(log2n), [&](auto is_even) {
+ this->make_fft(size, type, is_even, ctrue);
+ this->add_stage<
+ internal::fft_reorder_stage_impl_t<T, val_of(decltype(is_even)())>::template type>(
+ size, type);
+ });
+ });
+ initialize(type);
+ }
+}
+
+template <typename T>
+template <bool direct, bool inverse>
+dft_plan_real<T>::dft_plan_real(size_t size, cbools_t<direct, inverse> type)
+ : dft_plan<T>(size / 2, type), size(size), rtwiddle(size / 4)
+{
+ using namespace internal;
+
+ constexpr size_t width = platform<T>::vector_width * 2;
+
+ block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) {
+ constexpr size_t width = val_of(decltype(w)());
+ cwrite<width>(rtwiddle.data() + i,
+ cossin(dup(-constants<T>::pi * ((enumerate<T, width>() + i + size / 4) / (size / 2)))));
+ });
+}
+
+template <typename T>
+void dft_plan_real<T>::to_fmt(complex<T>* out, dft_pack_format fmt) const
+{
+ using namespace internal;
+ size_t csize = this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2
+
+ constexpr size_t width = platform<T>::vector_width * 2;
+ const cvec<T, 1> dc = cread<1>(out);
+ const size_t count = csize / 2;
+
+ block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
+ constexpr size_t width = val_of(decltype(w)());
+ constexpr size_t widthm1 = width - 1;
+ const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
+ const cvec<T, width> fpk = cread<width>(out + i);
+ const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1)));
+
+ const cvec<T, width> f1k = fpk + fpnk;
+ const cvec<T, width> f2k = fpk - fpnk;
+ const cvec<T, width> t = cmul(f2k, tw);
+ cwrite<width>(out + i, T(0.5) * (f1k + t));
+ cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(T(0.5) * (f1k - t))));
+ });
+
+ {
+ size_t k = csize / 2;
+ const cvec<T, 1> fpk = cread<1>(out + k);
+ const cvec<T, 1> fpnk = negodd(fpk);
+ cwrite<1>(out + k, fpnk);
+ }
+ if (fmt == dft_pack_format::CCs)
+ {
+ cwrite<1>(out, pack(dc[0] + dc[1], 0));
+ cwrite<1>(out + csize, pack(dc[0] - dc[1], 0));
+ }
+ else
+ {
+ cwrite<1>(out, pack(dc[0] + dc[1], dc[0] - dc[1]));
+ }
+}
+
+template <typename T>
+void dft_plan_real<T>::from_fmt(complex<T>* out, const complex<T>* in, dft_pack_format fmt) const
+{
+ using namespace internal;
+
+ const size_t csize = this->size / 2;
+
+ cvec<T, 1> dc;
+
+ if (fmt == dft_pack_format::CCs)
+ {
+ dc = pack(in[0].real() + in[csize].real(), in[0].real() - in[csize].real());
+ }
+ else
+ {
+ dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag());
+ }
+
+ constexpr size_t width = platform<T>::vector_width * 2;
+ const size_t count = csize / 2;
+
+ block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
+ i++;
+ constexpr size_t width = val_of(decltype(w)());
+ constexpr size_t widthm1 = width - 1;
+ const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
+ const cvec<T, width> fpk = cread<width>(in + i);
+ const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1)));
+
+ const cvec<T, width> f1k = fpk + fpnk;
+ const cvec<T, width> f2k = fpk - fpnk;
+ const cvec<T, width> t = cmul_conj(f2k, tw);
+ cwrite<width>(out + i, f1k + t);
+ cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t)));
+ });
+
+ {
+ size_t k = csize / 2;
+ const cvec<T, 1> fpk = cread<1>(in + k);
+ const cvec<T, 1> fpnk = 2 * negodd(fpk);
+ cwrite<1>(out + k, fpnk);
+ }
+ cwrite<1>(out, dc);
+}
+
+template <typename T>
+dft_plan<T>::~dft_plan() = default;
+
+template dft_plan<float>::dft_plan(size_t, cbools_t<false, true>);
+template dft_plan<float>::dft_plan(size_t, cbools_t<true, false>);
+template dft_plan<float>::dft_plan(size_t, cbools_t<true, true>);
+template dft_plan<float>::~dft_plan();
+template void dft_plan<float>::execute_dft(cometa::cbool_t<false>, kfr::complex<float>* out,
+ const kfr::complex<float>* in, kfr::u8* temp) const;
+template void dft_plan<float>::execute_dft(cometa::cbool_t<true>, kfr::complex<float>* out,
+ const kfr::complex<float>* in, kfr::u8* temp) const;
+template dft_plan_real<float>::dft_plan_real(size_t, cbools_t<false, true>);
+template dft_plan_real<float>::dft_plan_real(size_t, cbools_t<true, false>);
+template dft_plan_real<float>::dft_plan_real(size_t, cbools_t<true, true>);
+template void dft_plan_real<float>::from_fmt(kfr::complex<float>* out, const kfr::complex<float>* in,
+ kfr::dft_pack_format fmt) const;
+template void dft_plan_real<float>::to_fmt(kfr::complex<float>* out, kfr::dft_pack_format fmt) const;
+
+template dft_plan<double>::dft_plan(size_t, cbools_t<false, true>);
+template dft_plan<double>::dft_plan(size_t, cbools_t<true, false>);
+template dft_plan<double>::dft_plan(size_t, cbools_t<true, true>);
+template dft_plan<double>::~dft_plan();
+template void dft_plan<double>::execute_dft(cometa::cbool_t<false>, kfr::complex<double>* out,
+ const kfr::complex<double>* in, kfr::u8* temp) const;
+template void dft_plan<double>::execute_dft(cometa::cbool_t<true>, kfr::complex<double>* out,
+ const kfr::complex<double>* in, kfr::u8* temp) const;
+template dft_plan_real<double>::dft_plan_real(size_t, cbools_t<false, true>);
+template dft_plan_real<double>::dft_plan_real(size_t, cbools_t<true, false>);
+template dft_plan_real<double>::dft_plan_real(size_t, cbools_t<true, true>);
+template void dft_plan_real<double>::from_fmt(kfr::complex<double>* out, const kfr::complex<double>* in,
+ kfr::dft_pack_format fmt) const;
+template void dft_plan_real<double>::to_fmt(kfr::complex<double>* out, kfr::dft_pack_format fmt) const;
+
+} // namespace kfr
+
+CMT_PRAGMA_GNU(GCC diagnostic pop)
+
+CMT_PRAGMA_MSVC(warning(pop))
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -31,15 +31,14 @@
#include "../base/read_write.hpp"
#include "../base/small_buffer.hpp"
#include "../base/vec.hpp"
-#include "../testo/assert.hpp"
-
-#include "bitrev.hpp"
-#include "ft.hpp"
CMT_PRAGMA_GNU(GCC diagnostic push)
#if CMT_HAS_WARNING("-Wshadow")
CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow")
#endif
+#if CMT_HAS_WARNING("-Wundefined-inline")
+CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wundefined-inline")
+#endif
CMT_PRAGMA_MSVC(warning(push))
CMT_PRAGMA_MSVC(warning(disable : 4100))
@@ -47,764 +46,6 @@ CMT_PRAGMA_MSVC(warning(disable : 4100))
namespace kfr
{
-#define DFT_ASSERT TESTO_ASSERT_ACTIVE
-
-template <typename T>
-constexpr size_t fft_vector_width = platform<T>::vector_width;
-
-template <typename T>
-struct dft_stage
-{
- size_t stage_size = 0;
- size_t data_size = 0;
- size_t temp_size = 0;
- u8* data = nullptr;
- size_t repeats = 1;
- size_t out_offset = 0;
- const char* name;
- bool recursion = false;
-
- void initialize(size_t size) { do_initialize(size); }
-
- KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp) { do_execute(out, in, temp); }
- virtual ~dft_stage() {}
-
-protected:
- virtual void do_initialize(size_t) {}
- virtual void do_execute(complex<T>*, const complex<T>*, u8* temp) = 0;
-};
-
-CMT_PRAGMA_GNU(GCC diagnostic push)
-#if CMT_HAS_WARNING("-Wassume")
-CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wassume")
-#endif
-
-namespace internal
-{
-
-template <size_t width, bool inverse, typename T>
-KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split_format*/, cbool_t<inverse>,
- const cvec<T, width>& w, const cvec<T, width>& tw)
-{
- cvec<T, width> ww = w;
- cvec<T, width> tw_ = tw;
- cvec<T, width> b1 = ww * dupeven(tw_);
- ww = swap<2>(ww);
-
- if (inverse)
- tw_ = -(tw_);
- ww = subadd(b1, ww * dupodd(tw_));
- return ww;
-}
-
-template <size_t width, bool use_br2, bool inverse, bool aligned, typename T>
-KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, cfalse_t, cfalse_t, cfalse_t, cbool_t<use_br2>,
- cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in,
- const complex<T>* twiddle)
-{
- const size_t N4 = N / 4;
- cvec<T, width> w1, w2, w3;
-
- cvec<T, width> sum02, sum13, diff02, diff13;
-
- cvec<T, width> a0, a1, a2, a3;
- a0 = cread<width, aligned>(in + 0);
- a2 = cread<width, aligned>(in + N4 * 2);
- sum02 = a0 + a2;
-
- a1 = cread<width, aligned>(in + N4);
- a3 = cread<width, aligned>(in + N4 * 3);
- sum13 = a1 + a3;
-
- cwrite<width, aligned>(out, sum02 + sum13);
- w2 = sum02 - sum13;
- cwrite<width, aligned>(out + N4 * (use_br2 ? 1 : 2),
- radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w2,
- cread<width, true>(twiddle + width)));
- diff02 = a0 - a2;
- diff13 = a1 - a3;
- if (inverse)
- {
- diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T()));
- diff13 = swap<2>(diff13);
- }
- else
- {
- diff13 = swap<2>(diff13);
- diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T()));
- }
-
- w1 = diff02 + diff13;
-
- cwrite<width, aligned>(out + N4 * (use_br2 ? 2 : 1),
- radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w1,
- cread<width, true>(twiddle + 0)));
- w3 = diff02 - diff13;
- cwrite<width, aligned>(out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(),
- w3, cread<width, true>(twiddle + width * 2)));
-}
-
-template <size_t width, bool inverse, typename T>
-KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, ctrue_t /*split_format*/, cbool_t<inverse>,
- const cvec<T, width>& w, const cvec<T, width>& tw)
-{
- vec<T, width> re1, im1, twre, twim;
- split(w, re1, im1);
- split(tw, twre, twim);
-
- const vec<T, width> b1re = re1 * twre;
- const vec<T, width> b1im = im1 * twre;
- if (inverse)
- return concat(b1re + im1 * twim, b1im - re1 * twim);
- else
- return concat(b1re - im1 * twim, b1im + re1 * twim);
-}
-
-template <size_t width, bool splitout, bool splitin, bool use_br2, bool inverse, bool aligned, typename T>
-KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, ctrue_t, cbool_t<splitout>, cbool_t<splitin>,
- cbool_t<use_br2>, cbool_t<inverse>, cbool_t<aligned>, complex<T>* out,
- const complex<T>* in, const complex<T>* twiddle)
-{
- const size_t N4 = N / 4;
- cvec<T, width> w1, w2, w3;
- constexpr bool read_split = !splitin && splitout;
- constexpr bool write_split = splitin && !splitout;
-
- vec<T, width> re0, im0, re1, im1, re2, im2, re3, im3;
-
- split(cread_split<width, aligned, read_split>(in + N4 * 0), re0, im0);
- split(cread_split<width, aligned, read_split>(in + N4 * 1), re1, im1);
- split(cread_split<width, aligned, read_split>(in + N4 * 2), re2, im2);
- split(cread_split<width, aligned, read_split>(in + N4 * 3), re3, im3);
-
- const vec<T, width> sum02re = re0 + re2;
- const vec<T, width> sum02im = im0 + im2;
- const vec<T, width> sum13re = re1 + re3;
- const vec<T, width> sum13im = im1 + im3;
-
- cwrite_split<width, aligned, write_split>(out, concat(sum02re + sum13re, sum02im + sum13im));
- w2 = concat(sum02re - sum13re, sum02im - sum13im);
- cwrite_split<width, aligned, write_split>(
- out + N4 * (use_br2 ? 1 : 2), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w2,
- cread<width, true>(twiddle + width)));
-
- const vec<T, width> diff02re = re0 - re2;
- const vec<T, width> diff02im = im0 - im2;
- const vec<T, width> diff13re = re1 - re3;
- const vec<T, width> diff13im = im1 - im3;
-
- (inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re);
- (inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re);
-
- cwrite_split<width, aligned, write_split>(
- out + N4 * (use_br2 ? 2 : 1), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w1,
- cread<width, true>(twiddle + 0)));
- cwrite_split<width, aligned, write_split>(
- out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w3,
- cread<width, true>(twiddle + width * 2)));
-}
-
-template <typename T>
-CMT_NOINLINE cvec<T, 1> calculate_twiddle(size_t n, size_t size)
-{
- if (n == 0)
- {
- return make_vector(static_cast<T>(1), static_cast<T>(0));
- }
- else if (n == size / 4)
- {
- return make_vector(static_cast<T>(0), static_cast<T>(-1));
- }
- else if (n == size / 2)
- {
- return make_vector(static_cast<T>(-1), static_cast<T>(0));
- }
- else if (n == size * 3 / 4)
- {
- return make_vector(static_cast<T>(0), static_cast<T>(1));
- }
- else
- {
- fbase kth = c_pi<fbase, 2> * (n / static_cast<fbase>(size));
- fbase tcos = +kfr::cos(kth);
- fbase tsin = -kfr::sin(kth);
- return make_vector(static_cast<T>(tcos), static_cast<T>(tsin));
- }
-}
-
-template <typename T, size_t width>
-KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_t nnstep, size_t size,
- bool split_format)
-{
- vec<T, 2 * width> result = T();
- CMT_LOOP_UNROLL
- for (size_t i = 0; i < width; i++)
- {
- const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size);
- result[i * 2] = r[0];
- result[i * 2 + 1] = r[1];
- }
- if (split_format)
- ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result);
- else
- ref_cast<cvec<T, width>>(twiddle[0]) = result;
- twiddle += width;
-}
-
-template <typename T, size_t width>
-CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format)
-{
- const size_t count = stage_size / 4;
- // DFT_ASSERT(width <= count);
- size_t nnstep = size / stage_size;
- CMT_LOOP_NOUNROLL
- for (size_t n = 0; n < count; n += width)
- {
- initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format);
- initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format);
- initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 3, nnstep * 3, size, split_format);
- }
-}
-
-#ifdef CMT_ARCH_X86
-#ifdef CMT_COMPILER_GNU
-#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr), 0, _MM_HINT_T0);
-#else
-#define KFR_PREFETCH(addr) _mm_prefetch(::kfr::ptr_cast<char>(addr), _MM_HINT_T0);
-#endif
-#else
-#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr));
-#endif
-
-template <typename T>
-KFR_SINTRIN void prefetch_one(const complex<T>* in)
-{
- KFR_PREFETCH(in);
-}
-
-template <typename T>
-KFR_SINTRIN void prefetch_four(size_t stride, const complex<T>* in)
-{
- KFR_PREFETCH(in);
- KFR_PREFETCH(in + stride);
- KFR_PREFETCH(in + stride * 2);
- KFR_PREFETCH(in + stride * 3);
-}
-
-template <typename Ntype, size_t width, bool splitout, bool splitin, bool prefetch, bool use_br2,
- bool inverse, bool aligned, typename T>
-KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t<splitout>, cbool_t<splitin>,
- cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
- complex<T>* out, const complex<T>* in, const complex<T>*& twiddle)
-{
- constexpr static size_t prefetch_offset = width * 8;
- const auto N4 = N / csize_t<4>();
- const auto N43 = N4 * csize_t<3>();
- CMT_ASSUME(blocks > 0);
- CMT_ASSUME(N > 0);
- CMT_ASSUME(N4 > 0);
- DFT_ASSERT(width <= N4);
- CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++)
- {
- CMT_PRAGMA_CLANG(clang loop unroll_count(2))
- for (size_t n2 = 0; n2 < N4; n2 += width)
- {
- if (prefetch)
- prefetch_four(N4, in + prefetch_offset);
- radix4_body(N, csize_t<width>(), cbool_t<(splitout || splitin)>(), cbool_t<splitout>(),
- cbool_t<splitin>(), cbool_t<use_br2>(), cbool_t<inverse>(), cbool_t<aligned>(), out,
- in, twiddle + n2 * 3);
- in += width;
- out += width;
- }
- in += N43;
- out += N43;
- }
- twiddle += N43;
- return {};
-}
-
-template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
-KFR_SINTRIN ctrue_t radix4_pass(csize_t<32>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
- cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
- complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
-{
- CMT_ASSUME(blocks > 0);
- constexpr static size_t prefetch_offset = 32 * 4;
- for (size_t b = 0; b < blocks; b++)
- {
- if (prefetch)
- prefetch_four(csize_t<64>(), out + prefetch_offset);
- cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7;
- split(cread<8, aligned>(out + 0), w0, w1);
- split(cread<8, aligned>(out + 8), w2, w3);
- split(cread<8, aligned>(out + 16), w4, w5);
- split(cread<8, aligned>(out + 24), w6, w7);
-
- butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7);
-
- w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>());
- w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>());
- w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>());
- w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>());
- w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>());
- w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>());
- w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>());
-
- cvec<T, 8> z0, z1, z2, z3;
- transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3);
-
- butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3);
- cwrite<32, aligned>(out, bitreverse<2>(concat(z0, z1, z2, z3)));
- out += 32;
- }
- return {};
-}
-
-template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
-KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
- cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
- complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
-{
- CMT_ASSUME(blocks > 0);
- constexpr static size_t prefetch_offset = width * 16;
- DFT_ASSERT(2 <= blocks);
- for (size_t b = 0; b < blocks; b += 2)
- {
- if (prefetch)
- prefetch_one(out + prefetch_offset);
-
- cvec<T, 8> vlo = cread<8, aligned>(out + 0);
- cvec<T, 8> vhi = cread<8, aligned>(out + 8);
- butterfly8<inverse>(vlo);
- butterfly8<inverse>(vhi);
- vlo = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vlo);
- vhi = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vhi);
- cwrite<8, aligned>(out, vlo);
- cwrite<8, aligned>(out + 8, vhi);
- out += 16;
- }
- return {};
-}
-
-template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
-KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
- cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
- complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
-{
- CMT_ASSUME(blocks > 0);
- constexpr static size_t prefetch_offset = width * 4;
- DFT_ASSERT(2 <= blocks);
- CMT_PRAGMA_CLANG(clang loop unroll_count(2))
- for (size_t b = 0; b < blocks; b += 2)
- {
- if (prefetch)
- prefetch_one(out + prefetch_offset);
-
- cvec<T, 16> vlo = cread<16, aligned>(out);
- cvec<T, 16> vhi = cread<16, aligned>(out + 16);
- butterfly4<4, inverse>(vlo);
- butterfly4<4, inverse>(vhi);
- apply_twiddles4<0, 4, 4, inverse>(vlo);
- apply_twiddles4<0, 4, 4, inverse>(vhi);
- vlo = digitreverse4<2>(vlo);
- vhi = digitreverse4<2>(vhi);
- butterfly4<4, inverse>(vlo);
- butterfly4<4, inverse>(vhi);
-
- use_br2 ? cbitreverse_write(out, vlo) : cdigitreverse4_write(out, vlo);
- use_br2 ? cbitreverse_write(out + 16, vhi) : cdigitreverse4_write(out + 16, vhi);
- out += 32;
- }
- return {};
-}
-
-template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T>
-KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t,
- cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>,
- complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
-{
- constexpr static size_t prefetch_offset = width * 4;
- CMT_ASSUME(blocks > 0);
- DFT_ASSERT(4 <= blocks);
- CMT_LOOP_NOUNROLL
- for (size_t b = 0; b < blocks; b += 4)
- {
- if (prefetch)
- prefetch_one(out + prefetch_offset);
-
- cvec<T, 16> v16 = cdigitreverse4_read<16, aligned>(out);
- butterfly4<4, inverse>(v16);
- cdigitreverse4_write<aligned>(out, v16);
-
- out += 4 * 4;
- }
- return {};
-}
-
-template <typename T, bool splitin, bool is_even, bool inverse>
-struct fft_stage_impl : dft_stage<T>
-{
- fft_stage_impl(size_t stage_size)
- {
- this->stage_size = stage_size;
- this->repeats = 4;
- this->recursion = true;
- this->data_size =
- align_up(sizeof(complex<T>) * stage_size / 4 * 3, platform<>::native_cache_alignment);
- }
-
-protected:
- constexpr static bool prefetch = true;
- constexpr static bool aligned = false;
- constexpr static size_t width = fft_vector_width<T>;
-
- virtual void do_initialize(size_t size) override final
- {
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- initialize_twiddles<T, width>(twiddle, this->stage_size, size, true);
- }
-
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
- {
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- if (splitin)
- in = out;
- const size_t stg_size = this->stage_size;
- CMT_ASSUME(stg_size >= 2048);
- CMT_ASSUME(stg_size % 2048 == 0);
- radix4_pass(stg_size, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<!is_even>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- }
-};
-
-template <typename T, bool splitin, size_t size, bool inverse>
-struct fft_final_stage_impl : dft_stage<T>
-{
- fft_final_stage_impl(size_t)
- {
- this->stage_size = size;
- this->out_offset = size;
- this->repeats = 4;
- this->recursion = true;
- this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, platform<>::native_cache_alignment);
- }
-
-protected:
- constexpr static size_t width = fft_vector_width<T>;
- constexpr static bool is_even = cometa::is_even(ilog2(size));
- constexpr static bool use_br2 = !is_even;
- constexpr static bool aligned = false;
- constexpr static bool prefetch = splitin;
-
- KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {}
- KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {}
-
- template <size_t N, bool pass_splitin>
- KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle)
- {
- constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
- constexpr size_t pass_width = const_min(width, N / 4);
- initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin);
- init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle);
- }
-
- virtual void do_initialize(size_t total_size) override final
- {
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle);
- }
-
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override
- {
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_stage(csize<size>, 1, cbool<splitin>, out, in, twiddle);
- }
-
- // KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
- // const complex<T>*& twiddle)
- // {
- // radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- // }
- //
- // KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
- // const complex<T>*& twiddle)
- // {
- // radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- // cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- // }
-
- KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
- const complex<T>*& twiddle)
- {
- radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- }
-
- KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
- const complex<T>*& twiddle)
- {
- radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- }
-
- template <size_t N, bool pass_splitin>
- KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out,
- const complex<T>* in, const complex<T>*& twiddle)
- {
- static_assert(N > 8, "");
- constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
- constexpr size_t pass_width = const_min(width, N / 4);
- static_assert(pass_width == width || (pass_split == pass_splitin), "");
- static_assert(pass_width <= N / 4, "");
- radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(),
- cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in,
- twiddle);
- final_stage(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
- }
-};
-
-template <typename T, bool is_even>
-struct fft_reorder_stage_impl : dft_stage<T>
-{
- fft_reorder_stage_impl(size_t stage_size)
- {
- this->stage_size = stage_size;
- log2n = ilog2(stage_size);
- this->data_size = 0;
- }
-
-protected:
- size_t log2n;
-
- virtual void do_initialize(size_t) override final {}
-
- virtual void do_execute(complex<T>* out, const complex<T>*, u8* /*temp*/) override final
- {
- fft_reorder(out, log2n, cbool_t<!is_even>());
- }
-};
-
-template <typename T, size_t log2n, bool inverse>
-struct fft_specialization;
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 1, inverse> : dft_stage<T>
-{
- fft_specialization(size_t) {}
-
-protected:
- constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
- {
- cvec<T, 1> a0, a1;
- split(cread<2, aligned>(in), a0, a1);
- cwrite<2, aligned>(out, concat(a0 + a1, a0 - a1));
- }
-};
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 2, inverse> : dft_stage<T>
-{
- fft_specialization(size_t) {}
-
-protected:
- constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
- {
- cvec<T, 1> a0, a1, a2, a3;
- split(cread<4>(in), a0, a1, a2, a3);
- butterfly(cbool_t<inverse>(), a0, a1, a2, a3, a0, a1, a2, a3);
- cwrite<4>(out, concat(a0, a1, a2, a3));
- }
-};
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 3, inverse> : dft_stage<T>
-{
- fft_specialization(size_t) {}
-
-protected:
- constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
- {
- cvec<T, 8> v8 = cread<8, aligned>(in);
- butterfly8<inverse>(v8);
- cwrite<8, aligned>(out, v8);
- }
-};
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 4, inverse> : dft_stage<T>
-{
- fft_specialization(size_t) {}
-
-protected:
- constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
- {
- cvec<T, 16> v16 = cread<16, aligned>(in);
- butterfly16<inverse>(v16);
- cwrite<16, aligned>(out, v16);
- }
-};
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 5, inverse> : dft_stage<T>
-{
- fft_specialization(size_t) {}
-
-protected:
- constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
- {
- cvec<T, 32> v32 = cread<32, aligned>(in);
- butterfly32<inverse>(v32);
- cwrite<32, aligned>(out, v32);
- }
-};
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 6, inverse> : dft_stage<T>
-{
- fft_specialization(size_t) {}
-
-protected:
- constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
- {
- butterfly64(cbool_t<inverse>(), cbool_t<aligned>(), out, in);
- }
-};
-
-template <typename T, bool inverse>
-struct fft_specialization<T, 7, inverse> : dft_stage<T>
-{
- fft_specialization(size_t)
- {
- this->stage_size = 128;
- this->data_size = align_up(sizeof(complex<T>) * 128 * 3 / 2, platform<>::native_cache_alignment);
- }
-
-protected:
- constexpr static bool aligned = false;
- constexpr static size_t width = fft_vector_width<T>;
- constexpr static bool use_br2 = true;
- constexpr static bool prefetch = false;
- constexpr static bool is_double = sizeof(T) == 8;
- constexpr static size_t final_size = is_double ? 8 : 32;
- constexpr static size_t split_format = final_size == 8;
-
- virtual void do_initialize(size_t total_size) override final
- {
- complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- initialize_twiddles<T, width>(twiddle, 128, total_size, split_format);
- initialize_twiddles<T, width>(twiddle, 32, total_size, split_format);
- initialize_twiddles<T, width>(twiddle, 8, total_size, split_format);
- }
-
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
- {
- const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass(csize_t<final_size>(), out, in, twiddle);
- fft_reorder(out, csize_t<7>());
- }
-
- KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
- {
- radix4_pass(128, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(32, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- radix4_pass(csize_t<8>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- }
-
- KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
- {
- radix4_pass(128, 1, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
- cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
- radix4_pass(csize_t<32>(), 4, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
- cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
- }
-};
-
-template <bool inverse>
-struct fft_specialization<float, 8, inverse> : dft_stage<float>
-{
- fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; }
-
-protected:
- virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final
- {
- complex<float>* scratch = ptr_cast<complex<float>>(temp);
- if (out == in)
- {
- butterfly16_multi_flip<0, inverse>(scratch, out);
- butterfly16_multi_flip<1, inverse>(scratch, out);
- butterfly16_multi_flip<2, inverse>(scratch, out);
- butterfly16_multi_flip<3, inverse>(scratch, out);
-
- butterfly16_multi_natural<0, inverse>(out, scratch);
- butterfly16_multi_natural<1, inverse>(out, scratch);
- butterfly16_multi_natural<2, inverse>(out, scratch);
- butterfly16_multi_natural<3, inverse>(out, scratch);
- }
- else
- {
- butterfly16_multi_flip<0, inverse>(out, in);
- butterfly16_multi_flip<1, inverse>(out, in);
- butterfly16_multi_flip<2, inverse>(out, in);
- butterfly16_multi_flip<3, inverse>(out, in);
-
- butterfly16_multi_natural<0, inverse>(out, out);
- butterfly16_multi_natural<1, inverse>(out, out);
- butterfly16_multi_natural<2, inverse>(out, out);
- butterfly16_multi_natural<3, inverse>(out, out);
- }
- }
-};
-
-template <bool inverse>
-struct fft_specialization<double, 8, inverse> : fft_final_stage_impl<double, false, 256, inverse>
-{
- using T = double;
- using fft_final_stage_impl<double, false, 256, inverse>::fft_final_stage_impl;
-
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
- {
- fft_final_stage_impl<double, false, 256, inverse>::do_execute(out, in, nullptr);
- fft_reorder(out, csize_t<8>());
- }
-};
-
-template <typename T, bool splitin, bool is_even>
-struct fft_stage_impl_t
-{
- template <bool inverse>
- using type = internal::fft_stage_impl<T, splitin, is_even, inverse>;
-};
-template <typename T, bool splitin, size_t size>
-struct fft_final_stage_impl_t
-{
- template <bool inverse>
- using type = internal::fft_final_stage_impl<T, splitin, size, inverse>;
-};
-template <typename T, bool is_even>
-struct fft_reorder_stage_impl_t
-{
- template <bool>
- using type = internal::fft_reorder_stage_impl<T, is_even>;
-};
-template <typename T, size_t log2n, bool aligned>
-struct fft_specialization_t
-{
- template <bool inverse>
- using type = internal::fft_specialization<T, log2n, inverse>;
-};
-} // namespace internal
-
namespace dft_type
{
constexpr cbools_t<true, true> both{};
@@ -813,6 +54,9 @@ constexpr cbools_t<false, true> inverse{};
} // namespace dft_type
template <typename T>
+struct dft_stage;
+
+template <typename T>
struct dft_plan
{
using dft_stage_ptr = std::unique_ptr<dft_stage<T>>;
@@ -821,28 +65,8 @@ struct dft_plan
size_t temp_size;
template <bool direct = true, bool inverse = true>
- dft_plan(size_t size, cbools_t<direct, inverse> type = dft_type::both)
- : size(size), temp_size(0), data_size(0)
- {
- if (is_poweroftwo(size))
- {
- const size_t log2n = ilog2(size);
- cswitch(csizes_t<1, 2, 3, 4, 5, 6, 7, 8>(), log2n,
- [&](auto log2n) {
- (void)log2n;
- this->add_stage<internal::fft_specialization_t<T, val_of(decltype(log2n)()),
- false>::template type>(size, type);
- },
- [&]() {
- cswitch(cfalse_true, is_even(log2n), [&](auto is_even) {
- this->make_fft(size, type, is_even, ctrue);
- this->add_stage<internal::fft_reorder_stage_impl_t<
- T, val_of(decltype(is_even)())>::template type>(size, type);
- });
- });
- initialize(type);
- }
- }
+ dft_plan(size_t size, cbools_t<direct, inverse> type = dft_type::both);
+
KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp, bool inverse = false) const
{
if (inverse)
@@ -850,6 +74,7 @@ struct dft_plan
else
execute_dft(cfalse, out, in, temp);
}
+ ~dft_plan();
template <bool inverse>
KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp, cbool_t<inverse> inv) const
{
@@ -878,126 +103,19 @@ protected:
std::vector<dft_stage_ptr> stages[2];
template <template <bool inverse> class Stage>
- void add_stage(size_t stage_size, cbools_t<true, true>)
- {
- dft_stage<T>* direct_stage = new Stage<false>(stage_size);
- direct_stage->name = nullptr;
- dft_stage<T>* inverse_stage = new Stage<true>(stage_size);
- inverse_stage->name = nullptr;
- this->data_size += direct_stage->data_size;
- this->temp_size += direct_stage->temp_size;
- stages[0].push_back(dft_stage_ptr(direct_stage));
- stages[1].push_back(dft_stage_ptr(inverse_stage));
- }
+ void add_stage(size_t stage_size, cbools_t<true, true>);
template <template <bool inverse> class Stage>
- void add_stage(size_t stage_size, cbools_t<true, false>)
- {
- dft_stage<T>* direct_stage = new Stage<false>(stage_size);
- direct_stage->name = nullptr;
- this->data_size += direct_stage->data_size;
- this->temp_size += direct_stage->temp_size;
- stages[0].push_back(dft_stage_ptr(direct_stage));
- }
+ void add_stage(size_t stage_size, cbools_t<true, false>);
template <template <bool inverse> class Stage>
- void add_stage(size_t stage_size, cbools_t<false, true>)
- {
- dft_stage<T>* inverse_stage = new Stage<true>(stage_size);
- inverse_stage->name = nullptr;
- this->data_size += inverse_stage->data_size;
- this->temp_size += inverse_stage->temp_size;
- stages[1].push_back(dft_stage_ptr(inverse_stage));
- }
+ void add_stage(size_t stage_size, cbools_t<false, true>);
template <bool direct, bool inverse, bool is_even, bool first>
- void make_fft(size_t stage_size, cbools_t<direct, inverse> type, cbool_t<is_even>, cbool_t<first>)
- {
- constexpr size_t final_size = is_even ? 1024 : 512;
-
- using fft_stage_impl_t = internal::fft_stage_impl_t<T, !first, is_even>;
- using fft_final_stage_impl_t = internal::fft_final_stage_impl_t<T, !first, final_size>;
-
- if (stage_size >= 2048)
- {
- add_stage<fft_stage_impl_t::template type>(stage_size, type);
-
- make_fft(stage_size / 4, cbools_t<direct, inverse>(), cbool_t<is_even>(), cfalse);
- }
- else
- {
- add_stage<fft_final_stage_impl_t::template type>(final_size, type);
- }
- }
+ void make_fft(size_t stage_size, cbools_t<direct, inverse> type, cbool_t<is_even>, cbool_t<first>);
template <bool direct, bool inverse>
- void initialize(cbools_t<direct, inverse>)
- {
- data = autofree<u8>(data_size);
- if (direct)
- {
- size_t offset = 0;
- for (dft_stage_ptr& stage : stages[0])
- {
- stage->data = data.data() + offset;
- stage->initialize(this->size);
- offset += stage->data_size;
- }
- }
- if (inverse)
- {
- size_t offset = 0;
- for (dft_stage_ptr& stage : stages[1])
- {
- stage->data = data.data() + offset;
- if (!direct)
- stage->initialize(this->size);
- offset += stage->data_size;
- }
- }
- }
+ void initialize(cbools_t<direct, inverse>);
template <bool inverse>
- KFR_INTRIN void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
- {
- size_t stack[32] = { 0 };
-
- const size_t count = stages[inverse].size();
-
- for (size_t depth = 0; depth < count;)
- {
- if (stages[inverse][depth]->recursion)
- {
- complex<T>* rout = out;
- const complex<T>* rin = in;
- size_t rdepth = depth;
- size_t maxdepth = depth;
- do
- {
- if (stack[rdepth] == stages[inverse][rdepth]->repeats)
- {
- stack[rdepth] = 0;
- rdepth--;
- }
- else
- {
- stages[inverse][rdepth]->execute(rout, rin, temp);
- rout += stages[inverse][rdepth]->out_offset;
- rin = rout;
- stack[rdepth]++;
- if (rdepth < count - 1 && stages[inverse][rdepth + 1]->recursion)
- rdepth++;
- else
- maxdepth = rdepth;
- }
- } while (rdepth != depth);
- depth = maxdepth + 1;
- }
- else
- {
- stages[inverse][depth]->execute(out, in, temp);
- depth++;
- }
- in = out;
- }
- }
+ void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const;
};
enum class dft_pack_format
@@ -1011,20 +129,7 @@ struct dft_plan_real : dft_plan<T>
{
size_t size;
template <bool direct = true, bool inverse = true>
- dft_plan_real(size_t size, cbools_t<direct, inverse> type = dft_type::both)
- : dft_plan<T>(size / 2, type), size(size), rtwiddle(size / 4)
- {
- using namespace internal;
-
- constexpr size_t width = fft_vector_width<T> * 2;
-
- block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) {
- constexpr size_t width = val_of(decltype(w)());
- cwrite<width>(
- rtwiddle.data() + i,
- cossin(dup(-constants<T>::pi * ((enumerate<T, width>() + i + size / 4) / (size / 2)))));
- });
- }
+ dft_plan_real(size_t size, cbools_t<direct, inverse> type = dft_type::both);
KFR_INTRIN void execute(complex<T>* out, const T* in, u8* temp,
dft_pack_format fmt = dft_pack_format::CCs) const
@@ -1059,89 +164,8 @@ struct dft_plan_real : dft_plan<T>
private:
univector<complex<T>> rtwiddle;
- void to_fmt(complex<T>* out, dft_pack_format fmt) const
- {
- using namespace internal;
- size_t csize =
- this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2
-
- constexpr size_t width = fft_vector_width<T> * 2;
- const cvec<T, 1> dc = cread<1>(out);
- const size_t count = csize / 2;
-
- block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
- constexpr size_t width = val_of(decltype(w)());
- constexpr size_t widthm1 = width - 1;
- const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
- const cvec<T, width> fpk = cread<width>(out + i);
- const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1)));
-
- const cvec<T, width> f1k = fpk + fpnk;
- const cvec<T, width> f2k = fpk - fpnk;
- const cvec<T, width> t = cmul(f2k, tw);
- cwrite<width>(out + i, T(0.5) * (f1k + t));
- cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(T(0.5) * (f1k - t))));
- });
-
- {
- size_t k = csize / 2;
- const cvec<T, 1> fpk = cread<1>(out + k);
- const cvec<T, 1> fpnk = negodd(fpk);
- cwrite<1>(out + k, fpnk);
- }
- if (fmt == dft_pack_format::CCs)
- {
- cwrite<1>(out, pack(dc[0] + dc[1], 0));
- cwrite<1>(out + csize, pack(dc[0] - dc[1], 0));
- }
- else
- {
- cwrite<1>(out, pack(dc[0] + dc[1], dc[0] - dc[1]));
- }
- }
- void from_fmt(complex<T>* out, const complex<T>* in, dft_pack_format fmt) const
- {
- using namespace internal;
-
- const size_t csize = this->size / 2;
-
- cvec<T, 1> dc;
-
- if (fmt == dft_pack_format::CCs)
- {
- dc = pack(in[0].real() + in[csize].real(), in[0].real() - in[csize].real());
- }
- else
- {
- dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag());
- }
-
- constexpr size_t width = fft_vector_width<T> * 2;
- const size_t count = csize / 2;
-
- block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) {
- i++;
- constexpr size_t width = val_of(decltype(w)());
- constexpr size_t widthm1 = width - 1;
- const cvec<T, width> tw = cread<width>(rtwiddle.data() + i);
- const cvec<T, width> fpk = cread<width>(in + i);
- const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1)));
-
- const cvec<T, width> f1k = fpk + fpnk;
- const cvec<T, width> f2k = fpk - fpnk;
- const cvec<T, width> t = cmul_conj(f2k, tw);
- cwrite<width>(out + i, f1k + t);
- cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t)));
- });
-
- {
- size_t k = csize / 2;
- const cvec<T, 1> fpk = cread<1>(in + k);
- const cvec<T, 1> fpnk = 2 * negodd(fpk);
- cwrite<1>(out + k, fpnk);
- }
- cwrite<1>(out, dc);
- }
+ void to_fmt(complex<T>* out, dft_pack_format fmt) const;
+ void from_fmt(complex<T>* out, const complex<T>* in, dft_pack_format fmt) const;
};
template <typename T, size_t Tag1, size_t Tag2, size_t Tag3>
diff --git a/include/kfr/testo/assert.hpp b/include/kfr/testo/assert.hpp
@@ -58,15 +58,19 @@ bool check_assertion(const half_comparison<L>& comparison, const char* expr, con
TESTO_BREAKPOINT; \
} while (0)
+#define TESTO_ASSERT_INACTIVE(...) \
+ do \
+ { \
+ } while (false && (__VA_ARGS__))
+
#if defined(TESTO_ASSERTION_ON) || !(defined(NDEBUG) || defined(TESTO_ASSERTION_OFF))
#define TESTO_ASSERT TESTO_ASSERT_ACTIVE
#else
-#define TESTO_ASSERT(...) \
- do \
- { \
- } while (false && (__VA_ARGS__))
+
+#define TESTO_ASSERT TESTO_ASSERT_INACTIVE
+
#endif
#ifndef TESTO_NO_SHORT_MACROS
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
@@ -39,7 +39,7 @@ endif ()
find_package(MPFR)
add_executable(intrinsic_test intrinsic_test.cpp ${KFR_SRC} ${TEST_SRC})
-add_executable(dft_test dft_test.cpp ${KFR_SRC} ${TEST_SRC})
+add_executable(dft_test dft_test.cpp ${KFR_SRC} ${TEST_SRC} ../include/kfr/dft/dft-src.cpp)
if (MPFR_FOUND)
include_directories(${MPFR_INCLUDE_DIR})
add_executable(transcendental_test transcendental_test.cpp ${KFR_SRC} ${TEST_SRC})