commit 55e8c107122016518884e07ebe6b21704b27cb31
parent 07e4dc5314a8c1dcaf82af8eeca978599788b790
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Tue, 13 Oct 2020 12:35:30 +0100
DFT: calculate width from vector capacity
Diffstat:
7 files changed, 121 insertions(+), 66 deletions(-)
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -291,8 +291,13 @@ protected:
}
else
{
- stages[depth]->execute(cbool<inverse>, select_out(depth, out, scratch),
- select_in(depth, out, in, scratch, in_scratch), temp);
+ size_t offset = 0;
+ while (offset < this->size)
+ {
+ stages[depth]->execute(cbool<inverse>, select_out(depth, out, scratch) + offset,
+ select_in(depth, out, in, scratch, in_scratch) + offset, temp);
+ offset += stages[depth]->stage_size;
+ }
depth++;
}
}
diff --git a/include/kfr/dft/impl/bitrev.hpp b/include/kfr/dft/impl/bitrev.hpp
@@ -42,12 +42,13 @@ inline namespace CMT_ARCH_NAME
namespace intrinsics
{
-constexpr bool fft_reorder_aligned = false;
+constexpr inline static bool fft_reorder_aligned = false;
+
+constexpr inline static size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table));
template <size_t Bits>
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x)
{
- constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table));
if (Bits > bitrev_table_log2N)
return bitreverse<Bits>(x);
@@ -56,7 +57,6 @@ CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x)
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits)
{
- constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table));
if (bits > bitrev_table_log2N)
return bitreverse<32>(x) >> (32 - bits);
@@ -65,7 +65,6 @@ CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits)
CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
{
- constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table));
if (bits > bitrev_table_log2N)
return digitreverse4<32>(x) >> (32 - bits);
diff --git a/include/kfr/dft/impl/dft-impl.hpp b/include/kfr/dft/impl/dft-impl.hpp
@@ -82,12 +82,13 @@ struct dft_stage_fixed_impl : dft_stage<T>
{
dft_stage_fixed_impl(size_t, size_t iterations, size_t blocks)
{
- this->name = type_name<decltype(*this)>();
- this->radix = fixed_radix;
- this->blocks = blocks;
- this->repeats = iterations;
- this->recursion = false; // true;
- this->data_size = align_up((this->repeats * (fixed_radix - 1)) * sizeof(complex<T>),
+ this->name = type_name<decltype(*this)>();
+ this->radix = fixed_radix;
+ this->blocks = blocks;
+ this->repeats = iterations;
+ this->recursion = false; // true;
+ this->stage_size = fixed_radix * iterations * blocks;
+ this->data_size = align_up((this->repeats * (fixed_radix - 1)) * sizeof(complex<T>),
platform<>::native_cache_alignment);
}
@@ -125,6 +126,7 @@ struct dft_stage_fixed_final_impl : dft_stage<T>
this->radix = fixed_radix;
this->blocks = blocks;
this->repeats = iterations;
+ this->stage_size = fixed_radix * iterations * blocks;
this->recursion = false;
this->can_inplace = false;
}
@@ -204,6 +206,7 @@ struct dft_arblen_stage_impl : dft_stage<T>
this->recursion = false;
this->can_inplace = false;
this->temp_size = plan.temp_size;
+ this->stage_size = size;
chirp_ = render(cexp(sqr(linspace(T(1) - size, size - T(1), size * 2 - 1, true, true)) *
complex<T>(0, -1) * c_pi<T> / size));
@@ -259,6 +262,7 @@ struct dft_special_stage_impl : dft_stage<T>
this->repeats = 1;
this->recursion = false;
this->can_inplace = false;
+ this->stage_size = size;
this->temp_size = stage1.temp_size + stage2.temp_size + sizeof(complex<T>) * size;
this->data_size = stage1.data_size + stage2.data_size;
}
@@ -300,6 +304,7 @@ struct dft_stage_generic_impl : dft_stage<T>
this->repeats = iterations;
this->recursion = false; // true;
this->can_inplace = false;
+ this->stage_size = radix * iterations * blocks;
this->temp_size = align_up(sizeof(complex<T>) * radix, platform<>::native_cache_alignment);
this->data_size =
align_up(sizeof(complex<T>) * sqr(this->radix / 2), platform<>::native_cache_alignment);
@@ -411,6 +416,7 @@ struct dft_reorder_stage_impl : dft_stage<T>
this->inner_size *= radices[r];
this->size *= radices[r];
}
+ this->stage_size = this->size;
}
protected:
diff --git a/include/kfr/dft/impl/fft-impl.hpp b/include/kfr/dft/impl/fft-impl.hpp
@@ -55,7 +55,7 @@ KFR_INTRINSIC cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*spl
cvec<T, width> b1 = ww * dupeven(tw_);
ww = swap<2>(ww);
- if (inverse)
+ if constexpr (inverse)
tw_ = -(tw_);
ww = subadd(b1, ww * dupodd(tw_));
return ww;
@@ -87,7 +87,7 @@ KFR_INTRINSIC void radix4_body(size_t N, csize_t<width>, cfalse_t, cfalse_t, cfa
cread<width, true>(twiddle + width)));
diff02 = a0 - a2;
diff13 = a1 - a3;
- if (inverse)
+ if constexpr (inverse)
{
diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T()));
diff13 = swap<2>(diff13);
@@ -118,7 +118,7 @@ KFR_INTRINSIC cvec<T, width> radix4_apply_twiddle(csize_t<width>, ctrue_t /*spli
const vec<T, width> b1re = re1 * twre;
const vec<T, width> b1im = im1 * twre;
- if (inverse)
+ if constexpr (inverse)
return concat(b1re + im1 * twim, b1im - re1 * twim);
else
return concat(b1re - im1 * twim, b1im + re1 * twim);
@@ -200,6 +200,7 @@ template <typename T, size_t width>
KFR_INTRINSIC void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_t nnstep, size_t size,
bool split_format)
{
+ static_assert(width > 0, "width cannot be zero");
vec<T, 2 * width> result = T();
CMT_LOOP_UNROLL
for (size_t i = 0; i < width; i++)
@@ -218,6 +219,7 @@ KFR_INTRINSIC void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, siz
template <typename T, size_t width>
CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format)
{
+ static_assert(width > 0, "width cannot be zero");
const size_t count = stage_size / 4;
size_t nnstep = size / stage_size;
DFT_ASSERT(width <= count);
@@ -271,6 +273,7 @@ KFR_INTRINSIC cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool
cbool_t<aligned>, complex<T>* out, const complex<T>* in,
const complex<T>*& twiddle)
{
+ static_assert(width > 0, "width cannot be zero");
constexpr static size_t prefetch_offset = width * 8;
const auto N4 = N / csize_t<4>();
const auto N43 = N4 * csize_t<3>();
@@ -283,7 +286,7 @@ KFR_INTRINSIC cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool
CMT_PRAGMA_CLANG(clang loop unroll_count(2))
for (size_t n2 = 0; n2 < N4; n2 += width)
{
- if (prefetch)
+ if constexpr (prefetch)
prefetch_four(N4, in + prefetch_offset);
radix4_body(N, csize_t<width>(), cbool_t<(splitout || splitin)>(), cbool_t<splitout>(),
cbool_t<splitin>(), cbool_t<use_br2>(), cbool_t<inverse>(), cbool_t<aligned>(), out,
@@ -307,7 +310,7 @@ KFR_INTRINSIC ctrue_t radix4_pass(csize_t<32>, size_t blocks, csize_t<width>, cf
constexpr static size_t prefetch_offset = 32 * 4;
for (size_t b = 0; b < blocks; b++)
{
- if (prefetch)
+ if constexpr (prefetch)
prefetch_four(csize_t<64>(), out + prefetch_offset);
cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7;
split(cread_split<8, aligned, splitin>(out + 0), w0, w1);
@@ -345,7 +348,7 @@ KFR_INTRINSIC ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfa
constexpr static size_t prefetch_offset = width * 16;
for (size_t b = 0; b < blocks; b += 2)
{
- if (prefetch)
+ if constexpr (prefetch)
prefetch_one(out + prefetch_offset);
cvec<T, 8> vlo = cread<8, aligned>(out + 0);
@@ -372,7 +375,7 @@ KFR_INTRINSIC ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cf
CMT_PRAGMA_CLANG(clang loop unroll_count(2))
for (size_t b = 0; b < blocks; b += 2)
{
- if (prefetch)
+ if constexpr (prefetch)
prefetch_one(out + prefetch_offset);
cvec<T, 16> vlo = cread<16, aligned>(out);
@@ -399,12 +402,11 @@ KFR_INTRINSIC ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfa
complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/)
{
constexpr static size_t prefetch_offset = width * 4;
- CMT_ASSUME(blocks > 0);
- DFT_ASSERT(4 <= blocks);
- CMT_LOOP_NOUNROLL
+ CMT_ASSUME(blocks > 8);
+ DFT_ASSERT(8 <= blocks);
for (size_t b = 0; b < blocks; b += 4)
{
- if (prefetch)
+ if constexpr (prefetch)
prefetch_one(out + prefetch_offset);
cvec<T, 16> v16 = cdigitreverse4_read<16, aligned>(out);
@@ -416,6 +418,15 @@ KFR_INTRINSIC ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfa
return {};
}
+template <typename T>
+struct fft_config
+{
+ constexpr static inline const bool recursion = true;
+ constexpr static inline const bool prefetch = true;
+ constexpr static inline const size_t process_width =
+ const_max(static_cast<size_t>(1), vector_capacity<T> / 16);
+};
+
template <typename T, bool splitin, bool is_even>
struct fft_stage_impl : dft_stage<T>
{
@@ -425,14 +436,14 @@ struct fft_stage_impl : dft_stage<T>
this->radix = 4;
this->stage_size = stage_size;
this->repeats = 4;
- this->recursion = true;
+ this->recursion = fft_config<T>::recursion;
this->data_size =
align_up(sizeof(complex<T>) * stage_size / 4 * 3, platform<>::native_cache_alignment);
}
- constexpr static bool prefetch = true;
+ constexpr static bool prefetch = fft_config<T>::prefetch;
constexpr static bool aligned = false;
- constexpr static size_t width = fft_vector_width<T>;
+ constexpr static size_t width = fft_config<T>::process_width;
virtual void do_initialize(size_t size) override final
{
@@ -445,7 +456,7 @@ struct fft_stage_impl : dft_stage<T>
KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- if (splitin)
+ if constexpr (splitin)
in = out;
const size_t stg_size = this->stage_size;
CMT_ASSUME(stg_size >= 2048);
@@ -465,27 +476,29 @@ struct fft_final_stage_impl : dft_stage<T>
this->stage_size = size;
this->out_offset = size;
this->repeats = 4;
- this->recursion = true;
+ this->recursion = fft_config<T>::recursion;
this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, platform<>::native_cache_alignment);
}
- constexpr static size_t width = fft_vector_width<T>;
+ constexpr static size_t width = fft_config<T>::process_width;
constexpr static bool is_even = cometa::is_even(ilog2(size));
constexpr static bool use_br2 = !is_even;
constexpr static bool aligned = false;
- constexpr static bool prefetch = splitin;
+ constexpr static bool prefetch = fft_config<T>::prefetch && splitin;
KFR_MEM_INTRINSIC void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {}
KFR_MEM_INTRINSIC void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {}
+ static constexpr bool get_pass_splitout(size_t N) { return N / 4 > 8 && N / 4 / 4 >= width; }
+
template <size_t N, bool pass_splitin>
KFR_MEM_INTRINSIC void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>,
complex<T>*& twiddle)
{
- constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
- constexpr size_t pass_width = const_min(width, N / 4);
- initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin);
- init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle);
+ constexpr bool pass_splitout = get_pass_splitout(N);
+ constexpr size_t pass_width = const_min(width, N / 4);
+ initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_splitout || pass_splitin);
+ init_twiddles(csize<N / 4>, total_size, cbool<pass_splitout>, twiddle);
}
virtual void do_initialize(size_t total_size) override final
@@ -539,14 +552,14 @@ struct fft_final_stage_impl : dft_stage<T>
const complex<T>* in, const complex<T>*& twiddle)
{
static_assert(N > 8, "");
- constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width;
- constexpr size_t pass_width = const_min(width, N / 4);
- static_assert(pass_width == width || (pass_split == pass_splitin), "");
+ constexpr bool pass_splitout = get_pass_splitout(N);
+ constexpr size_t pass_width = const_min(width, N / 4);
+ static_assert(pass_width == width || !pass_splitin, "");
static_assert(pass_width <= N / 4, "");
- radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(),
+ radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_splitout>, cbool_t<pass_splitin>(),
cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in,
twiddle);
- final_stage<inverse>(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
+ final_stage<inverse>(csize<N / 4>, invN * 4, cbool<pass_splitout>, out, out, twiddle);
}
};
@@ -670,9 +683,10 @@ struct fft_specialization<T, 6> : dft_stage<T>
}
};
-template <typename T>
-struct fft_specialization<T, 7> : dft_stage<T>
+template <>
+struct fft_specialization<double, 7> : dft_stage<double>
{
+ using T = double;
fft_specialization(size_t)
{
this->name = type_name<decltype(*this)>();
@@ -681,12 +695,10 @@ struct fft_specialization<T, 7> : dft_stage<T>
}
constexpr static bool aligned = false;
- constexpr static size_t width = vector_width<T>;
+ constexpr static size_t width = const_min(fft_config<T>::process_width, size_t(8));
constexpr static bool use_br2 = true;
constexpr static bool prefetch = false;
- constexpr static bool is_double = sizeof(T) == 8;
- constexpr static size_t final_size = is_double ? 8 : 32;
- constexpr static size_t split_format = final_size == 8;
+ constexpr static size_t split_format = true;
virtual void do_initialize(size_t total_size) override final
{
@@ -701,31 +713,54 @@ struct fft_specialization<T, 7> : dft_stage<T>
KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass<inverse>(csize_t<final_size>(), out, in, twiddle);
- if (this->need_reorder)
- fft_reorder(out, csize_t<7>());
- }
-
- template <bool inverse>
- KFR_MEM_INTRINSIC void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in,
- const complex<T>* twiddle)
- {
radix4_pass(128, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
radix4_pass(32, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(),
cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
radix4_pass(csize_t<8>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ if (this->need_reorder)
+ fft_reorder(out, csize_t<7>());
+ }
+};
+
+template <>
+struct fft_specialization<float, 7> : dft_stage<float>
+{
+ using T = float;
+ fft_specialization(size_t)
+ {
+ this->name = type_name<decltype(*this)>();
+ this->stage_size = 128;
+ this->data_size = align_up(sizeof(complex<T>) * 128 * 3 / 2, platform<>::native_cache_alignment);
}
+ constexpr static bool aligned = false;
+ constexpr static size_t width = const_min(fft_config<T>::process_width, size_t(16));
+ constexpr static bool use_br2 = true;
+ constexpr static bool prefetch = false;
+ constexpr static size_t final_size = 32;
+ constexpr static size_t split_format = false;
+
+ virtual void do_initialize(size_t total_size) override final
+ {
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ initialize_twiddles<T, width>(twiddle, 128, total_size, split_format);
+ initialize_twiddles<T, width>(twiddle, 32, total_size, split_format);
+ initialize_twiddles<T, width>(twiddle, 8, total_size, split_format);
+ }
+
+ DFT_STAGE_FN
template <bool inverse>
- KFR_MEM_INTRINSIC void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in,
- const complex<T>* twiddle)
+ KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
radix4_pass(128, 1, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle);
radix4_pass(csize_t<32>(), 4, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(),
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
+ if (this->need_reorder)
+ fft_reorder(out, csize_t<7>());
}
};
diff --git a/include/kfr/simd/constants.hpp b/include/kfr/simd/constants.hpp
@@ -89,7 +89,7 @@ public:
};
template <size_t Value>
-constexpr size_t force_compiletime_size_t = Value;
+constexpr inline size_t force_compiletime_size_t = Value;
CMT_PRAGMA_GNU(GCC diagnostic pop)
diff --git a/include/kfr/simd/vec.hpp b/include/kfr/simd/vec.hpp
@@ -167,16 +167,23 @@ struct compoundcast<vec<vec<vec<T, N1>, N2>, N3>>
static vec<vec<vec<T, N1>, N2>, N3> from_flat(const vec<T, N1 * N2 * N3>& x) { return x.v; }
};
+
+template <typename T, size_t N_>
+inline constexpr size_t vec_alignment =
+ const_max(alignof(intrinsics::simd<typename compound_type_traits<T>::deep_subtype,
+ const_max(size_t(1), N_) * compound_type_traits<T>::deep_width>),
+ const_min(size_t(platform<>::native_vector_alignment),
+ next_poweroftwo(sizeof(typename compound_type_traits<T>::deep_subtype) *
+ const_max(size_t(1), N_) * compound_type_traits<T>::deep_width)));
+
} // namespace internal
-template <typename T, size_t N>
-struct alignas(force_compiletime_size_t<
- const_max(alignof(intrinsics::simd<typename compound_type_traits<T>::deep_subtype,
- N * compound_type_traits<T>::deep_width>),
- const_min(size_t(platform<>::native_vector_alignment),
- next_poweroftwo(sizeof(typename compound_type_traits<T>::deep_subtype) *
- N * compound_type_traits<T>::deep_width)))>) vec
+template <typename T, size_t N_>
+struct alignas(internal::vec_alignment<T, N_>) vec
{
+ static_assert(N_ > 0, "vec<T, N>: vector width cannot be zero");
+
+ constexpr static inline size_t N = const_max(size_t(1), N_);
static constexpr vec_shape<T, N> shape() CMT_NOEXCEPT { return {}; }
// type and size
@@ -187,8 +194,8 @@ struct alignas(force_compiletime_size_t<
using ST = typename compound_type_traits<T>::deep_subtype;
using scalar_type = ST;
- constexpr static size_t SW = compound_type_traits<T>::deep_width;
- constexpr static size_t SN = N * SW;
+ constexpr static inline size_t SW = compound_type_traits<T>::deep_width;
+ constexpr static inline size_t SN = N * SW;
constexpr static size_t scalar_size() CMT_NOEXCEPT { return SN; }
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
@@ -109,6 +109,8 @@ set(ALL_TESTS_CPP
if (ENABLE_DFT)
list(APPEND ALL_TESTS_CPP dft_test.cpp)
+
+ add_executable(dft_test dft_test.cpp)
endif ()
find_package(MPFR)
@@ -136,6 +138,7 @@ target_compile_definitions(all_tests PRIVATE KFR_NO_MAIN)
target_link_libraries(all_tests kfr use_arch)
if (ENABLE_DFT)
target_link_libraries(all_tests kfr_dft)
+ target_link_libraries(dft_test kfr_dft)
endif ()
target_link_libraries(all_tests kfr_io)