commit 830601705db50b1f174ff407661b4326b70f9655
parent 6aea976a464de59d522d0c629e64bf0c044e6777
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Mon, 13 Jan 2025 11:45:38 +0100
Improve dft performance on arm64
Diffstat:
3 files changed, 44 insertions(+), 11 deletions(-)
diff --git a/src/dft/bitrev.hpp b/src/dft/bitrev.hpp
@@ -49,15 +49,22 @@ constexpr inline static size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev
template <size_t Bits>
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x)
{
+#ifdef CMT_ARCH_NEON
+ return __builtin_bitreverse32(x) >> (32 - Bits);
+#else
if constexpr (Bits > bitrev_table_log2N)
return bitreverse<Bits>(x);
return data::bitrev_table[x] >> (bitrev_table_log2N - Bits);
+#endif
}
template <bool use_table>
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t<use_table>)
{
+#ifdef CMT_ARCH_NEON
+ return __builtin_bitreverse32(x) >> (32 - bits);
+#else
if constexpr (use_table)
{
return data::bitrev_table[x] >> (bitrev_table_log2N - bits);
@@ -66,10 +73,17 @@ CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t<use_
{
return bitreverse<32>(x) >> (32 - bits);
}
+#endif
}
CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
{
+#ifdef CMT_ARCH_NEON
+ x = __builtin_bitreverse32(x);
+ x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
+ x = x >> (32 - bits);
+ return x;
+#else
if (bits > bitrev_table_log2N)
{
if (bits <= 16)
@@ -82,6 +96,7 @@ CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
x = x >> (bitrev_table_log2N - bits);
return x;
+#endif
}
template <size_t log2n, size_t bitrev, typename T>
diff --git a/src/dft/fft-impl.hpp b/src/dft/fft-impl.hpp
@@ -52,22 +52,30 @@ template <typename T>
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection;
template <>
-inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<float>{ (1ull << 15) - 1 };
+inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<float>{
+#ifdef CMT_ARCH_NEON
+ 0
+#else
+ (1ull << 15) - 1
+#endif
+};
template <>
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<double>{ 0 };
template <typename T>
-constexpr bool inline use_autosort(size_t log2n)
+inline bool use_autosort(size_t log2n)
{
return fft_algorithm_selection<T>[log2n];
}
+#ifndef CMT_ARCH_NEON
+#define KFR_AUTOSORT_FOR_2048
#define KFR_AUTOSORT_FOR_128D
#define KFR_AUTOSORT_FOR_256D
#define KFR_AUTOSORT_FOR_512
#define KFR_AUTOSORT_FOR_1024
-#define KFR_AUTOSORT_FOR_2048
+#endif
#ifdef CMT_ARCH_AVX
template <>
@@ -855,7 +863,11 @@ template <typename T>
struct fft_config
{
constexpr static inline const bool recursion = true;
- constexpr static inline const bool prefetch = true;
+#ifdef CMT_ARCH_NEON
+ constexpr static inline const bool prefetch = false;
+#else
+ constexpr static inline const bool prefetch = true;
+#endif
constexpr static inline const size_t process_width =
const_max(static_cast<size_t>(1), vector_capacity<T> / 16);
};
@@ -1606,7 +1618,7 @@ struct fft_specialization<T, 10> : fft_final_stage_impl<T, false, 1024>
{
fft_final_stage_impl<T, false, 1024>::template do_execute<inverse>(out, in, nullptr);
if (this->need_reorder)
- fft_reorder(out, 10, cfalse);
+ fft_reorder(out, csize_t<10>{}, cbool_t<always_br2>{});
}
};
#endif
@@ -1649,8 +1661,6 @@ struct fft_specialization<T, 11> : dft_stage<T>
radix8_autosort_pass_last(256, csize<width>, no, no, no, cbool<inverse>, out, out, tw);
}
};
-
-#else
#endif
template <bool is_even, bool first, typename T, bool autosort>
@@ -1768,7 +1778,13 @@ KFR_INTRINSIC void init_fft(dft_plan<T>* self, size_t size, dft_order)
{
const size_t log2n = ilog2(size);
cswitch(
- csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>(), log2n,
+ csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
+#ifdef KFR_AUTOSORT_FOR_2048
+ ,
+ 11
+#endif
+ >(),
+ log2n,
[&](auto log2n)
{
(void)log2n;
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -33,11 +33,13 @@ constexpr ctypes_t<float, double> dft_float_types{};
constexpr ctypes_t<float> dft_float_types{};
#endif
-#if defined(CMT_ARCH_X86) && !defined(KFR_NO_PERF_TESTS)
+#if !defined(KFR_NO_PERF_TESTS)
static void full_barrier()
{
-#ifdef CMT_COMPILER_GNU
+#if defined(CMT_ARCH_NEON)
+ asm volatile("dmb ish" ::: "memory");
+#elif defined(CMT_COMPILER_GNU)
asm volatile("mfence" ::: "memory");
#else
_ReadWriteBarrier();
@@ -235,7 +237,7 @@ TEST(fft_accuracy)
if (is_even(size))
{
- index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
+ index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
univector<float_type> in = truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), csize);