kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 79c45e4968f694e92a3cc590650fcec7c1bade93
parent 194011494c612013f9dc8d75e96bd7fbd97e9b75
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Tue, 27 Nov 2018 21:45:55 +0000

Refactor DFT sources

Diffstat:
MCMakeLists.txt | 9++++++++-
Mexamples/CMakeLists.txt | 19++++++++++++++-----
Dinclude/kfr/dft/bitrev.hpp | 390-------------------------------------------------------------------------------
Dinclude/kfr/dft/dft-src.cpp | 1971-------------------------------------------------------------------------------
Dinclude/kfr/dft/ft.hpp | 1760-------------------------------------------------------------------------------
Ainclude/kfr/dft/impl/bitrev.hpp | 390+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/convolution-impl.cpp | 204+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/dft-impl-f32.cpp | 29+++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/dft-impl-f64.cpp | 30++++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/dft-impl.hpp | 1689+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/dft-src.cpp | 130+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/dft-templates.hpp | 44++++++++++++++++++++++++++++++++++++++++++++
Ainclude/kfr/dft/impl/ft.hpp | 1760+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msources.cmake | 6++++--
Mtests/CMakeLists.txt | 10+++++-----
15 files changed, 4307 insertions(+), 4134 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt @@ -41,6 +41,13 @@ add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) option(ENABLE_TESTS "Enable tests and examples. This changes many compiler flags" OFF) +set(KFR_DFT_SRC + ${CMAKE_SOURCE_DIR}/include/kfr/dft/impl/dft-src.cpp + ${CMAKE_SOURCE_DIR}/include/kfr/dft/dft_c.h + ${CMAKE_SOURCE_DIR}/include/kfr/dft/impl/dft-impl-f32.cpp + ${CMAKE_SOURCE_DIR}/include/kfr/dft/impl/dft-impl-f64.cpp + ${CMAKE_SOURCE_DIR}/include/kfr/dft/impl/convolution-impl.cpp) + if (ENABLE_TESTS) if (IOS) @@ -89,5 +96,5 @@ add_library(kfr INTERFACE) target_sources(kfr INTERFACE ${KFR_SRC}) target_include_directories(kfr INTERFACE include) -add_library(kfr_dft include/kfr/dft/dft-src.cpp include/kfr/dft/dft_c.h) +add_library(kfr_dft ${KFR_DFT_SRC}) target_link_libraries(kfr_dft kfr) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt @@ -21,8 +21,17 @@ file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/svg) include_directories(../include) -add_executable(biquads biquads.cpp ${KFR_SRC}) -add_executable(window window.cpp ${KFR_SRC}) -add_executable(fir fir.cpp ${KFR_SRC}) -add_executable(sample_rate_conversion sample_rate_conversion.cpp ${KFR_SRC}) -add_executable(dft dft.cpp ${KFR_SRC} ${DFT_SRC} ../include/kfr/dft/dft-src.cpp) +add_executable(biquads biquads.cpp) +target_link_libraries(biquads kfr) + +add_executable(window window.cpp) +target_link_libraries(window kfr) + +add_executable(fir fir.cpp) +target_link_libraries(fir kfr) + +add_executable(sample_rate_conversion sample_rate_conversion.cpp) +target_link_libraries(sample_rate_conversion kfr) + +add_executable(dft dft.cpp) +target_link_libraries(dft kfr kfr_dft) diff --git a/include/kfr/dft/bitrev.hpp b/include/kfr/dft/bitrev.hpp @@ -1,390 +0,0 @@ -/** @addtogroup dft - * @{ - */ -/* - Copyright (C) 2016 D Levin (https://www.kfrlib.com) - This file is part of KFR - - KFR is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - KFR is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with KFR. - - If GPL is not suitable for your project, you must purchase a commercial license to use KFR. - Buying a commercial license is mandatory as soon as you develop commercial activities without - disclosing the source code of your own applications. - See https://www.kfrlib.com for details. - */ -#pragma once - -#include "../base/complex.hpp" -#include "../base/constants.hpp" -#include "../base/digitreverse.hpp" -#include "../base/vec.hpp" - -#include "../data/bitrev.hpp" - -#include "ft.hpp" - -namespace kfr -{ - -namespace internal -{ - -constexpr bool fft_reorder_aligned = false; - -template <size_t Bits> -CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x) -{ - constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table)); - if (Bits > bitrev_table_log2N) - return bitreverse<Bits>(x); - - return data::bitrev_table[x] >> (bitrev_table_log2N - Bits); -} - -CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits) -{ - constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table)); - if (bits > bitrev_table_log2N) - return bitreverse<32>(x) >> (32 - bits); - - return data::bitrev_table[x] >> (bitrev_table_log2N - bits); -} - -CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits) -{ - constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table)); - if (bits > bitrev_table_log2N) - return digitreverse4<32>(x) >> (32 - bits); - - x = data::bitrev_table[x]; - x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); - x = x >> (bitrev_table_log2N - bits); - return x; -} - -template <size_t log2n, size_t bitrev, typename T> -KFR_INTRIN void fft_reorder_swap(T* inout, size_t i) -{ - using cxx = cvec<T, 16>; - constexpr size_t N = 1 << log2n; - constexpr size_t N4 = 2 * N / 4; - - cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i)); - vi = digitreverse<bitrev, 2>(vi); - cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vi); -} - -template <size_t log2n, size_t bitrev, typename T> -KFR_INTRIN void fft_reorder_swap_two(T* inout, size_t i, size_t j) -{ - CMT_ASSUME(i != j); - using cxx = cvec<T, 16>; - constexpr size_t N = 1 << log2n; - constexpr size_t N4 = 2 * N / 4; - - cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i)); - cxx vj = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j)); - - vi = digitreverse<bitrev, 2>(vi); - cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vi); - vj = digitreverse<bitrev, 2>(vj); - cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), vj); -} - -template <size_t log2n, size_t bitrev, typename T> -KFR_INTRIN void fft_reorder_swap(T* inout, size_t i, size_t j) -{ - CMT_ASSUME(i != j); - using cxx = cvec<T, 16>; - constexpr size_t N = 1 << log2n; - constexpr size_t N4 = 2 * N / 4; - - cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i)); - cxx vj = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j)); - - vi = digitreverse<bitrev, 2>(vi); - cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), vi); - vj = digitreverse<bitrev, 2>(vj); - cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vj); -} - -template <size_t log2n, size_t bitrev, typename T> -KFR_INTRIN void fft_reorder_swap(complex<T>* inout, size_t i) -{ - fft_reorder_swap<log2n, bitrev>(ptr_cast<T>(inout), i * 2); -} - -template <size_t log2n, size_t bitrev, typename T> -KFR_INTRIN void fft_reorder_swap_two(complex<T>* inout, size_t i0, size_t i1) -{ - fft_reorder_swap_two<log2n, bitrev>(ptr_cast<T>(inout), i0 * 2, i1 * 2); -} - -template <size_t log2n, size_t bitrev, typename T> -KFR_INTRIN void fft_reorder_swap(complex<T>* inout, size_t i, size_t j) -{ - fft_reorder_swap<log2n, bitrev>(ptr_cast<T>(inout), i * 2, j * 2); -} - -template <typename T> -KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<11>) -{ - fft_reorder_swap_two<11>(inout, 0 * 4, 8 * 4); - fft_reorder_swap<11>(inout, 1 * 4, 64 * 4); - fft_reorder_swap<11>(inout, 2 * 4, 32 * 4); - fft_reorder_swap<11>(inout, 3 * 4, 96 * 4); - fft_reorder_swap<11>(inout, 4 * 4, 16 * 4); - fft_reorder_swap<11>(inout, 5 * 4, 80 * 4); - fft_reorder_swap<11>(inout, 6 * 4, 48 * 4); - fft_reorder_swap<11>(inout, 7 * 4, 112 * 4); - fft_reorder_swap<11>(inout, 9 * 4, 72 * 4); - fft_reorder_swap<11>(inout, 10 * 4, 40 * 4); - fft_reorder_swap<11>(inout, 11 * 4, 104 * 4); - fft_reorder_swap<11>(inout, 12 * 4, 24 * 4); - fft_reorder_swap<11>(inout, 13 * 4, 88 * 4); - fft_reorder_swap<11>(inout, 14 * 4, 56 * 4); - fft_reorder_swap<11>(inout, 15 * 4, 120 * 4); - fft_reorder_swap<11>(inout, 17 * 4, 68 * 4); - fft_reorder_swap<11>(inout, 18 * 4, 36 * 4); - fft_reorder_swap<11>(inout, 19 * 4, 100 * 4); - fft_reorder_swap_two<11>(inout, 20 * 4, 28 * 4); - fft_reorder_swap<11>(inout, 21 * 4, 84 * 4); - fft_reorder_swap<11>(inout, 22 * 4, 52 * 4); - fft_reorder_swap<11>(inout, 23 * 4, 116 * 4); - fft_reorder_swap<11>(inout, 25 * 4, 76 * 4); - fft_reorder_swap<11>(inout, 26 * 4, 44 * 4); - fft_reorder_swap<11>(inout, 27 * 4, 108 * 4); - fft_reorder_swap<11>(inout, 29 * 4, 92 * 4); - fft_reorder_swap<11>(inout, 30 * 4, 60 * 4); - fft_reorder_swap<11>(inout, 31 * 4, 124 * 4); - fft_reorder_swap<11>(inout, 33 * 4, 66 * 4); - fft_reorder_swap_two<11>(inout, 34 * 4, 42 * 4); - fft_reorder_swap<11>(inout, 35 * 4, 98 * 4); - fft_reorder_swap<11>(inout, 37 * 4, 82 * 4); - fft_reorder_swap<11>(inout, 38 * 4, 50 * 4); - fft_reorder_swap<11>(inout, 39 * 4, 114 * 4); - fft_reorder_swap<11>(inout, 41 * 4, 74 * 4); - fft_reorder_swap<11>(inout, 43 * 4, 106 * 4); - fft_reorder_swap<11>(inout, 45 * 4, 90 * 4); - fft_reorder_swap<11>(inout, 46 * 4, 58 * 4); - fft_reorder_swap<11>(inout, 47 * 4, 122 * 4); - fft_reorder_swap<11>(inout, 49 * 4, 70 * 4); - fft_reorder_swap<11>(inout, 51 * 4, 102 * 4); - fft_reorder_swap<11>(inout, 53 * 4, 86 * 4); - fft_reorder_swap_two<11>(inout, 54 * 4, 62 * 4); - fft_reorder_swap<11>(inout, 55 * 4, 118 * 4); - fft_reorder_swap<11>(inout, 57 * 4, 78 * 4); - fft_reorder_swap<11>(inout, 59 * 4, 110 * 4); - fft_reorder_swap<11>(inout, 61 * 4, 94 * 4); - fft_reorder_swap<11>(inout, 63 * 4, 126 * 4); - fft_reorder_swap_two<11>(inout, 65 * 4, 73 * 4); - fft_reorder_swap<11>(inout, 67 * 4, 97 * 4); - fft_reorder_swap<11>(inout, 69 * 4, 81 * 4); - fft_reorder_swap<11>(inout, 71 * 4, 113 * 4); - fft_reorder_swap<11>(inout, 75 * 4, 105 * 4); - fft_reorder_swap<11>(inout, 77 * 4, 89 * 4); - fft_reorder_swap<11>(inout, 79 * 4, 121 * 4); - fft_reorder_swap<11>(inout, 83 * 4, 101 * 4); - fft_reorder_swap_two<11>(inout, 85 * 4, 93 * 4); - fft_reorder_swap<11>(inout, 87 * 4, 117 * 4); - fft_reorder_swap<11>(inout, 91 * 4, 109 * 4); - fft_reorder_swap<11>(inout, 95 * 4, 125 * 4); - fft_reorder_swap_two<11>(inout, 99 * 4, 107 * 4); - fft_reorder_swap<11>(inout, 103 * 4, 115 * 4); - fft_reorder_swap<11>(inout, 111 * 4, 123 * 4); - fft_reorder_swap_two<11>(inout, 119 * 4, 127 * 4); -} - -template <typename T> -KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<7>) -{ - constexpr size_t bitrev = 2; - fft_reorder_swap_two<7, bitrev>(inout, 0 * 4, 2 * 4); - fft_reorder_swap<7, bitrev>(inout, 1 * 4, 4 * 4); - fft_reorder_swap<7, bitrev>(inout, 3 * 4, 6 * 4); - fft_reorder_swap_two<7, bitrev>(inout, 5 * 4, 7 * 4); -} - -template <typename T> -KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<8>) -{ - constexpr size_t bitrev = 4; - fft_reorder_swap_two<8, bitrev>(inout, 0 * 4, 5 * 4); - fft_reorder_swap<8, bitrev>(inout, 1 * 4, 4 * 4); - fft_reorder_swap<8, bitrev>(inout, 2 * 4, 8 * 4); - fft_reorder_swap<8, bitrev>(inout, 3 * 4, 12 * 4); - fft_reorder_swap<8, bitrev>(inout, 6 * 4, 9 * 4); - fft_reorder_swap<8, bitrev>(inout, 7 * 4, 13 * 4); - fft_reorder_swap_two<8, bitrev>(inout, 10 * 4, 15 * 4); - fft_reorder_swap<8, bitrev>(inout, 11 * 4, 14 * 4); -} - -template <typename T> -KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<9>) -{ - constexpr size_t bitrev = 2; - fft_reorder_swap_two<9, bitrev>(inout, 0 * 4, 4 * 4); - fft_reorder_swap<9, bitrev>(inout, 1 * 4, 16 * 4); - fft_reorder_swap<9, bitrev>(inout, 2 * 4, 8 * 4); - fft_reorder_swap<9, bitrev>(inout, 3 * 4, 24 * 4); - fft_reorder_swap<9, bitrev>(inout, 5 * 4, 20 * 4); - fft_reorder_swap<9, bitrev>(inout, 6 * 4, 12 * 4); - fft_reorder_swap<9, bitrev>(inout, 7 * 4, 28 * 4); - fft_reorder_swap<9, bitrev>(inout, 9 * 4, 18 * 4); - fft_reorder_swap_two<9, bitrev>(inout, 10 * 4, 14 * 4); - fft_reorder_swap<9, bitrev>(inout, 11 * 4, 26 * 4); - fft_reorder_swap<9, bitrev>(inout, 13 * 4, 22 * 4); - fft_reorder_swap<9, bitrev>(inout, 15 * 4, 30 * 4); - fft_reorder_swap_two<9, bitrev>(inout, 17 * 4, 21 * 4); - fft_reorder_swap<9, bitrev>(inout, 19 * 4, 25 * 4); - fft_reorder_swap<9, bitrev>(inout, 23 * 4, 29 * 4); - fft_reorder_swap_two<9, bitrev>(inout, 27 * 4, 31 * 4); -} - -template <typename T, bool use_br2> -void cwrite_reordered(T* out, const cvec<T, 16>& value, size_t N4, cbool_t<use_br2>) -{ - cwrite_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(out), N4, - digitreverse<(use_br2 ? 2 : 4), 2>(value)); -} - -template <typename T, bool use_br2> -KFR_INTRIN void fft_reorder_swap_n4(T* inout, size_t i, size_t j, size_t N4, cbool_t<use_br2>) -{ - CMT_ASSUME(i != j); - const cvec<T, 16> vi = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), N4); - const cvec<T, 16> vj = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), N4); - cwrite_reordered(inout + j, vi, N4, cbool_t<use_br2>()); - cwrite_reordered(inout + i, vj, N4, cbool_t<use_br2>()); -} - -template <typename T> -KFR_INTRIN void fft_reorder(complex<T>* inout, size_t log2n, ctrue_t use_br2) -{ - const size_t N = 1 << log2n; - const size_t N4 = N / 4; - const size_t iend = N / 16 * 4 * 2; - constexpr size_t istep = 2 * 4; - const size_t jstep1 = (1 << (log2n - 5)) * 4 * 2; - const size_t jstep2 = size_t(1 << (log2n - 5)) * 4 * 2 - size_t(1 << (log2n - 6)) * 4 * 2; - T* io = ptr_cast<T>(inout); - - for (size_t i = 0; i < iend;) - { - size_t j = bitrev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep; - j = j + jstep1; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep; - j = j - jstep2; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep; - j = j + jstep1; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep; - } -} - -template <typename T> -KFR_INTRIN void fft_reorder(complex<T>* inout, size_t log2n, cfalse_t use_br2) -{ - const size_t N = size_t(1) << log2n; - const size_t N4 = N / 4; - const size_t N16 = N * 2 / 16; - size_t iend = N16; - constexpr size_t istep = 2 * 4; - const size_t jstep = N / 64 * 4 * 2; - T* io = ptr_cast<T>(inout); - - size_t i = 0; - CMT_PRAGMA_CLANG(clang loop unroll_count(2)) - for (; i < iend;) - { - size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep * 4; - } - iend += N16; - CMT_PRAGMA_CLANG(clang loop unroll_count(2)) - for (; i < iend;) - { - size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; - - fft_reorder_swap_n4(io, i, j, N4, use_br2); - - i += istep; - j = j + jstep; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep * 3; - } - iend += N16; - CMT_PRAGMA_CLANG(clang loop unroll_count(2)) - for (; i < iend;) - { - size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; - - fft_reorder_swap_n4(io, i, j, N4, use_br2); - - i += istep; - j = j + jstep; - - fft_reorder_swap_n4(io, i, j, N4, use_br2); - - i += istep; - j = j + jstep; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep * 2; - } - iend += N16; - CMT_PRAGMA_CLANG(clang loop unroll_count(2)) - for (; i < iend;) - { - size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; - - fft_reorder_swap_n4(io, i, j, N4, use_br2); - - i += istep; - j = j + jstep; - - fft_reorder_swap_n4(io, i, j, N4, use_br2); - - i += istep; - j = j + jstep; - - fft_reorder_swap_n4(io, i, j, N4, use_br2); - - i += istep; - j = j + jstep; - - if (i >= j) - fft_reorder_swap_n4(io, i, j, N4, use_br2); - i += istep; - } -} -} -} diff --git a/include/kfr/dft/dft-src.cpp b/include/kfr/dft/dft-src.cpp @@ -1,1971 +0,0 @@ -/** @addtogroup dft - * @{ - */ -/* - Copyright (C) 2016 D Levin (https://www.kfrlib.com) - This file is part of KFR - - KFR is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - KFR is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with KFR. - - If GPL is not suitable for your project, you must purchase a commercial license to use KFR. - Buying a commercial license is mandatory as soon as you develop commercial activities without - disclosing the source code of your own applications. - See https://www.kfrlib.com for details. - */ - -#include "dft_c.h" - -#include "../base/basic_expressions.hpp" -#include "../testo/assert.hpp" -#include "bitrev.hpp" -#include "cache.hpp" -#include "convolution.hpp" -#include "fft.hpp" -#include "ft.hpp" - -CMT_PRAGMA_GNU(GCC diagnostic push) -#if CMT_HAS_WARNING("-Wshadow") -CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow") -#endif - -CMT_PRAGMA_MSVC(warning(push)) -CMT_PRAGMA_MSVC(warning(disable : 4100)) - -namespace kfr -{ - -constexpr csizes_t<2, 3, 4, 5, 6, 7, 8, 9, 10> dft_radices{}; - -#define DFT_ASSERT TESTO_ASSERT_INACTIVE - -template <typename T> -constexpr size_t fft_vector_width = platform<T>::vector_width; - -using cdirect_t = cfalse_t; -using cinvert_t = ctrue_t; - -template <typename T> -struct dft_stage -{ - size_t radix = 0; - size_t stage_size = 0; - size_t data_size = 0; - size_t temp_size = 0; - u8* data = nullptr; - size_t repeats = 1; - size_t out_offset = 0; - size_t blocks = 0; - const char* name = nullptr; - bool recursion = false; - bool can_inplace = true; - bool inplace = false; - bool to_scratch = false; - bool need_reorder = true; - - void initialize(size_t size) { do_initialize(size); } - - virtual void dump() const - { - printf("%s: \n\t%5zu,%5zu,%5zu,%5zu,%5zu,%5zu,%5zu, %d, %d, %d, %d\n", name ? name : "unnamed", radix, - stage_size, data_size, temp_size, repeats, out_offset, blocks, recursion, can_inplace, inplace, - to_scratch); - } - - KFR_INTRIN void execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) - { - do_execute(cdirect_t(), out, in, temp); - } - KFR_INTRIN void execute(cinvert_t, complex<T>* out, const complex<T>* in, u8* temp) - { - do_execute(cinvert_t(), out, in, temp); - } - virtual ~dft_stage() {} - -protected: - virtual void do_initialize(size_t) {} - virtual void do_execute(cdirect_t, complex<T>*, const complex<T>*, u8* temp) = 0; - virtual void do_execute(cinvert_t, complex<T>*, const complex<T>*, u8* temp) = 0; -}; - -#define DFT_STAGE_FN \ - void do_execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) override \ - { \ - return do_execute<false>(out, in, temp); \ - } \ - void do_execute(cinvert_t, complex<T>* out, const complex<T>* in, u8* temp) override \ - { \ - return do_execute<true>(out, in, temp); \ - } - -CMT_PRAGMA_GNU(GCC diagnostic push) -#if CMT_HAS_WARNING("-Wassume") -CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wassume") -#endif - -namespace internal -{ - -template <size_t width, bool inverse, typename T> -KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split_format*/, cbool_t<inverse>, - const cvec<T, width>& w, const cvec<T, width>& tw) -{ - cvec<T, width> ww = w; - cvec<T, width> tw_ = tw; - cvec<T, width> b1 = ww * dupeven(tw_); - ww = swap<2>(ww); - - if (inverse) - tw_ = -(tw_); - ww = subadd(b1, ww * dupodd(tw_)); - return ww; -} - -template <size_t width, bool use_br2, bool inverse, bool aligned, typename T> -KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, cfalse_t, cfalse_t, cfalse_t, cbool_t<use_br2>, - cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in, - const complex<T>* twiddle) -{ - const size_t N4 = N / 4; - cvec<T, width> w1, w2, w3; - - cvec<T, width> sum02, sum13, diff02, diff13; - - cvec<T, width> a0, a1, a2, a3; - a0 = cread<width, aligned>(in + 0); - a2 = cread<width, aligned>(in + N4 * 2); - sum02 = a0 + a2; - - a1 = cread<width, aligned>(in + N4); - a3 = cread<width, aligned>(in + N4 * 3); - sum13 = a1 + a3; - - cwrite<width, aligned>(out, sum02 + sum13); - w2 = sum02 - sum13; - cwrite<width, aligned>(out + N4 * (use_br2 ? 1 : 2), - radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w2, - cread<width, true>(twiddle + width))); - diff02 = a0 - a2; - diff13 = a1 - a3; - if (inverse) - { - diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T())); - diff13 = swap<2>(diff13); - } - else - { - diff13 = swap<2>(diff13); - diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T())); - } - - w1 = diff02 + diff13; - - cwrite<width, aligned>(out + N4 * (use_br2 ? 2 : 1), - radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w1, - cread<width, true>(twiddle + 0))); - w3 = diff02 - diff13; - cwrite<width, aligned>(out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), - w3, cread<width, true>(twiddle + width * 2))); -} - -template <size_t width, bool inverse, typename T> -KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, ctrue_t /*split_format*/, cbool_t<inverse>, - const cvec<T, width>& w, const cvec<T, width>& tw) -{ - vec<T, width> re1, im1, twre, twim; - split(w, re1, im1); - split(tw, twre, twim); - - const vec<T, width> b1re = re1 * twre; - const vec<T, width> b1im = im1 * twre; - if (inverse) - return concat(b1re + im1 * twim, b1im - re1 * twim); - else - return concat(b1re - im1 * twim, b1im + re1 * twim); -} - -template <size_t width, bool splitout, bool splitin, bool use_br2, bool inverse, bool aligned, typename T> -KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, ctrue_t, cbool_t<splitout>, cbool_t<splitin>, - cbool_t<use_br2>, cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, - const complex<T>* in, const complex<T>* twiddle) -{ - const size_t N4 = N / 4; - cvec<T, width> w1, w2, w3; - constexpr bool read_split = !splitin && splitout; - constexpr bool write_split = splitin && !splitout; - - vec<T, width> re0, im0, re1, im1, re2, im2, re3, im3; - - split(cread_split<width, aligned, read_split>(in + N4 * 0), re0, im0); - split(cread_split<width, aligned, read_split>(in + N4 * 1), re1, im1); - split(cread_split<width, aligned, read_split>(in + N4 * 2), re2, im2); - split(cread_split<width, aligned, read_split>(in + N4 * 3), re3, im3); - - const vec<T, width> sum02re = re0 + re2; - const vec<T, width> sum02im = im0 + im2; - const vec<T, width> sum13re = re1 + re3; - const vec<T, width> sum13im = im1 + im3; - - cwrite_split<width, aligned, write_split>(out, concat(sum02re + sum13re, sum02im + sum13im)); - w2 = concat(sum02re - sum13re, sum02im - sum13im); - cwrite_split<width, aligned, write_split>( - out + N4 * (use_br2 ? 1 : 2), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w2, - cread<width, true>(twiddle + width))); - - const vec<T, width> diff02re = re0 - re2; - const vec<T, width> diff02im = im0 - im2; - const vec<T, width> diff13re = re1 - re3; - const vec<T, width> diff13im = im1 - im3; - - (inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re); - (inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re); - - cwrite_split<width, aligned, write_split>( - out + N4 * (use_br2 ? 2 : 1), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w1, - cread<width, true>(twiddle + 0))); - cwrite_split<width, aligned, write_split>( - out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w3, - cread<width, true>(twiddle + width * 2))); -} - -template <typename T> -CMT_NOINLINE cvec<T, 1> calculate_twiddle(size_t n, size_t size) -{ - if (n == 0) - { - return make_vector(static_cast<T>(1), static_cast<T>(0)); - } - else if (n == size / 4) - { - return make_vector(static_cast<T>(0), static_cast<T>(-1)); - } - else if (n == size / 2) - { - return make_vector(static_cast<T>(-1), static_cast<T>(0)); - } - else if (n == size * 3 / 4) - { - return make_vector(static_cast<T>(0), static_cast<T>(1)); - } - else - { - fbase kth = c_pi<fbase, 2> * (n / static_cast<fbase>(size)); - fbase tcos = +kfr::cos(kth); - fbase tsin = -kfr::sin(kth); - return make_vector(static_cast<T>(tcos), static_cast<T>(tsin)); - } -} - -template <typename T, size_t width> -KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_t nnstep, size_t size, - bool split_format) -{ - vec<T, 2 * width> result = T(); - CMT_LOOP_UNROLL - for (size_t i = 0; i < width; i++) - { - const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size); - result[i * 2] = r[0]; - result[i * 2 + 1] = r[1]; - } - if (split_format) - ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result); - else - ref_cast<cvec<T, width>>(twiddle[0]) = result; - twiddle += width; -} - -template <typename T, size_t width> -CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format) -{ - const size_t count = stage_size / 4; - size_t nnstep = size / stage_size; - DFT_ASSERT(width <= count); - CMT_LOOP_NOUNROLL - for (size_t n = 0; n < count; n += width) - { - initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format); - initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format); - initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 3, nnstep * 3, size, split_format); - } -} - -#if defined CMT_ARCH_SSE -#ifdef CMT_COMPILER_GNU -#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr), 0, _MM_HINT_T0); -#else -#define KFR_PREFETCH(addr) _mm_prefetch(::kfr::ptr_cast<char>(addr), _MM_HINT_T0); -#endif -#else -#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr)); -#endif - -template <typename T> -KFR_SINTRIN void prefetch_one(const complex<T>* in) -{ - KFR_PREFETCH(in); -} - -template <typename T> -KFR_SINTRIN void prefetch_four(size_t stride, const complex<T>* in) -{ - KFR_PREFETCH(in); - KFR_PREFETCH(in + stride); - KFR_PREFETCH(in + stride * 2); - KFR_PREFETCH(in + stride * 3); -} - -template <typename Ntype, size_t width, bool splitout, bool splitin, bool prefetch, bool use_br2, - bool inverse, bool aligned, typename T> -KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t<splitout>, cbool_t<splitin>, - cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, - complex<T>* out, const complex<T>* in, const complex<T>*& twiddle) -{ - constexpr static size_t prefetch_offset = width * 8; - const auto N4 = N / csize_t<4>(); - const auto N43 = N4 * csize_t<3>(); - CMT_ASSUME(blocks > 0); - CMT_ASSUME(N > 0); - CMT_ASSUME(N4 > 0); - DFT_ASSERT(width <= N4); - CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++) - { - CMT_PRAGMA_CLANG(clang loop unroll_count(2)) - for (size_t n2 = 0; n2 < N4; n2 += width) - { - if (prefetch) - prefetch_four(N4, in + prefetch_offset); - radix4_body(N, csize_t<width>(), cbool_t<(splitout || splitin)>(), cbool_t<splitout>(), - cbool_t<splitin>(), cbool_t<use_br2>(), cbool_t<inverse>(), cbool_t<aligned>(), out, - in, twiddle + n2 * 3); - in += width; - out += width; - } - in += N43; - out += N43; - } - twiddle += N43; - return {}; -} - -template <bool splitin, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> -KFR_SINTRIN ctrue_t radix4_pass(csize_t<32>, size_t blocks, csize_t<8>, cfalse_t, cbool_t<splitin>, - cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, - complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) -{ - CMT_ASSUME(blocks > 0); - constexpr static size_t prefetch_offset = 32 * 4; - for (size_t b = 0; b < blocks; b++) - { - if (prefetch) - prefetch_four(csize_t<64>(), out + prefetch_offset); - cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7; - split(cread_split<8, aligned, splitin>(out + 0), w0, w1); - split(cread_split<8, aligned, splitin>(out + 8), w2, w3); - split(cread_split<8, aligned, splitin>(out + 16), w4, w5); - split(cread_split<8, aligned, splitin>(out + 24), w6, w7); - - butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7); - - w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>()); - w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>()); - w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>()); - w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>()); - w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>()); - w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>()); - w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>()); - - cvec<T, 8> z0, z1, z2, z3; - transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3); - - butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3); - cwrite<32, aligned>(out, bitreverse<2>(concat(z0, z1, z2, z3))); - out += 32; - } - return {}; -} - -template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> -KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, - cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, - complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) -{ - CMT_ASSUME(blocks > 0); - DFT_ASSERT(2 <= blocks); - constexpr static size_t prefetch_offset = width * 16; - for (size_t b = 0; b < blocks; b += 2) - { - if (prefetch) - prefetch_one(out + prefetch_offset); - - cvec<T, 8> vlo = cread<8, aligned>(out + 0); - cvec<T, 8> vhi = cread<8, aligned>(out + 8); - butterfly8<inverse>(vlo); - butterfly8<inverse>(vhi); - vlo = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vlo); - vhi = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vhi); - cwrite<8, aligned>(out, vlo); - cwrite<8, aligned>(out + 8, vhi); - out += 16; - } - return {}; -} - -template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> -KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, - cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, - complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) -{ - CMT_ASSUME(blocks > 0); - constexpr static size_t prefetch_offset = width * 4; - DFT_ASSERT(2 <= blocks); - CMT_PRAGMA_CLANG(clang loop unroll_count(2)) - for (size_t b = 0; b < blocks; b += 2) - { - if (prefetch) - prefetch_one(out + prefetch_offset); - - cvec<T, 16> vlo = cread<16, aligned>(out); - cvec<T, 16> vhi = cread<16, aligned>(out + 16); - butterfly4<4, inverse>(vlo); - butterfly4<4, inverse>(vhi); - apply_twiddles4<0, 4, 4, inverse>(vlo); - apply_twiddles4<0, 4, 4, inverse>(vhi); - vlo = digitreverse4<2>(vlo); - vhi = digitreverse4<2>(vhi); - butterfly4<4, inverse>(vlo); - butterfly4<4, inverse>(vhi); - - use_br2 ? cbitreverse_write(out, vlo) : cdigitreverse4_write(out, vlo); - use_br2 ? cbitreverse_write(out + 16, vhi) : cdigitreverse4_write(out + 16, vhi); - out += 32; - } - return {}; -} - -template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> -KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, - cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, - complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) -{ - constexpr static size_t prefetch_offset = width * 4; - CMT_ASSUME(blocks > 0); - DFT_ASSERT(4 <= blocks); - CMT_LOOP_NOUNROLL - for (size_t b = 0; b < blocks; b += 4) - { - if (prefetch) - prefetch_one(out + prefetch_offset); - - cvec<T, 16> v16 = cdigitreverse4_read<16, aligned>(out); - butterfly4<4, inverse>(v16); - cdigitreverse4_write<aligned>(out, v16); - - out += 4 * 4; - } - return {}; -} - -template <typename T> -static void dft_stage_fixed_initialize(dft_stage<T>* stage, size_t width) -{ - complex<T>* twiddle = ptr_cast<complex<T>>(stage->data); - const size_t N = stage->repeats * stage->radix; - const size_t Nord = stage->repeats; - size_t i = 0; - - while (width > 0) - { - CMT_LOOP_NOUNROLL - for (; i < Nord / width * width; i += width) - { - CMT_LOOP_NOUNROLL - for (size_t j = 1; j < stage->radix; j++) - { - CMT_LOOP_NOUNROLL - for (size_t k = 0; k < width; k++) - { - cvec<T, 1> xx = cossin_conj(broadcast<2, T>(c_pi<T, 2> * (i + k) * j / N)); - ref_cast<cvec<T, 1>>(twiddle[k]) = xx; - } - twiddle += width; - } - } - width = width / 2; - } -} - -template <typename T, size_t radix> -struct dft_stage_fixed_impl : dft_stage<T> -{ - dft_stage_fixed_impl(size_t radix_, size_t iterations, size_t blocks) - { - this->name = type_name<decltype(*this)>(); - this->radix = radix; - this->blocks = blocks; - this->repeats = iterations; - this->recursion = false; // true; - this->data_size = - align_up((this->repeats * (radix - 1)) * sizeof(complex<T>), platform<>::native_cache_alignment); - } - - constexpr static size_t width = - radix >= 7 ? fft_vector_width<T> / 2 : radix >= 4 ? fft_vector_width<T> : fft_vector_width<T> * 2; - virtual void do_initialize(size_t size) override final { dft_stage_fixed_initialize(this, width); } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - const size_t Nord = this->repeats; - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - - const size_t N = Nord * this->radix; - CMT_LOOP_NOUNROLL - for (size_t b = 0; b < this->blocks; b++) - { - butterflies(Nord, csize<width>, csize<radix>, cbool<inverse>, out, in, twiddle, Nord); - in += N; - out += N; - } - } -}; - -template <typename T, size_t radix> -struct dft_stage_fixed_final_impl : dft_stage<T> -{ - dft_stage_fixed_final_impl(size_t radix_, size_t iterations, size_t blocks) - { - this->name = type_name<decltype(*this)>(); - this->radix = radix; - this->blocks = blocks; - this->repeats = iterations; - this->recursion = false; - this->can_inplace = false; - } - constexpr static size_t width = - radix >= 7 ? fft_vector_width<T> / 2 : radix >= 4 ? fft_vector_width<T> : fft_vector_width<T> * 2; - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - const size_t b = this->blocks; - const size_t size = b * radix; - - butterflies(b, csize<width>, csize<radix>, cbool<inverse>, out, in, b); - } -}; - -template <typename E> -inline E& apply_conj(E& e, cfalse_t) -{ - return e; -} - -template <typename E> -inline auto apply_conj(E& e, ctrue_t) -{ - return cconj(e); -} - -/// [0, N - 1, N - 2, N - 3, ..., 3, 2, 1] -template <typename E> -struct fft_inverse : expression_base<E> -{ - using value_type = value_type_of<E>; - - CMT_INLINE fft_inverse(E&& expr) noexcept : expression_base<E>(std::forward<E>(expr)) {} - - CMT_INLINE vec<value_type, 1> operator()(cinput_t input, size_t index, vec_t<value_type, 1>) const - { - return this->argument_first(input, index == 0 ? 0 : this->size() - index, vec_t<value_type, 1>()); - } - - template <size_t N> - CMT_INLINE vec<value_type, N> operator()(cinput_t input, size_t index, vec_t<value_type, N>) const - { - if (index == 0) - { - return concat( - this->argument_first(input, index, vec_t<value_type, 1>()), - reverse(this->argument_first(input, this->size() - (N - 1), vec_t<value_type, N - 1>()))); - } - return reverse(this->argument_first(input, this->size() - index - (N - 1), vec_t<value_type, N>())); - } -}; - -template <typename E> -inline auto apply_fft_inverse(E&& e) -{ - return fft_inverse<E>(std::forward<E>(e)); -} - -template <typename T> -struct dft_arblen_stage_impl : dft_stage<T> -{ - dft_arblen_stage_impl(size_t size) - : fftsize(next_poweroftwo(size) * 2), plan(fftsize, dft_order::internal), size(size) - { - this->name = type_name<decltype(*this)>(); - this->radix = size; - this->blocks = 1; - this->repeats = 1; - this->recursion = false; - this->can_inplace = false; - this->temp_size = plan.temp_size; - - chirp_ = render(cexp(sqr(linspace(T(1) - size, size - T(1), size * 2 - 1, true, true)) * - complex<T>(0, -1) * c_pi<T> / size)); - - ichirpp_ = render(truncate(padded(1 / slice(chirp_, 0, 2 * size - 1)), fftsize)); - - univector<u8> temp(plan.temp_size); - plan.execute(ichirpp_, ichirpp_, temp); - xp.resize(fftsize, 0); - xp_fft.resize(fftsize); - invN2 = T(1) / fftsize; - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) - { - const size_t n = this->size; - const size_t N2 = this->fftsize; - - auto&& chirp = apply_conj(chirp_, cbool<inverse>); - - xp.slice(0, n) = make_univector(in, n) * slice(chirp, n - 1); - - plan.execute(xp_fft.data(), xp.data(), temp); - - if (inverse) - xp_fft = xp_fft * cconj(apply_fft_inverse(ichirpp_)); - else - xp_fft = xp_fft * ichirpp_; - plan.execute(xp_fft.data(), xp_fft.data(), temp, ctrue); - - make_univector(out, n) = xp_fft.slice(n - 1) * slice(chirp, n - 1) * invN2; - } - - const size_t size; - const size_t fftsize; - T invN2; - dft_plan<T> plan; - univector<complex<T>> chirp_; - univector<complex<T>> ichirpp_; - univector<complex<T>> xp; - univector<complex<T>> xp_fft; -}; - -template <typename T, size_t radix1, size_t radix2, size_t size = radix1* radix2> -struct dft_special_stage_impl : dft_stage<T> -{ - dft_special_stage_impl() : stage1(radix1, size / radix1, 1), stage2(radix2, 1, size / radix2) - { - this->name = type_name<decltype(*this)>(); - this->radix = size; - this->blocks = 1; - this->repeats = 1; - this->recursion = false; - this->can_inplace = false; - this->temp_size = stage1.temp_size + stage2.temp_size + sizeof(complex<T>) * size; - this->data_size = stage1.data_size + stage2.data_size; - } - void dump() const override - { - dft_stage<T>::dump(); - printf(" "); - stage1.dump(); - printf(" "); - stage2.dump(); - } - void do_initialize(size_t stage_size) override - { - stage1.data = this->data; - stage2.data = this->data + stage1.data_size; - stage1.initialize(stage_size); - stage2.initialize(stage_size); - } - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) - { - complex<T>* scratch = ptr_cast<complex<T>>(temp + stage1.temp_size + stage2.temp_size); - stage1.do_execute(cbool<inverse>, scratch, in, temp); - stage2.do_execute(cbool<inverse>, out, scratch, temp + stage1.temp_size); - } - dft_stage_fixed_impl<T, radix1> stage1; - dft_stage_fixed_final_impl<T, radix2> stage2; -}; - -template <typename T, bool final> -struct dft_stage_generic_impl : dft_stage<T> -{ - dft_stage_generic_impl(size_t radix, size_t iterations, size_t blocks) - { - this->name = type_name<decltype(*this)>(); - this->radix = radix; - this->blocks = blocks; - this->repeats = iterations; - this->recursion = false; // true; - this->can_inplace = false; - this->temp_size = align_up(sizeof(complex<T>) * radix, platform<>::native_cache_alignment); - this->data_size = - align_up(sizeof(complex<T>) * sqr(this->radix / 2), platform<>::native_cache_alignment); - } - -protected: - virtual void do_initialize(size_t size) override final - { - complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - CMT_LOOP_NOUNROLL - for (size_t i = 0; i < this->radix / 2; i++) - { - CMT_LOOP_NOUNROLL - for (size_t j = 0; j < this->radix / 2; j++) - { - cwrite<1>(twiddle++, cossin_conj(broadcast<2>((i + 1) * (j + 1) * c_pi<T, 2> / this->radix))); - } - } - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) - { - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - const size_t bl = this->blocks; - const size_t Nord = this->repeats; - const size_t N = Nord * this->radix; - - CMT_LOOP_NOUNROLL - for (size_t b = 0; b < bl; b++) - generic_butterfly(this->radix, cbool<inverse>, out + b, in + b * this->radix, - ptr_cast<complex<T>>(temp), twiddle, bl); - } -}; - -template <typename T, typename Tr2> -inline void dft_permute(complex<T>* out, const complex<T>* in, size_t r0, size_t r1, Tr2 first_radix) -{ - CMT_ASSUME(r0 > 1); - CMT_ASSUME(r1 > 1); - - CMT_LOOP_NOUNROLL - for (size_t p = 0; p < r0; p++) - { - const complex<T>* in1 = in; - CMT_LOOP_NOUNROLL - for (size_t i = 0; i < r1; i++) - { - const complex<T>* in2 = in1; - CMT_LOOP_UNROLL - for (size_t j = 0; j < first_radix; j++) - { - *out++ = *in2; - in2 += r1; - } - in1++; - in += first_radix; - } - } -} - -template <typename T, typename Tr2> -inline void dft_permute_deep(complex<T>*& out, const complex<T>* in, const size_t* radices, size_t count, - size_t index, size_t inscale, size_t inner_size, Tr2 first_radix) -{ - const bool b = index == 1; - const size_t radix = radices[index]; - if (b) - { - CMT_LOOP_NOUNROLL - for (size_t i = 0; i < radix; i++) - { - const complex<T>* in1 = in; - CMT_LOOP_UNROLL - for (size_t j = 0; j < first_radix; j++) - { - *out++ = *in1; - in1 += inner_size; - } - in += inscale; - } - } - else - { - const size_t steps = radix; - const size_t inscale_next = inscale * radix; - CMT_LOOP_NOUNROLL - for (size_t i = 0; i < steps; i++) - { - dft_permute_deep(out, in, radices, count, index - 1, inscale_next, inner_size, first_radix); - in += inscale; - } - } -} - -template <typename T> -struct dft_reorder_stage_impl : dft_stage<T> -{ - dft_reorder_stage_impl(const int* radices, size_t count) : count(count) - { - this->name = type_name<decltype(*this)>(); - this->can_inplace = false; - this->data_size = 0; - std::copy(radices, radices + count, this->radices); - this->inner_size = 1; - this->size = 1; - for (size_t r = 0; r < count; r++) - { - if (r != 0 && r != count - 1) - this->inner_size *= radices[r]; - this->size *= radices[r]; - } - } - -protected: - size_t radices[32]; - size_t count = 0; - size_t size = 0; - size_t inner_size = 0; - virtual void do_initialize(size_t) override final {} - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - cswitch(dft_radices, radices[0], - [&](auto first_radix) { - if (count == 3) - { - dft_permute(out, in, radices[2], radices[1], first_radix); - } - else - { - const size_t rlast = radices[count - 1]; - for (size_t p = 0; p < rlast; p++) - { - dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, first_radix); - in += size / rlast; - } - } - }, - [&]() { - if (count == 3) - { - dft_permute(out, in, radices[2], radices[1], radices[0]); - } - else - { - const size_t rlast = radices[count - 1]; - for (size_t p = 0; p < rlast; p++) - { - dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, radices[0]); - in += size / rlast; - } - } - }); - } -}; - -template <typename T, bool splitin, bool is_even> -struct fft_stage_impl : dft_stage<T> -{ - fft_stage_impl(size_t stage_size) - { - this->name = type_name<decltype(*this)>(); - this->radix = 4; - this->stage_size = stage_size; - this->repeats = 4; - this->recursion = true; - this->data_size = - align_up(sizeof(complex<T>) * stage_size / 4 * 3, platform<>::native_cache_alignment); - } - -protected: - constexpr static bool prefetch = true; - constexpr static bool aligned = false; - constexpr static size_t width = fft_vector_width<T>; - - virtual void do_initialize(size_t size) override final - { - complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - initialize_twiddles<T, width>(twiddle, this->stage_size, size, true); - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - if (splitin) - in = out; - const size_t stg_size = this->stage_size; - CMT_ASSUME(stg_size >= 2048); - CMT_ASSUME(stg_size % 2048 == 0); - radix4_pass(stg_size, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<!is_even>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - } -}; - -template <typename T, bool splitin, size_t size> -struct fft_final_stage_impl : dft_stage<T> -{ - fft_final_stage_impl(size_t) - { - this->name = type_name<decltype(*this)>(); - this->radix = size; - this->stage_size = size; - this->out_offset = size; - this->repeats = 4; - this->recursion = true; - this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, platform<>::native_cache_alignment); - } - -protected: - constexpr static size_t width = fft_vector_width<T>; - constexpr static bool is_even = cometa::is_even(ilog2(size)); - constexpr static bool use_br2 = !is_even; - constexpr static bool aligned = false; - constexpr static bool prefetch = splitin; - - KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {} - KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {} - - template <size_t N, bool pass_splitin> - KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle) - { - constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width; - constexpr size_t pass_width = const_min(width, N / 4); - initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin); - init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle); - } - - virtual void do_initialize(size_t total_size) override final - { - complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle); - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - final_stage<inverse>(csize<size>, 1, cbool<splitin>, out, in, twiddle); - } - - template <bool inverse, typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)> - KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, - const complex<T>*& twiddle) - { - radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } - - template <bool inverse, typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)> - KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, - const complex<T>*& twiddle) - { - radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } - - template <bool inverse> - KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, - const complex<T>*& twiddle) - { - radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } - - template <bool inverse> - KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, - const complex<T>*& twiddle) - { - radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } - - template <bool inverse, size_t N, bool pass_splitin> - KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out, - const complex<T>* in, const complex<T>*& twiddle) - { - static_assert(N > 8, ""); - constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width; - constexpr size_t pass_width = const_min(width, N / 4); - static_assert(pass_width == width || (pass_split == pass_splitin), ""); - static_assert(pass_width <= N / 4, ""); - radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(), - cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, - twiddle); - final_stage<inverse>(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle); - } -}; - -template <typename T, bool is_even> -struct fft_reorder_stage_impl : dft_stage<T> -{ - fft_reorder_stage_impl(size_t stage_size) - { - this->name = type_name<decltype(*this)>(); - this->stage_size = stage_size; - log2n = ilog2(stage_size); - this->data_size = 0; - } - -protected: - size_t log2n; - - virtual void do_initialize(size_t) override final {} - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - fft_reorder(out, log2n, cbool_t<!is_even>()); - } -}; - -template <typename T, size_t log2n> -struct fft_specialization; - -template <typename T> -struct fft_specialization<T, 1> : dft_stage<T> -{ - fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } - -protected: - constexpr static bool aligned = false; - DFT_STAGE_FN - - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - cvec<T, 1> a0, a1; - split(cread<2, aligned>(in), a0, a1); - cwrite<2, aligned>(out, concat(a0 + a1, a0 - a1)); - } -}; - -template <typename T> -struct fft_specialization<T, 2> : dft_stage<T> -{ - fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } - -protected: - constexpr static bool aligned = false; - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - cvec<T, 1> a0, a1, a2, a3; - split(cread<4>(in), a0, a1, a2, a3); - butterfly(cbool_t<inverse>(), a0, a1, a2, a3, a0, a1, a2, a3); - cwrite<4>(out, concat(a0, a1, a2, a3)); - } -}; - -template <typename T> -struct fft_specialization<T, 3> : dft_stage<T> -{ - fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } - -protected: - constexpr static bool aligned = false; - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - cvec<T, 8> v8 = cread<8, aligned>(in); - butterfly8<inverse>(v8); - cwrite<8, aligned>(out, v8); - } -}; - -template <typename T> -struct fft_specialization<T, 4> : dft_stage<T> -{ - fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } - -protected: - constexpr static bool aligned = false; - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - cvec<T, 16> v16 = cread<16, aligned>(in); - butterfly16<inverse>(v16); - cwrite<16, aligned>(out, v16); - } -}; - -template <typename T> -struct fft_specialization<T, 5> : dft_stage<T> -{ - fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } - -protected: - constexpr static bool aligned = false; - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - cvec<T, 32> v32 = cread<32, aligned>(in); - butterfly32<inverse>(v32); - cwrite<32, aligned>(out, v32); - } -}; - -template <typename T> -struct fft_specialization<T, 6> : dft_stage<T> -{ - fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } - -protected: - constexpr static bool aligned = false; - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - butterfly64(cbool_t<inverse>(), cbool_t<aligned>(), out, in); - } -}; - -template <typename T> -struct fft_specialization<T, 7> : dft_stage<T> -{ - fft_specialization(size_t) - { - this->name = type_name<decltype(*this)>(); - this->stage_size = 128; - this->data_size = align_up(sizeof(complex<T>) * 128 * 3 / 2, platform<>::native_cache_alignment); - } - -protected: - constexpr static bool aligned = false; - constexpr static size_t width = platform<T>::vector_width; - constexpr static bool use_br2 = true; - constexpr static bool prefetch = false; - constexpr static bool is_double = sizeof(T) == 8; - constexpr static size_t final_size = is_double ? 8 : 32; - constexpr static size_t split_format = final_size == 8; - - virtual void do_initialize(size_t total_size) override final - { - complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - initialize_twiddles<T, width>(twiddle, 128, total_size, split_format); - initialize_twiddles<T, width>(twiddle, 32, total_size, split_format); - initialize_twiddles<T, width>(twiddle, 8, total_size, split_format); - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); - final_pass<inverse>(csize_t<final_size>(), out, in, twiddle); - if (this->need_reorder) - fft_reorder(out, csize_t<7>()); - } - - template <bool inverse> - KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) - { - radix4_pass(128, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(32, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - radix4_pass(csize_t<8>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } - - template <bool inverse> - KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) - { - radix4_pass(128, 1, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), - cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); - radix4_pass(csize_t<32>(), 4, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), - cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); - } -}; - -template <> -struct fft_specialization<float, 8> : dft_stage<float> -{ - fft_specialization(size_t) - { - this->name = type_name<decltype(*this)>(); - this->temp_size = sizeof(complex<float>) * 256; - } - -protected: - using T = float; - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) - { - complex<float>* scratch = ptr_cast<complex<float>>(temp); - if (out == in) - { - butterfly16_multi_flip<0, inverse>(scratch, out); - butterfly16_multi_flip<1, inverse>(scratch, out); - butterfly16_multi_flip<2, inverse>(scratch, out); - butterfly16_multi_flip<3, inverse>(scratch, out); - - butterfly16_multi_natural<0, inverse>(out, scratch); - butterfly16_multi_natural<1, inverse>(out, scratch); - butterfly16_multi_natural<2, inverse>(out, scratch); - butterfly16_multi_natural<3, inverse>(out, scratch); - } - else - { - butterfly16_multi_flip<0, inverse>(out, in); - butterfly16_multi_flip<1, inverse>(out, in); - butterfly16_multi_flip<2, inverse>(out, in); - butterfly16_multi_flip<3, inverse>(out, in); - - butterfly16_multi_natural<0, inverse>(out, out); - butterfly16_multi_natural<1, inverse>(out, out); - butterfly16_multi_natural<2, inverse>(out, out); - butterfly16_multi_natural<3, inverse>(out, out); - } - } -}; - -template <> -struct fft_specialization<double, 8> : fft_final_stage_impl<double, false, 256> -{ - using T = double; - fft_specialization(size_t stage_size) : fft_final_stage_impl<double, false, 256>(stage_size) - { - this->name = type_name<decltype(*this)>(); - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - fft_final_stage_impl<double, false, 256>::template do_execute<inverse>(out, in, nullptr); - if (this->need_reorder) - fft_reorder(out, csize_t<8>()); - } -}; - -template <typename T> -struct fft_specialization<T, 9> : fft_final_stage_impl<T, false, 512> -{ - fft_specialization(size_t stage_size) : fft_final_stage_impl<T, false, 512>(stage_size) - { - this->name = type_name<decltype(*this)>(); - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - fft_final_stage_impl<T, false, 512>::template do_execute<inverse>(out, in, nullptr); - if (this->need_reorder) - fft_reorder(out, csize_t<9>()); - } -}; - -template <typename T> -struct fft_specialization<T, 10> : fft_final_stage_impl<T, false, 1024> -{ - fft_specialization(size_t stage_size) : fft_final_stage_impl<T, false, 1024>(stage_size) - { - this->name = type_name<decltype(*this)>(); - } - - DFT_STAGE_FN - template <bool inverse> - KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) - { - fft_final_stage_impl<T, false, 1024>::template do_execute<inverse>(out, in, nullptr); - if (this->need_reorder) - fft_reorder(out, 10, cfalse); - } -}; - -} // namespace internal - -// - -template <typename T> -template <typename Stage, typename... Args> -void dft_plan<T>::add_stage(Args... args) -{ - dft_stage<T>* stage = new Stage(args...); - stage->need_reorder = need_reorder; - this->data_size += stage->data_size; - this->temp_size += stage->temp_size; - stages.push_back(dft_stage_ptr(stage)); -} - -template <typename T> -template <bool is_final> -void dft_plan<T>::prepare_dft_stage(size_t radix, size_t iterations, size_t blocks, cbool_t<is_final>) -{ - return cswitch( - dft_radices, radix, - [&](auto radix) CMT_INLINE_LAMBDA { - add_stage<conditional<is_final, internal::dft_stage_fixed_final_impl<T, val_of(radix)>, - internal::dft_stage_fixed_impl<T, val_of(radix)>>>(radix, iterations, - blocks); - }, - [&]() { add_stage<internal::dft_stage_generic_impl<T, is_final>>(radix, iterations, blocks); }); -} - -template <typename T> -template <bool is_even, bool first> -void dft_plan<T>::make_fft(size_t stage_size, cbool_t<is_even>, cbool_t<first>) -{ - constexpr size_t final_size = is_even ? 1024 : 512; - - if (stage_size >= 2048) - { - add_stage<internal::fft_stage_impl<T, !first, is_even>>(stage_size); - - make_fft(stage_size / 4, cbool_t<is_even>(), cfalse); - } - else - { - add_stage<internal::fft_final_stage_impl<T, !first, final_size>>(final_size); - } -} - -template <typename T> -struct reverse_wrapper -{ - T& iterable; -}; - -template <typename T> -auto begin(reverse_wrapper<T> w) -{ - return std::rbegin(w.iterable); -} - -template <typename T> -auto end(reverse_wrapper<T> w) -{ - return std::rend(w.iterable); -} - -template <typename T> -reverse_wrapper<T> reversed(T&& iterable) -{ - return { iterable }; -} - -template <typename T> -void dft_plan<T>::initialize() -{ - data = autofree<u8>(data_size); - size_t offset = 0; - for (dft_stage_ptr& stage : stages) - { - stage->data = data.data() + offset; - stage->initialize(this->size); - offset += stage->data_size; - } - - bool to_scratch = false; - bool scratch_needed = false; - for (dft_stage_ptr& stage : reversed(stages)) - { - if (to_scratch) - { - scratch_needed = true; - } - stage->to_scratch = to_scratch; - if (!stage->can_inplace) - { - to_scratch = !to_scratch; - } - } - if (scratch_needed || !stages[0]->can_inplace) - this->temp_size += align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment); -} - -template <typename T> -const complex<T>* dft_plan<T>::select_in(size_t stage, const complex<T>* out, const complex<T>* in, - const complex<T>* scratch, bool in_scratch) const -{ - if (stage == 0) - return in_scratch ? scratch : in; - return stages[stage - 1]->to_scratch ? scratch : out; -} - -template <typename T> -complex<T>* dft_plan<T>::select_out(size_t stage, complex<T>* out, complex<T>* scratch) const -{ - return stages[stage]->to_scratch ? scratch : out; -} - -template <typename T> -template <bool inverse> -void dft_plan<T>::execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const -{ - if (stages.size() == 1 && (stages[0]->can_inplace || in != out)) - { - return stages[0]->execute(cbool<inverse>, out, in, temp); - } - size_t stack[32] = { 0 }; - - complex<T>* scratch = - ptr_cast<complex<T>>(temp + this->temp_size - - align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment)); - - bool in_scratch = !stages[0]->can_inplace && in == out; - if (in_scratch) - { - internal::builtin_memcpy(scratch, in, sizeof(complex<T>) * this->size); - } - - const size_t count = stages.size(); - - for (size_t depth = 0; depth < count;) - { - if (stages[depth]->recursion) - { - size_t offset = 0; - size_t rdepth = depth; - size_t maxdepth = depth; - do - { - if (stack[rdepth] == stages[rdepth]->repeats) - { - stack[rdepth] = 0; - rdepth--; - } - else - { - complex<T>* rout = select_out(rdepth, out, scratch); - const complex<T>* rin = select_in(rdepth, out, in, scratch, in_scratch); - stages[rdepth]->execute(cbool<inverse>, rout + offset, rin + offset, temp); - offset += stages[rdepth]->out_offset; - stack[rdepth]++; - if (rdepth < count - 1 && stages[rdepth + 1]->recursion) - rdepth++; - else - maxdepth = rdepth; - } - } while (rdepth != depth); - depth = maxdepth + 1; - } - else - { - stages[depth]->execute(cbool<inverse>, select_out(depth, out, scratch), - select_in(depth, out, in, scratch, in_scratch), temp); - depth++; - } - } -} - -template <typename T> -dft_plan<T>::dft_plan(size_t size, dft_order order) : size(size), temp_size(0), data_size(0) -{ - need_reorder = true; - if (is_poweroftwo(size)) - { - const size_t log2n = ilog2(size); - cswitch(csizes_t<1, 2, 3, 4, 5, 6, 7, 8, 9, 10>(), log2n, - [&](auto log2n) { - (void)log2n; - constexpr size_t log2nv = val_of(decltype(log2n)()); - this->add_stage<internal::fft_specialization<T, log2nv>>(size); - }, - [&]() { - cswitch(cfalse_true, is_even(log2n), [&](auto is_even) { - this->make_fft(size, is_even, ctrue); - constexpr size_t is_evenv = val_of(decltype(is_even)()); - if (need_reorder) - this->add_stage<internal::fft_reorder_stage_impl<T, is_evenv>>(size); - }); - }); - } -#ifndef KFR_DFT_NO_NPo2 - else - { - if (size == 60) - { - this->add_stage<internal::dft_special_stage_impl<T, 6, 10>>(); - } - else if (size == 48) - { - this->add_stage<internal::dft_special_stage_impl<T, 6, 8>>(); - } - else - { - size_t cur_size = size; - constexpr size_t radices_count = dft_radices.back() + 1; - u8 count[radices_count] = { 0 }; - int radices[32] = { 0 }; - size_t radices_size = 0; - - cforeach(dft_radices[csizeseq<dft_radices.size(), dft_radices.size() - 1, -1>], [&](auto radix) { - while (cur_size && cur_size % val_of(radix) == 0) - { - count[val_of(radix)]++; - cur_size /= val_of(radix); - } - }); - - if (cur_size >= 101) - { - this->add_stage<internal::dft_arblen_stage_impl<T>>(size); - } - else - { - size_t blocks = 1; - size_t iterations = size; - - for (size_t r = dft_radices.front(); r <= dft_radices.back(); r++) - { - for (size_t i = 0; i < count[r]; i++) - { - iterations /= r; - radices[radices_size++] = r; - if (iterations == 1) - this->prepare_dft_stage(r, iterations, blocks, ctrue); - else - this->prepare_dft_stage(r, iterations, blocks, cfalse); - blocks *= r; - } - } - - if (cur_size > 1) - { - iterations /= cur_size; - radices[radices_size++] = cur_size; - if (iterations == 1) - this->prepare_dft_stage(cur_size, iterations, blocks, ctrue); - else - this->prepare_dft_stage(cur_size, iterations, blocks, cfalse); - } - - if (stages.size() > 2) - this->add_stage<internal::dft_reorder_stage_impl<T>>(radices, radices_size); - } - } - } -#endif - initialize(); -} - -template <typename T> -dft_plan_real<T>::dft_plan_real(size_t size) : dft_plan<T>(size / 2), size(size), rtwiddle(size / 4) -{ - using namespace internal; - - constexpr size_t width = platform<T>::vector_width * 2; - - block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) { - constexpr size_t width = val_of(decltype(w)()); - cwrite<width>(rtwiddle.data() + i, - cossin(dup(-constants<T>::pi * ((enumerate<T, width>() + i + size / 4) / (size / 2))))); - }); -} - -template <typename T> -void dft_plan_real<T>::to_fmt(complex<T>* out, dft_pack_format fmt) const -{ - using namespace internal; - size_t csize = this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2 - - constexpr size_t width = platform<T>::vector_width * 2; - const cvec<T, 1> dc = cread<1>(out); - const size_t count = csize / 2; - - block_process(count - 1, csizes_t<width, 1>(), [&](size_t i, auto w) { - i++; - constexpr size_t width = val_of(decltype(w)()); - constexpr size_t widthm1 = width - 1; - const cvec<T, width> tw = cread<width>(rtwiddle.data() + i); - const cvec<T, width> fpk = cread<width>(out + i); - const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1))); - - const cvec<T, width> f1k = fpk + fpnk; - const cvec<T, width> f2k = fpk - fpnk; - const cvec<T, width> t = cmul(f2k, tw); - cwrite<width>(out + i, T(0.5) * (f1k + t)); - cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(T(0.5) * (f1k - t)))); - }); - - { - size_t k = csize / 2; - const cvec<T, 1> fpk = cread<1>(out + k); - const cvec<T, 1> fpnk = negodd(fpk); - cwrite<1>(out + k, fpnk); - } - if (fmt == dft_pack_format::CCs) - { - cwrite<1>(out, pack(dc[0] + dc[1], 0)); - cwrite<1>(out + csize, pack(dc[0] - dc[1], 0)); - } - else - { - cwrite<1>(out, pack(dc[0] + dc[1], dc[0] - dc[1])); - } -} - -template <typename T> -void dft_plan_real<T>::from_fmt(complex<T>* out, const complex<T>* in, dft_pack_format fmt) const -{ - using namespace internal; - - const size_t csize = this->size / 2; - - cvec<T, 1> dc; - - if (fmt == dft_pack_format::CCs) - { - dc = pack(in[0].real() + in[csize].real(), in[0].real() - in[csize].real()); - } - else - { - dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag()); - } - - constexpr size_t width = platform<T>::vector_width * 2; - const size_t count = csize / 2; - - block_process(count - 1, csizes_t<width, 1>(), [&](size_t i, auto w) { - i++; - constexpr size_t width = val_of(decltype(w)()); - constexpr size_t widthm1 = width - 1; - const cvec<T, width> tw = cread<width>(rtwiddle.data() + i); - const cvec<T, width> fpk = cread<width>(in + i); - const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1))); - - const cvec<T, width> f1k = fpk + fpnk; - const cvec<T, width> f2k = fpk - fpnk; - const cvec<T, width> t = cmul_conj(f2k, tw); - cwrite<width>(out + i, f1k + t); - cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t))); - }); - - { - size_t k = csize / 2; - const cvec<T, 1> fpk = cread<1>(in + k); - const cvec<T, 1> fpnk = 2 * negodd(fpk); - cwrite<1>(out + k, fpnk); - } - cwrite<1>(out, dc); -} - -template <typename T> -dft_plan<T>::~dft_plan() -{ -} - -template <typename T> -void dft_plan<T>::dump() const -{ - for (const dft_stage_ptr& s : stages) - { - s->dump(); - } -} - -namespace internal -{ - -template <typename T> -univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2) -{ - const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); - univector<complex<T>> src1padded = src1; - univector<complex<T>> src2padded = src2; - src1padded.resize(size, 0); - src2padded.resize(size, 0); - - dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size); - univector<u8> temp(dft->temp_size); - dft->execute(src1padded, src1padded, temp); - dft->execute(src2padded, src2padded, temp); - src1padded = src1padded * src2padded; - dft->execute(src1padded, src1padded, temp, true); - const T invsize = reciprocal<T>(size); - return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; -} - -template <typename T> -univector<T> correlate(const univector_ref<const T>& src1, const univector_ref<const T>& src2) -{ - const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); - univector<complex<T>> src1padded = src1; - univector<complex<T>> src2padded = reverse(src2); - src1padded.resize(size, 0); - src2padded.resize(size, 0); - dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size); - univector<u8> temp(dft->temp_size); - dft->execute(src1padded, src1padded, temp); - dft->execute(src2padded, src2padded, temp); - src1padded = src1padded * src2padded; - dft->execute(src1padded, src1padded, temp, true); - const T invsize = reciprocal<T>(size); - return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; -} - -template <typename T> -univector<T> autocorrelate(const univector_ref<const T>& src1) -{ - return internal::autocorrelate(src1.slice()); - univector<T> result = correlate(src1, src1); - result = result.slice(result.size() / 2); - return result; -} - -template univector<float> convolve<float>(const univector_ref<const float>&, - const univector_ref<const float>&); -template univector<double> convolve<double>(const univector_ref<const double>&, - const univector_ref<const double>&); -template univector<float> correlate<float>(const univector_ref<const float>&, - const univector_ref<const float>&); -template univector<double> correlate<double>(const univector_ref<const double>&, - const univector_ref<const double>&); - -template univector<float> autocorrelate<float>(const univector_ref<const float>&); -template univector<double> autocorrelate<double>(const univector_ref<const double>&); - -} // namespace internal - -template <typename T> -convolve_filter<T>::convolve_filter(size_t size, size_t block_size) - : fft(2 * next_poweroftwo(block_size)), size(size), block_size(block_size), temp(fft.temp_size), - segments((size + block_size - 1) / block_size) -{ -} - -template <typename T> -convolve_filter<T>::convolve_filter(const univector<T>& data, size_t block_size) - : fft(2 * next_poweroftwo(block_size)), size(data.size()), block_size(next_poweroftwo(block_size)), - temp(fft.temp_size), - segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), - ir_segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), - input_position(0), position(0) -{ - set_data(data); -} - -template <typename T> -void convolve_filter<T>::set_data(const univector<T>& data) -{ - univector<T> input(fft.size); - const T ifftsize = reciprocal(T(fft.size)); - for (size_t i = 0; i < ir_segments.size(); i++) - { - segments[i].resize(block_size); - ir_segments[i].resize(block_size, 0); - input = padded(data.slice(i * block_size, block_size)); - - fft.execute(ir_segments[i], input, temp, dft_pack_format::Perm); - process(ir_segments[i], ir_segments[i] * ifftsize); - } - saved_input.resize(block_size, 0); - scratch.resize(block_size * 2); - premul.resize(block_size, 0); - cscratch.resize(block_size); - overlap.resize(block_size, 0); -} - -template <typename T> -void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size) -{ - size_t processed = 0; - while (processed < size) - { - const size_t processing = std::min(size - processed, block_size - input_position); - internal::builtin_memcpy(saved_input.data() + input_position, input + processed, - processing * sizeof(T)); - - process(scratch, padded(saved_input)); - fft.execute(segments[position], scratch, temp, dft_pack_format::Perm); - - if (input_position == 0) - { - process(premul, zeros()); - for (size_t i = 1; i < segments.size(); i++) - { - const size_t n = (position + i) % segments.size(); - fft_multiply_accumulate(premul, ir_segments[i], segments[n], dft_pack_format::Perm); - } - } - fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], dft_pack_format::Perm); - - fft.execute(scratch, cscratch, temp, dft_pack_format::Perm); - - process(make_univector(output + processed, processing), - scratch.slice(input_position) + overlap.slice(input_position)); - - input_position += processing; - if (input_position == block_size) - { - input_position = 0; - process(saved_input, zeros()); - - internal::builtin_memcpy(overlap.data(), scratch.data() + block_size, block_size * sizeof(T)); - - position = position > 0 ? position - 1 : segments.size() - 1; - } - - processed += processing; - } -} - -template convolve_filter<float>::convolve_filter(size_t, size_t); -template convolve_filter<double>::convolve_filter(size_t, size_t); - -template convolve_filter<float>::convolve_filter(const univector<float>&, size_t); -template convolve_filter<double>::convolve_filter(const univector<double>&, size_t); - -template void convolve_filter<float>::set_data(const univector<float>&); -template void convolve_filter<double>::set_data(const univector<double>&); - -template void convolve_filter<float>::process_buffer(float* output, const float* input, size_t size); -template void convolve_filter<double>::process_buffer(double* output, const double* input, size_t size); - -template dft_plan<float>::dft_plan(size_t, dft_order); -template dft_plan<float>::~dft_plan(); -template void dft_plan<float>::dump() const; -template void dft_plan<float>::execute_dft(cometa::cbool_t<false>, kfr::complex<float>* out, - const kfr::complex<float>* in, kfr::u8* temp) const; -template void dft_plan<float>::execute_dft(cometa::cbool_t<true>, kfr::complex<float>* out, - const kfr::complex<float>* in, kfr::u8* temp) const; -template dft_plan_real<float>::dft_plan_real(size_t); -template void dft_plan_real<float>::from_fmt(kfr::complex<float>* out, const kfr::complex<float>* in, - kfr::dft_pack_format fmt) const; -template void dft_plan_real<float>::to_fmt(kfr::complex<float>* out, kfr::dft_pack_format fmt) const; - -template dft_plan<double>::dft_plan(size_t, dft_order); -template dft_plan<double>::~dft_plan(); -template void dft_plan<double>::dump() const; -template void dft_plan<double>::execute_dft(cometa::cbool_t<false>, kfr::complex<double>* out, - const kfr::complex<double>* in, kfr::u8* temp) const; -template void dft_plan<double>::execute_dft(cometa::cbool_t<true>, kfr::complex<double>* out, - const kfr::complex<double>* in, kfr::u8* temp) const; -template dft_plan_real<double>::dft_plan_real(size_t); -template void dft_plan_real<double>::from_fmt(kfr::complex<double>* out, const kfr::complex<double>* in, - kfr::dft_pack_format fmt) const; -template void dft_plan_real<double>::to_fmt(kfr::complex<double>* out, kfr::dft_pack_format fmt) const; - -} // namespace kfr - -extern "C" -{ - - KFR_DFT_PLAN_F32* kfr_dft_create_plan_f32(size_t size) - { - return reinterpret_cast<KFR_DFT_PLAN_F32*>(new kfr::dft_plan<float>(size)); - } - KFR_DFT_PLAN_F64* kfr_dft_create_plan_f64(size_t size) - { - return reinterpret_cast<KFR_DFT_PLAN_F64*>(new kfr::dft_plan<double>(size)); - } - - void kfr_dft_execute_f32(KFR_DFT_PLAN_F32* plan, size_t size, float* out, const float* in, uint8_t* temp) - { - reinterpret_cast<kfr::dft_plan<float>*>(plan)->execute( - reinterpret_cast<kfr::complex<float>*>(out), reinterpret_cast<const kfr::complex<float>*>(in), - temp, kfr::cfalse); - } - void kfr_dft_execute_f64(KFR_DFT_PLAN_F64* plan, size_t size, double* out, const double* in, - uint8_t* temp) - { - reinterpret_cast<kfr::dft_plan<double>*>(plan)->execute( - reinterpret_cast<kfr::complex<double>*>(out), reinterpret_cast<const kfr::complex<double>*>(in), - temp, kfr::cfalse); - } - void kfr_dft_execute_inverse_f32(KFR_DFT_PLAN_F32* plan, size_t size, float* out, const float* in, - uint8_t* temp) - { - reinterpret_cast<kfr::dft_plan<float>*>(plan)->execute( - reinterpret_cast<kfr::complex<float>*>(out), reinterpret_cast<const kfr::complex<float>*>(in), - temp, kfr::ctrue); - } - void kfr_dft_execute_inverse_f64(KFR_DFT_PLAN_F64* plan, size_t size, double* out, const double* in, - uint8_t* temp) - { - reinterpret_cast<kfr::dft_plan<double>*>(plan)->execute( - reinterpret_cast<kfr::complex<double>*>(out), reinterpret_cast<const kfr::complex<double>*>(in), - temp, kfr::ctrue); - } - - void kfr_dft_delete_plan_f32(KFR_DFT_PLAN_F32* plan) - { - delete reinterpret_cast<kfr::dft_plan<float>*>(plan); - } - void kfr_dft_delete_plan_f64(KFR_DFT_PLAN_F64* plan) - { - delete reinterpret_cast<kfr::dft_plan<double>*>(plan); - } - - // Real DFT plans - - KFR_DFT_REAL_PLAN_F32* kfr_dft_create_real_plan_f32(size_t size) - { - return reinterpret_cast<KFR_DFT_REAL_PLAN_F32*>(new kfr::dft_plan_real<float>(size)); - } - KFR_DFT_REAL_PLAN_F64* kfr_dft_create_real_plan_f64(size_t size) - { - return reinterpret_cast<KFR_DFT_REAL_PLAN_F64*>(new kfr::dft_plan_real<double>(size)); - } - - void kfr_dft_execute_real_f32(KFR_DFT_REAL_PLAN_F32* plan, size_t size, float* out, const float* in, - uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) - { - reinterpret_cast<kfr::dft_plan_real<float>*>(plan)->execute( - reinterpret_cast<kfr::complex<float>*>(out), in, temp, - static_cast<kfr::dft_pack_format>(pack_format)); - } - void kfr_dft_execute_real_f64(KFR_DFT_REAL_PLAN_F64* plan, size_t size, double* out, const double* in, - uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) - { - reinterpret_cast<kfr::dft_plan_real<double>*>(plan)->execute( - reinterpret_cast<kfr::complex<double>*>(out), in, temp, - static_cast<kfr::dft_pack_format>(pack_format)); - } - void kfr_dft_execute_real_inverse_f32(KFR_DFT_REAL_PLAN_F32* plan, size_t size, float* out, - const float* in, uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) - { - reinterpret_cast<kfr::dft_plan_real<float>*>(plan)->execute( - out, reinterpret_cast<const kfr::complex<float>*>(in), temp, - static_cast<kfr::dft_pack_format>(pack_format)); - } - void kfr_dft_execute_real_inverse__f64(KFR_DFT_REAL_PLAN_F64* plan, size_t size, double* out, - const double* in, uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) - { - reinterpret_cast<kfr::dft_plan_real<double>*>(plan)->execute( - out, reinterpret_cast<const kfr::complex<double>*>(in), temp, - static_cast<kfr::dft_pack_format>(pack_format)); - } - - void kfr_dft_delete_real_plan_f32(KFR_DFT_REAL_PLAN_F32* plan) - { - delete reinterpret_cast<kfr::dft_plan_real<float>*>(plan); - } - void kfr_dft_delete_real_plan_f64(KFR_DFT_REAL_PLAN_F64* plan) - { - delete reinterpret_cast<kfr::dft_plan_real<double>*>(plan); - } -} - -CMT_PRAGMA_GNU(GCC diagnostic pop) - -CMT_PRAGMA_MSVC(warning(pop)) diff --git a/include/kfr/dft/ft.hpp b/include/kfr/dft/ft.hpp @@ -1,1760 +0,0 @@ -/** @addtogroup dft - * @{ - */ -/* - Copyright (C) 2016 D Levin (https://www.kfrlib.com) - This file is part of KFR - - KFR is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - KFR is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with KFR. - - If GPL is not suitable for your project, you must purchase a commercial license to use KFR. - Buying a commercial license is mandatory as soon as you develop commercial activities without - disclosing the source code of your own applications. - See https://www.kfrlib.com for details. - */ -#pragma once - -#include "../base/complex.hpp" -#include "../base/constants.hpp" -#include "../base/digitreverse.hpp" -#include "../base/read_write.hpp" -#include "../base/sin_cos.hpp" -#include "../base/small_buffer.hpp" -#include "../base/univector.hpp" -#include "../base/vec.hpp" - -#include "../base/memory.hpp" -#include "../data/sincos.hpp" - -CMT_PRAGMA_MSVC(warning(push)) -CMT_PRAGMA_MSVC(warning(disable : 4127)) - -namespace kfr -{ - -namespace internal -{ - -template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> -CMT_INLINE vec<T, N> cmul_impl(const vec<T, N>& x, const vec<T, N>& y) -{ - return subadd(x * dupeven(y), swap<2>(x) * dupodd(y)); -} -template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> -CMT_INLINE vec<T, N> cmul_impl(const vec<T, N>& x, const vec<T, 2>& y) -{ - vec<T, N> yy = resize<N>(y); - return cmul_impl(x, yy); -} -template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> -CMT_INLINE vec<T, N> cmul_impl(const vec<T, 2>& x, const vec<T, N>& y) -{ - vec<T, N> xx = resize<N>(x); - return cmul_impl(xx, y); -} - -/// Complex Multiplication -template <typename T, size_t N1, size_t N2> -CMT_INLINE vec<T, const_max(N1, N2)> cmul(const vec<T, N1>& x, const vec<T, N2>& y) -{ - return internal::cmul_impl(x, y); -} - -template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> -CMT_INLINE vec<T, N> cmul_conj(const vec<T, N>& x, const vec<T, N>& y) -{ - return swap<2>(subadd(swap<2>(x) * dupeven(y), x * dupodd(y))); -} -template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> -CMT_INLINE vec<T, N> cmul_2conj(const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& tw) -{ - return (in0 + in1) * dupeven(tw) + swap<2>(cnegimag(in0 - in1)) * dupodd(tw); -} -template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> -CMT_INLINE void cmul_2conj(vec<T, N>& out0, vec<T, N>& out1, const vec<T, 2>& in0, const vec<T, 2>& in1, - const vec<T, N>& tw) -{ - const vec<T, N> twr = dupeven(tw); - const vec<T, N> twi = dupodd(tw); - const vec<T, 2> sum = (in0 + in1); - const vec<T, 2> dif = swap<2>(negodd(in0 - in1)); - const vec<T, N> sumtw = resize<N>(sum) * twr; - const vec<T, N> diftw = resize<N>(dif) * twi; - out0 += sumtw + diftw; - out1 += sumtw - diftw; -} -template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> -CMT_INLINE vec<T, N> cmul_conj(const vec<T, N>& x, const vec<T, 2>& y) -{ - vec<T, N> yy = resize<N>(y); - return cmul_conj(x, yy); -} -template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> -CMT_INLINE vec<T, N> cmul_conj(const vec<T, 2>& x, const vec<T, N>& y) -{ - vec<T, N> xx = resize<N>(x); - return cmul_conj(xx, y); -} - -template <typename T, size_t N> -using cvec = vec<T, N * 2>; - -template <size_t N, bool A = false, typename T> -CMT_INLINE cvec<T, N> cread(const complex<T>* src) -{ - return cvec<T, N>(ptr_cast<T>(src), cbool_t<A>()); -} - -template <size_t N, bool A = false, typename T> -CMT_INLINE void cwrite(complex<T>* dest, const cvec<T, N>& value) -{ - value.write(ptr_cast<T>(dest)); -} - -template <size_t count, size_t N, size_t stride, bool A, typename T, size_t... indices> -CMT_INLINE cvec<T, count * N> cread_group_impl(const complex<T>* src, csizes_t<indices...>) -{ - return concat(read<N * 2, A>(ptr_cast<T>(src + stride * indices))...); -} -template <size_t count, size_t N, size_t stride, bool A, typename T, size_t... indices> -CMT_INLINE void cwrite_group_impl(complex<T>* dest, const cvec<T, count * N>& value, csizes_t<indices...>) -{ - swallow{ (write<A>(ptr_cast<T>(dest + stride * indices), slice<indices * N * 2, N * 2>(value)), 0)... }; -} - -template <size_t count, size_t N, bool A, typename T, size_t... indices> -CMT_INLINE cvec<T, count * N> cread_group_impl(const complex<T>* src, size_t stride, csizes_t<indices...>) -{ - return concat(read<N * 2, A>(ptr_cast<T>(src + stride * indices))...); -} -template <size_t count, size_t N, bool A, typename T, size_t... indices> -CMT_INLINE void cwrite_group_impl(complex<T>* dest, size_t stride, const cvec<T, count * N>& value, - csizes_t<indices...>) -{ - swallow{ (write<A>(ptr_cast<T>(dest + stride * indices), slice<indices * N * 2, N * 2>(value)), 0)... }; -} - -template <size_t count, size_t N, size_t stride, bool A = false, typename T> -CMT_INLINE cvec<T, count * N> cread_group(const complex<T>* src) -{ - return cread_group_impl<count, N, stride, A>(src, csizeseq_t<count>()); -} - -template <size_t count, size_t N, size_t stride, bool A = false, typename T> -CMT_INLINE void cwrite_group(complex<T>* dest, const cvec<T, count * N>& value) -{ - return cwrite_group_impl<count, N, stride, A>(dest, value, csizeseq_t<count>()); -} - -template <size_t count, size_t N, bool A = false, typename T> -CMT_INLINE cvec<T, count * N> cread_group(const complex<T>* src, size_t stride) -{ - return cread_group_impl<count, N, A>(src, stride, csizeseq_t<count>()); -} - -template <size_t count, size_t N, bool A = false, typename T> -CMT_INLINE void cwrite_group(complex<T>* dest, size_t stride, const cvec<T, count * N>& value) -{ - return cwrite_group_impl<count, N, A>(dest, stride, value, csizeseq_t<count>()); -} - -template <size_t N, bool A = false, bool split = false, typename T> -CMT_INLINE cvec<T, N> cread_split(const complex<T>* src) -{ - cvec<T, N> temp = cvec<T, N>(ptr_cast<T>(src), cbool_t<A>()); - if (split) - temp = splitpairs(temp); - return temp; -} - -template <size_t N, bool A = false, bool split = false, typename T> -CMT_INLINE void cwrite_split(complex<T>* dest, const cvec<T, N>& value) -{ - cvec<T, N> v = value; - if (split) - v = interleavehalfs(v); - v.write(ptr_cast<T>(dest), cbool_t<A>()); -} - -template <> -inline cvec<f32, 8> cread_split<8, false, true, f32>(const complex<f32>* src) -{ - const cvec<f32, 4> l = concat(cread<2>(src), cread<2>(src + 4)); - const cvec<f32, 4> h = concat(cread<2>(src + 2), cread<2>(src + 6)); - - return concat(shuffle<0, 2, 8 + 0, 8 + 2>(l, h), shuffle<1, 3, 8 + 1, 8 + 3>(l, h)); -} -template <> -inline cvec<f32, 8> cread_split<8, true, true, f32>(const complex<f32>* src) -{ - const cvec<f32, 4> l = concat(cread<2, true>(src), cread<2, true>(src + 4)); - const cvec<f32, 4> h = concat(cread<2, true>(src + 2), cread<2, true>(src + 6)); - - return concat(shuffle<0, 2, 8 + 0, 8 + 2>(l, h), shuffle<1, 3, 8 + 1, 8 + 3>(l, h)); -} - -template <> -inline cvec<f64, 4> cread_split<4, false, true, f64>(const complex<f64>* src) -{ - const cvec<f64, 2> l = concat(cread<1>(src), cread<1>(src + 2)); - const cvec<f64, 2> h = concat(cread<1>(src + 1), cread<1>(src + 3)); - - return concat(shuffle<0, 4, 2, 6>(l, h), shuffle<1, 5, 3, 7>(l, h)); -} - -template <> -inline void cwrite_split<8, false, true, f32>(complex<f32>* dest, const cvec<f32, 8>& x) -{ - const cvec<f32, 8> xx = - concat(shuffle<0, 8 + 0, 1, 8 + 1>(low(x), high(x)), shuffle<2, 8 + 2, 3, 8 + 3>(low(x), high(x))); - - cvec<f32, 2> a, b, c, d; - split(xx, a, b, c, d); - cwrite<2>(dest, a); - cwrite<2>(dest + 4, b); - cwrite<2>(dest + 2, c); - cwrite<2>(dest + 6, d); -} -template <> -inline void cwrite_split<8, true, true, f32>(complex<f32>* dest, const cvec<f32, 8>& x) -{ - const cvec<f32, 8> xx = - concat(shuffle<0, 8 + 0, 1, 8 + 1>(low(x), high(x)), shuffle<2, 8 + 2, 3, 8 + 3>(low(x), high(x))); - - cvec<f32, 2> a, b, c, d; - split(xx, a, b, c, d); - cwrite<2, true>(dest + 0, a); - cwrite<2, true>(dest + 4, b); - cwrite<2, true>(dest + 2, c); - cwrite<2, true>(dest + 6, d); -} - -template <> -inline void cwrite_split<4, false, true, f64>(complex<f64>* dest, const cvec<f64, 4>& x) -{ - const cvec<f64, 4> xx = - concat(shuffle<0, 4, 2, 6>(low(x), high(x)), shuffle<1, 5, 3, 7>(low(x), high(x))); - cwrite<1>(dest, part<4, 0>(xx)); - cwrite<1>(dest + 2, part<4, 1>(xx)); - cwrite<1>(dest + 1, part<4, 2>(xx)); - cwrite<1>(dest + 3, part<4, 3>(xx)); -} -template <> -inline void cwrite_split<4, true, true, f64>(complex<f64>* dest, const cvec<f64, 4>& x) -{ - const cvec<f64, 4> xx = - concat(shuffle<0, 4, 2, 6>(low(x), high(x)), shuffle<1, 5, 3, 7>(low(x), high(x))); - cwrite<1, true>(dest + 0, part<4, 0>(xx)); - cwrite<1, true>(dest + 2, part<4, 1>(xx)); - cwrite<1, true>(dest + 1, part<4, 2>(xx)); - cwrite<1, true>(dest + 3, part<4, 3>(xx)); -} - -template <size_t N, size_t stride, typename T, size_t... Indices> -CMT_INLINE cvec<T, N> cgather_helper(const complex<T>* base, csizes_t<Indices...>) -{ - return concat(ref_cast<cvec<T, 1>>(base[Indices * stride])...); -} - -template <size_t N, size_t stride, typename T> -CMT_INLINE cvec<T, N> cgather(const complex<T>* base) -{ - if (stride == 1) - { - return ref_cast<cvec<T, N>>(*base); - } - else - return cgather_helper<N, stride, T>(base, csizeseq_t<N>()); -} - -CMT_INLINE size_t cgather_next(size_t& index, size_t stride, size_t size, size_t) -{ - size_t temp = index; - index += stride; - if (index >= size) - index -= size; - return temp; -} -CMT_INLINE size_t cgather_next(size_t& index, size_t stride, size_t) -{ - size_t temp = index; - index += stride; - return temp; -} - -template <size_t N, typename T, size_t... Indices> -CMT_INLINE cvec<T, N> cgather_helper(const complex<T>* base, size_t& index, size_t stride, - csizes_t<Indices...>) -{ - return concat(ref_cast<cvec<T, 1>>(base[cgather_next(index, stride, Indices)])...); -} - -template <size_t N, typename T> -CMT_INLINE cvec<T, N> cgather(const complex<T>* base, size_t& index, size_t stride) -{ - return cgather_helper<N, T>(base, index, stride, csizeseq_t<N>()); -} -template <size_t N, typename T> -CMT_INLINE cvec<T, N> cgather(const complex<T>* base, size_t stride) -{ - size_t index = 0; - return cgather_helper<N, T>(base, index, stride, csizeseq_t<N>()); -} - -template <size_t N, typename T, size_t... Indices> -CMT_INLINE cvec<T, N> cgather_helper(const complex<T>* base, size_t& index, size_t stride, size_t size, - csizes_t<Indices...>) -{ - return concat(ref_cast<cvec<T, 1>>(base[cgather_next(index, stride, size, Indices)])...); -} - -template <size_t N, typename T> -CMT_INLINE cvec<T, N> cgather(const complex<T>* base, size_t& index, size_t stride, size_t size) -{ - return cgather_helper<N, T>(base, index, stride, size, csizeseq_t<N>()); -} - -template <size_t N, size_t stride, typename T, size_t... Indices> -CMT_INLINE void cscatter_helper(complex<T>* base, const cvec<T, N>& value, csizes_t<Indices...>) -{ - swallow{ (cwrite<1>(base + Indices * stride, slice<Indices * 2, 2>(value)), 0)... }; -} - -template <size_t N, size_t stride, typename T> -CMT_INLINE void cscatter(complex<T>* base, const cvec<T, N>& value) -{ - if (stride == 1) - { - cwrite<N>(base, value); - } - else - { - return cscatter_helper<N, stride, T>(base, value, csizeseq_t<N>()); - } -} - -template <size_t N, typename T, size_t... Indices> -CMT_INLINE void cscatter_helper(complex<T>* base, size_t stride, const cvec<T, N>& value, - csizes_t<Indices...>) -{ - swallow{ (cwrite<1>(base + Indices * stride, slice<Indices * 2, 2>(value)), 0)... }; -} - -template <size_t N, typename T> -CMT_INLINE void cscatter(complex<T>* base, size_t stride, const cvec<T, N>& value) -{ - return cscatter_helper<N, T>(base, stride, value, csizeseq_t<N>()); -} - -template <size_t groupsize = 1, typename T, size_t N, typename IT> -CMT_INLINE vec<T, N * 2 * groupsize> cgather(const complex<T>* base, const vec<IT, N>& offset) -{ - return gather_helper<2 * groupsize>(ptr_cast<T>(base), offset, csizeseq_t<N>()); -} - -template <size_t groupsize = 1, typename T, size_t N, typename IT> -CMT_INLINE void cscatter(complex<T>* base, const vec<IT, N>& offset, vec<T, N * 2 * groupsize> value) -{ - return scatter_helper<2 * groupsize>(ptr_cast<T>(base), offset, value, csizeseq_t<N>()); -} - -template <typename T> -KFR_INTRIN void transpose4x8(const cvec<T, 8>& z0, const cvec<T, 8>& z1, const cvec<T, 8>& z2, - const cvec<T, 8>& z3, cvec<T, 4>& w0, cvec<T, 4>& w1, cvec<T, 4>& w2, - cvec<T, 4>& w3, cvec<T, 4>& w4, cvec<T, 4>& w5, cvec<T, 4>& w6, cvec<T, 4>& w7) -{ - cvec<T, 16> a = concat(low(z0), low(z1), low(z2), low(z3)); - cvec<T, 16> b = concat(high(z0), high(z1), high(z2), high(z3)); - a = digitreverse4<2>(a); - b = digitreverse4<2>(b); - w0 = part<4, 0>(a); - w1 = part<4, 1>(a); - w2 = part<4, 2>(a); - w3 = part<4, 3>(a); - w4 = part<4, 0>(b); - w5 = part<4, 1>(b); - w6 = part<4, 2>(b); - w7 = part<4, 3>(b); -} - -template <typename T> -KFR_INTRIN void transpose4x8(const cvec<T, 4>& w0, const cvec<T, 4>& w1, const cvec<T, 4>& w2, - const cvec<T, 4>& w3, const cvec<T, 4>& w4, const cvec<T, 4>& w5, - const cvec<T, 4>& w6, const cvec<T, 4>& w7, cvec<T, 8>& z0, cvec<T, 8>& z1, - cvec<T, 8>& z2, cvec<T, 8>& z3) -{ - cvec<T, 16> a = concat(w0, w1, w2, w3); - cvec<T, 16> b = concat(w4, w5, w6, w7); - a = digitreverse4<2>(a); - b = digitreverse4<2>(b); - z0 = concat(part<4, 0>(a), part<4, 0>(b)); - z1 = concat(part<4, 1>(a), part<4, 1>(b)); - z2 = concat(part<4, 2>(a), part<4, 2>(b)); - z3 = concat(part<4, 3>(a), part<4, 3>(b)); -} - -template <typename T> -void transpose4(cvec<T, 16>& a, cvec<T, 16>& b, cvec<T, 16>& c, cvec<T, 16>& d) -{ - cvec<T, 4> a0, a1, a2, a3; - cvec<T, 4> b0, b1, b2, b3; - cvec<T, 4> c0, c1, c2, c3; - cvec<T, 4> d0, d1, d2, d3; - - split(a, a0, a1, a2, a3); - split(b, b0, b1, b2, b3); - split(c, c0, c1, c2, c3); - split(d, d0, d1, d2, d3); - - a = concat(a0, b0, c0, d0); - b = concat(a1, b1, c1, d1); - c = concat(a2, b2, c2, d2); - d = concat(a3, b3, c3, d3); -} -template <typename T> -void transpose4(cvec<T, 16>& a, cvec<T, 16>& b, cvec<T, 16>& c, cvec<T, 16>& d, cvec<T, 16>& aa, - cvec<T, 16>& bb, cvec<T, 16>& cc, cvec<T, 16>& dd) -{ - cvec<T, 4> a0, a1, a2, a3; - cvec<T, 4> b0, b1, b2, b3; - cvec<T, 4> c0, c1, c2, c3; - cvec<T, 4> d0, d1, d2, d3; - - split(a, a0, a1, a2, a3); - split(b, b0, b1, b2, b3); - split(c, c0, c1, c2, c3); - split(d, d0, d1, d2, d3); - - aa = concat(a0, b0, c0, d0); - bb = concat(a1, b1, c1, d1); - cc = concat(a2, b2, c2, d2); - dd = concat(a3, b3, c3, d3); -} - -template <bool b, typename T> -constexpr KFR_INTRIN T chsign(T x) -{ - return b ? -x : x; -} - -template <typename T, size_t N, size_t size, size_t start, size_t step, bool inverse = false, - size_t... indices> -constexpr KFR_INTRIN cvec<T, N> get_fixed_twiddle_helper(csizes_t<indices...>) -{ - return make_vector((indices & 1 ? chsign<inverse>(-sin_using_table<T>(size, (indices / 2 * step + start))) - : cos_using_table<T>(size, (indices / 2 * step + start)))...); -} - -template <typename T, size_t width, size_t... indices> -constexpr KFR_INTRIN cvec<T, width> get_fixed_twiddle_helper(csizes_t<indices...>, size_t size, size_t start, - size_t step) -{ - return make_vector((indices & 1 ? -sin_using_table<T>(size, indices / 2 * step + start) - : cos_using_table<T>(size, indices / 2 * step + start))...); -} - -template <typename T, size_t width, size_t size, size_t start, size_t step = 0, bool inverse = false> -constexpr KFR_INTRIN cvec<T, width> fixed_twiddle() -{ - return get_fixed_twiddle_helper<T, width, size, start, step, inverse>(csizeseq_t<width * 2>()); -} - -template <typename T, size_t width> -constexpr KFR_INTRIN cvec<T, width> fixed_twiddle(size_t size, size_t start, size_t step = 0) -{ - return get_fixed_twiddle_helper<T, width>(csizeseq_t<width * 2>(), start, step, size); -} - -// template <typename T, size_t N, size_t size, size_t start, size_t step = 0, bool inverse = false> -// constexpr cvec<T, N> fixed_twiddle = get_fixed_twiddle<T, N, size, start, step, inverse>(); - -template <typename T, size_t N, bool inverse> -constexpr cvec<T, N> twiddleimagmask() -{ - return inverse ? broadcast<N * 2, T>(-1, +1) : broadcast<N * 2, T>(+1, -1); -} - -CMT_PRAGMA_GNU(GCC diagnostic push) -CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wconversion") - -CMT_PRAGMA_GNU(GCC diagnostic pop) - -template <typename T, size_t N> -CMT_NOINLINE static vec<T, N> cossin_conj(const vec<T, N>& x) -{ - return negodd(cossin(x)); -} - -template <size_t k, size_t size, bool inverse = false, typename T, size_t width, - size_t kk = (inverse ? size - k : k) % size> -KFR_INTRIN vec<T, width> cmul_by_twiddle(const vec<T, width>& x) -{ - constexpr T isqrt2 = static_cast<T>(0.70710678118654752440084436210485); - if (kk == 0) - { - return x; - } - else if (kk == size * 1 / 8) - { - return swap<2>(subadd(swap<2>(x), x)) * isqrt2; - } - else if (kk == size * 2 / 8) - { - return negodd(swap<2>(x)); - } - else if (kk == size * 3 / 8) - { - return subadd(x, swap<2>(x)) * -isqrt2; - } - else if (kk == size * 4 / 8) - { - return -x; - } - else if (kk == size * 5 / 8) - { - return swap<2>(subadd(swap<2>(x), x)) * -isqrt2; - } - else if (kk == size * 6 / 8) - { - return swap<2>(negodd(x)); - } - else if (kk == size * 7 / 8) - { - return subadd(x, swap<2>(x)) * isqrt2; - } - else - { - return cmul(x, resize<width>(fixed_twiddle<T, 1, size, kk>())); - } -} - -template <size_t N, typename T> -KFR_INTRIN void butterfly2(const cvec<T, N>& a0, const cvec<T, N>& a1, cvec<T, N>& w0, cvec<T, N>& w1) -{ - const cvec<T, N> sum = a0 + a1; - const cvec<T, N> dif = a0 - a1; - w0 = sum; - w1 = dif; -} - -template <size_t N, typename T> -KFR_INTRIN void butterfly2(cvec<T, N>& a0, cvec<T, N>& a1) -{ - butterfly2<N>(a0, a1, a0, a1); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly4(cfalse_t /*split_format*/, const cvec<T, N>& a0, const cvec<T, N>& a1, - const cvec<T, N>& a2, const cvec<T, N>& a3, cvec<T, N>& w0, cvec<T, N>& w1, - cvec<T, N>& w2, cvec<T, N>& w3) -{ - cvec<T, N> sum02, sum13, diff02, diff13; - cvec<T, N * 2> a01, a23, sum0213, diff0213; - - a01 = concat(a0, a1); - a23 = concat(a2, a3); - sum0213 = a01 + a23; - diff0213 = a01 - a23; - - sum02 = low(sum0213); - sum13 = high(sum0213); - diff02 = low(diff0213); - diff13 = high(diff0213); - w0 = sum02 + sum13; - w2 = sum02 - sum13; - if (inverse) - { - diff13 = (diff13 ^ broadcast<N * 2, T>(T(), -T())); - diff13 = swap<2>(diff13); - } - else - { - diff13 = swap<2>(diff13); - diff13 = (diff13 ^ broadcast<N * 2, T>(T(), -T())); - } - - w1 = diff02 + diff13; - w3 = diff02 - diff13; -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly4(ctrue_t /*split_format*/, const cvec<T, N>& a0, const cvec<T, N>& a1, - const cvec<T, N>& a2, const cvec<T, N>& a3, cvec<T, N>& w0, cvec<T, N>& w1, - cvec<T, N>& w2, cvec<T, N>& w3) -{ - vec<T, N> re0, im0, re1, im1, re2, im2, re3, im3; - vec<T, N> wre0, wim0, wre1, wim1, wre2, wim2, wre3, wim3; - - cvec<T, N> sum02, sum13, diff02, diff13; - vec<T, N> sum02re, sum13re, diff02re, diff13re; - vec<T, N> sum02im, sum13im, diff02im, diff13im; - - sum02 = a0 + a2; - sum13 = a1 + a3; - - w0 = sum02 + sum13; - w2 = sum02 - sum13; - - diff02 = a0 - a2; - diff13 = a1 - a3; - split(diff02, diff02re, diff02im); - split(diff13, diff13re, diff13im); - - (inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re); - (inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly8(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, - const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, - const cvec<T, N>& a6, const cvec<T, N>& a7, cvec<T, N>& w0, cvec<T, N>& w1, - cvec<T, N>& w2, cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5, cvec<T, N>& w6, - cvec<T, N>& w7) -{ - cvec<T, N> b0 = a0, b2 = a2, b4 = a4, b6 = a6; - butterfly4<N, inverse>(cfalse, b0, b2, b4, b6, b0, b2, b4, b6); - cvec<T, N> b1 = a1, b3 = a3, b5 = a5, b7 = a7; - butterfly4<N, inverse>(cfalse, b1, b3, b5, b7, b1, b3, b5, b7); - w0 = b0 + b1; - w4 = b0 - b1; - - b3 = cmul_by_twiddle<1, 8, inverse>(b3); - b5 = cmul_by_twiddle<2, 8, inverse>(b5); - b7 = cmul_by_twiddle<3, 8, inverse>(b7); - - w1 = b2 + b3; - w5 = b2 - b3; - w2 = b4 + b5; - w6 = b4 - b5; - w3 = b6 + b7; - w7 = b6 - b7; -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly8(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, - cvec<T, N>& a5, cvec<T, N>& a6, cvec<T, N>& a7) -{ - butterfly8<N, inverse>(a0, a1, a2, a3, a4, a5, a6, a7, a0, a1, a2, a3, a4, a5, a6, a7); -} - -template <bool inverse = false, typename T> -KFR_INTRIN void butterfly8(cvec<T, 2>& a01, cvec<T, 2>& a23, cvec<T, 2>& a45, cvec<T, 2>& a67) -{ - cvec<T, 2> b01 = a01, b23 = a23, b45 = a45, b67 = a67; - - butterfly4<2, inverse>(cfalse, b01, b23, b45, b67, b01, b23, b45, b67); - - cvec<T, 2> b02, b13, b46, b57; - - cvec<T, 8> b01234567 = concat(b01, b23, b45, b67); - cvec<T, 8> b02461357 = concat(even<2>(b01234567), odd<2>(b01234567)); - split(b02461357, b02, b46, b13, b57); - - b13 = cmul(b13, fixed_twiddle<T, 2, 8, 0, 1, inverse>()); - b57 = cmul(b57, fixed_twiddle<T, 2, 8, 2, 1, inverse>()); - a01 = b02 + b13; - a23 = b46 + b57; - a45 = b02 - b13; - a67 = b46 - b57; -} - -template <bool inverse = false, typename T> -KFR_INTRIN void butterfly8(cvec<T, 8>& v8) -{ - cvec<T, 2> w0, w1, w2, w3; - split(v8, w0, w1, w2, w3); - butterfly8<inverse>(w0, w1, w2, w3); - v8 = concat(w0, w1, w2, w3); -} - -template <bool inverse = false, typename T> -KFR_INTRIN void butterfly32(cvec<T, 32>& v32) -{ - cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7; - split(v32, w0, w1, w2, w3, w4, w5, w6, w7); - butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7); - - w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>()); - w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>()); - w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>()); - w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>()); - w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>()); - w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>()); - w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>()); - - cvec<T, 8> z0, z1, z2, z3; - transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3); - - butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3); - v32 = concat(z0, z1, z2, z3); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly4(cvec<T, N * 4>& a0123) -{ - cvec<T, N> a0; - cvec<T, N> a1; - cvec<T, N> a2; - cvec<T, N> a3; - split(a0123, a0, a1, a2, a3); - butterfly4<N, inverse>(cfalse, a0, a1, a2, a3, a0, a1, a2, a3); - a0123 = concat(a0, a1, a2, a3); -} - -template <size_t N, typename T> -KFR_INTRIN void butterfly2(cvec<T, N * 2>& a01) -{ - cvec<T, N> a0; - cvec<T, N> a1; - split(a01, a0, a1); - butterfly2<N>(a0, a1); - a01 = concat(a0, a1); -} - -template <size_t N, bool inverse = false, bool split_format = false, typename T> -KFR_INTRIN void apply_twiddle(const cvec<T, N>& a1, const cvec<T, N>& tw1, cvec<T, N>& w1) -{ - if (split_format) - { - vec<T, N> re1, im1, tw1re, tw1im; - split(a1, re1, im1); - split(tw1, tw1re, tw1im); - vec<T, N> b1re = re1 * tw1re; - vec<T, N> b1im = im1 * tw1re; - if (inverse) - w1 = concat(b1re + im1 * tw1im, b1im - re1 * tw1im); - else - w1 = concat(b1re - im1 * tw1im, b1im + re1 * tw1im); - } - else - { - const cvec<T, N> b1 = a1 * dupeven(tw1); - const cvec<T, N> a1_ = swap<2>(a1); - - cvec<T, N> tw1_ = tw1; - if (inverse) - tw1_ = -(tw1_); - w1 = subadd(b1, a1_ * dupodd(tw1_)); - } -} - -template <size_t N, bool inverse = false, bool split_format = false, typename T> -KFR_INTRIN void apply_twiddles4(const cvec<T, N>& a1, const cvec<T, N>& a2, const cvec<T, N>& a3, - const cvec<T, N>& tw1, const cvec<T, N>& tw2, const cvec<T, N>& tw3, - cvec<T, N>& w1, cvec<T, N>& w2, cvec<T, N>& w3) -{ - apply_twiddle<N, inverse, split_format>(a1, tw1, w1); - apply_twiddle<N, inverse, split_format>(a2, tw2, w2); - apply_twiddle<N, inverse, split_format>(a3, tw3, w3); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void apply_twiddles4(cvec<T, N>& __restrict a1, cvec<T, N>& __restrict a2, - cvec<T, N>& __restrict a3, const cvec<T, N>& tw1, const cvec<T, N>& tw2, - const cvec<T, N>& tw3) -{ - apply_twiddles4<N, inverse>(a1, a2, a3, tw1, tw2, tw3, a1, a2, a3); -} - -template <size_t N, bool inverse = false, typename T, typename = u8[N - 1]> -KFR_INTRIN void apply_twiddles4(cvec<T, N>& __restrict a1, cvec<T, N>& __restrict a2, - cvec<T, N>& __restrict a3, const cvec<T, 1>& tw1, const cvec<T, 1>& tw2, - const cvec<T, 1>& tw3) -{ - apply_twiddles4<N, inverse>(a1, a2, a3, resize<N * 2>(tw1), resize<N * 2>(tw2), resize<N * 2>(tw3)); -} - -template <size_t N, bool inverse = false, typename T, typename = u8[N - 2]> -KFR_INTRIN void apply_twiddles4(cvec<T, N>& __restrict a1, cvec<T, N>& __restrict a2, - cvec<T, N>& __restrict a3, cvec<T, N / 2> tw1, cvec<T, N / 2> tw2, - cvec<T, N / 2> tw3) -{ - apply_twiddles4<N, inverse>(a1, a2, a3, resize<N * 2>(tw1), resize<N * 2>(tw2), resize<N * 2>(tw3)); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void apply_vertical_twiddles4(cvec<T, N * 4>& b, cvec<T, N * 4>& c, cvec<T, N * 4>& d) -{ - cvec<T, 4> b0, b1, b2, b3; - cvec<T, 4> c0, c1, c2, c3; - cvec<T, 4> d0, d1, d2, d3; - - split(b, b0, b1, b2, b3); - split(c, c0, c1, c2, c3); - split(d, d0, d1, d2, d3); - - b1 = cmul_by_twiddle<4, 64, inverse>(b1); - b2 = cmul_by_twiddle<8, 64, inverse>(b2); - b3 = cmul_by_twiddle<12, 64, inverse>(b3); - - c1 = cmul_by_twiddle<8, 64, inverse>(c1); - c2 = cmul_by_twiddle<16, 64, inverse>(c2); - c3 = cmul_by_twiddle<24, 64, inverse>(c3); - - d1 = cmul_by_twiddle<12, 64, inverse>(d1); - d2 = cmul_by_twiddle<24, 64, inverse>(d2); - d3 = cmul_by_twiddle<36, 64, inverse>(d3); - - b = concat(b0, b1, b2, b3); - c = concat(c0, c1, c2, c3); - d = concat(d0, d1, d2, d3); -} - -template <size_t n2, size_t nnstep, size_t N, bool inverse = false, typename T> -KFR_INTRIN void apply_twiddles4(cvec<T, N * 4>& __restrict a0123) -{ - cvec<T, N> a0; - cvec<T, N> a1; - cvec<T, N> a2; - cvec<T, N> a3; - split(a0123, a0, a1, a2, a3); - - cvec<T, N> tw1 = fixed_twiddle<T, N, 64, n2 * nnstep * 1, nnstep * 1, inverse>(), - tw2 = fixed_twiddle<T, N, 64, n2 * nnstep * 2, nnstep * 2, inverse>(), - tw3 = fixed_twiddle<T, N, 64, n2 * nnstep * 3, nnstep * 3, inverse>(); - - apply_twiddles4<N>(a1, a2, a3, tw1, tw2, tw3); - - a0123 = concat(a0, a1, a2, a3); -} - -template <bool inverse, bool aligned, typename T> -KFR_INTRIN void butterfly64(cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in) -{ - cvec<T, 16> w0, w1, w2, w3; - - w0 = cread_group<4, 4, 16, aligned>( - in); // concat(cread<4>(in + 0), cread<4>(in + 16), cread<4>(in + 32), cread<4>(in + 48)); - butterfly4<4, inverse>(w0); - apply_twiddles4<0, 1, 4, inverse>(w0); - - w1 = cread_group<4, 4, 16, aligned>( - in + 4); // concat(cread<4>(in + 4), cread<4>(in + 20), cread<4>(in + 36), cread<4>(in + 52)); - butterfly4<4, inverse>(w1); - apply_twiddles4<4, 1, 4, inverse>(w1); - - w2 = cread_group<4, 4, 16, aligned>( - in + 8); // concat(cread<4>(in + 8), cread<4>(in + 24), cread<4>(in + 40), cread<4>(in + 56)); - butterfly4<4, inverse>(w2); - apply_twiddles4<8, 1, 4, inverse>(w2); - - w3 = cread_group<4, 4, 16, aligned>( - in + 12); // concat(cread<4>(in + 12), cread<4>(in + 28), cread<4>(in + 44), cread<4>(in + 60)); - butterfly4<4, inverse>(w3); - apply_twiddles4<12, 1, 4, inverse>(w3); - - transpose4(w0, w1, w2, w3); - // pass 2: - - butterfly4<4, inverse>(w0); - butterfly4<4, inverse>(w1); - butterfly4<4, inverse>(w2); - butterfly4<4, inverse>(w3); - - transpose4(w0, w1, w2, w3); - - w0 = digitreverse4<2>(w0); - w1 = digitreverse4<2>(w1); - w2 = digitreverse4<2>(w2); - w3 = digitreverse4<2>(w3); - - apply_vertical_twiddles4<4, inverse>(w1, w2, w3); - - // pass 3: - butterfly4<4, inverse>(w3); - cwrite_group<4, 4, 16, aligned>(out + 12, w3); // split(w3, out[3], out[7], out[11], out[15]); - - butterfly4<4, inverse>(w2); - cwrite_group<4, 4, 16, aligned>(out + 8, w2); // split(w2, out[2], out[6], out[10], out[14]); - - butterfly4<4, inverse>(w1); - cwrite_group<4, 4, 16, aligned>(out + 4, w1); // split(w1, out[1], out[5], out[9], out[13]); - - butterfly4<4, inverse>(w0); - cwrite_group<4, 4, 16, aligned>(out, w0); // split(w0, out[0], out[4], out[8], out[12]); -} - -template <bool inverse = false, typename T> -KFR_INTRIN void butterfly16(cvec<T, 16>& v16) -{ - butterfly4<4, inverse>(v16); - apply_twiddles4<0, 4, 4, inverse>(v16); - v16 = digitreverse4<2>(v16); - butterfly4<4, inverse>(v16); -} - -template <size_t index, bool inverse = false, typename T> -KFR_INTRIN void butterfly16_multi_natural(complex<T>* out, const complex<T>* in) -{ - constexpr size_t N = 4; - - cvec<T, 4> a1 = cread<4>(in + index * 4 + 16 * 1); - cvec<T, 4> a5 = cread<4>(in + index * 4 + 16 * 5); - cvec<T, 4> a9 = cread<4>(in + index * 4 + 16 * 9); - cvec<T, 4> a13 = cread<4>(in + index * 4 + 16 * 13); - butterfly4<N, inverse>(cfalse, a1, a5, a9, a13, a1, a5, a9, a13); - a5 = cmul_by_twiddle<1, 16, inverse>(a5); - a9 = cmul_by_twiddle<2, 16, inverse>(a9); - a13 = cmul_by_twiddle<3, 16, inverse>(a13); - - cvec<T, 4> a2 = cread<4>(in + index * 4 + 16 * 2); - cvec<T, 4> a6 = cread<4>(in + index * 4 + 16 * 6); - cvec<T, 4> a10 = cread<4>(in + index * 4 + 16 * 10); - cvec<T, 4> a14 = cread<4>(in + index * 4 + 16 * 14); - butterfly4<N, inverse>(cfalse, a2, a6, a10, a14, a2, a6, a10, a14); - a6 = cmul_by_twiddle<2, 16, inverse>(a6); - a10 = cmul_by_twiddle<4, 16, inverse>(a10); - a14 = cmul_by_twiddle<6, 16, inverse>(a14); - - cvec<T, 4> a3 = cread<4>(in + index * 4 + 16 * 3); - cvec<T, 4> a7 = cread<4>(in + index * 4 + 16 * 7); - cvec<T, 4> a11 = cread<4>(in + index * 4 + 16 * 11); - cvec<T, 4> a15 = cread<4>(in + index * 4 + 16 * 15); - butterfly4<N, inverse>(cfalse, a3, a7, a11, a15, a3, a7, a11, a15); - a7 = cmul_by_twiddle<3, 16, inverse>(a7); - a11 = cmul_by_twiddle<6, 16, inverse>(a11); - a15 = cmul_by_twiddle<9, 16, inverse>(a15); - - cvec<T, 4> a0 = cread<4>(in + index * 4 + 16 * 0); - cvec<T, 4> a4 = cread<4>(in + index * 4 + 16 * 4); - cvec<T, 4> a8 = cread<4>(in + index * 4 + 16 * 8); - cvec<T, 4> a12 = cread<4>(in + index * 4 + 16 * 12); - butterfly4<N, inverse>(cfalse, a0, a4, a8, a12, a0, a4, a8, a12); - butterfly4<N, inverse>(cfalse, a0, a1, a2, a3, a0, a1, a2, a3); - cwrite<4>(out + index * 4 + 16 * 0, a0); - cwrite<4>(out + index * 4 + 16 * 4, a1); - cwrite<4>(out + index * 4 + 16 * 8, a2); - cwrite<4>(out + index * 4 + 16 * 12, a3); - butterfly4<N, inverse>(cfalse, a4, a5, a6, a7, a4, a5, a6, a7); - cwrite<4>(out + index * 4 + 16 * 1, a4); - cwrite<4>(out + index * 4 + 16 * 5, a5); - cwrite<4>(out + index * 4 + 16 * 9, a6); - cwrite<4>(out + index * 4 + 16 * 13, a7); - butterfly4<N, inverse>(cfalse, a8, a9, a10, a11, a8, a9, a10, a11); - cwrite<4>(out + index * 4 + 16 * 2, a8); - cwrite<4>(out + index * 4 + 16 * 6, a9); - cwrite<4>(out + index * 4 + 16 * 10, a10); - cwrite<4>(out + index * 4 + 16 * 14, a11); - butterfly4<N, inverse>(cfalse, a12, a13, a14, a15, a12, a13, a14, a15); - cwrite<4>(out + index * 4 + 16 * 3, a12); - cwrite<4>(out + index * 4 + 16 * 7, a13); - cwrite<4>(out + index * 4 + 16 * 11, a14); - cwrite<4>(out + index * 4 + 16 * 15, a15); -} - -template <size_t index, bool inverse = false, typename T> -KFR_INTRIN void butterfly16_multi_flip(complex<T>* out, const complex<T>* in) -{ - constexpr size_t N = 4; - - cvec<T, 4> a1 = cread<4>(in + index * 4 + 16 * 1); - cvec<T, 4> a5 = cread<4>(in + index * 4 + 16 * 5); - cvec<T, 4> a9 = cread<4>(in + index * 4 + 16 * 9); - cvec<T, 4> a13 = cread<4>(in + index * 4 + 16 * 13); - butterfly4<N, inverse>(cfalse, a1, a5, a9, a13, a1, a5, a9, a13); - a5 = cmul_by_twiddle<1, 16, inverse>(a5); - a9 = cmul_by_twiddle<2, 16, inverse>(a9); - a13 = cmul_by_twiddle<3, 16, inverse>(a13); - - cvec<T, 4> a2 = cread<4>(in + index * 4 + 16 * 2); - cvec<T, 4> a6 = cread<4>(in + index * 4 + 16 * 6); - cvec<T, 4> a10 = cread<4>(in + index * 4 + 16 * 10); - cvec<T, 4> a14 = cread<4>(in + index * 4 + 16 * 14); - butterfly4<N, inverse>(cfalse, a2, a6, a10, a14, a2, a6, a10, a14); - a6 = cmul_by_twiddle<2, 16, inverse>(a6); - a10 = cmul_by_twiddle<4, 16, inverse>(a10); - a14 = cmul_by_twiddle<6, 16, inverse>(a14); - - cvec<T, 4> a3 = cread<4>(in + index * 4 + 16 * 3); - cvec<T, 4> a7 = cread<4>(in + index * 4 + 16 * 7); - cvec<T, 4> a11 = cread<4>(in + index * 4 + 16 * 11); - cvec<T, 4> a15 = cread<4>(in + index * 4 + 16 * 15); - butterfly4<N, inverse>(cfalse, a3, a7, a11, a15, a3, a7, a11, a15); - a7 = cmul_by_twiddle<3, 16, inverse>(a7); - a11 = cmul_by_twiddle<6, 16, inverse>(a11); - a15 = cmul_by_twiddle<9, 16, inverse>(a15); - - cvec<T, 16> w1 = concat(a1, a5, a9, a13); - cvec<T, 16> w2 = concat(a2, a6, a10, a14); - cvec<T, 16> w3 = concat(a3, a7, a11, a15); - - cvec<T, 4> a0 = cread<4>(in + index * 4 + 16 * 0); - cvec<T, 4> a4 = cread<4>(in + index * 4 + 16 * 4); - cvec<T, 4> a8 = cread<4>(in + index * 4 + 16 * 8); - cvec<T, 4> a12 = cread<4>(in + index * 4 + 16 * 12); - butterfly4<N, inverse>(cfalse, a0, a4, a8, a12, a0, a4, a8, a12); - cvec<T, 16> w0 = concat(a0, a4, a8, a12); - - butterfly4<N * 4, inverse>(cfalse, w0, w1, w2, w3, w0, w1, w2, w3); - - w0 = digitreverse4<2>(w0); - w1 = digitreverse4<2>(w1); - w2 = digitreverse4<2>(w2); - w3 = digitreverse4<2>(w3); - - transpose4(w0, w1, w2, w3); - cwrite<16>(out + index * 64 + 16 * 0, cmul(w0, fixed_twiddle<T, 16, 256, 0, index * 4 + 0, inverse>())); - cwrite<16>(out + index * 64 + 16 * 1, cmul(w1, fixed_twiddle<T, 16, 256, 0, index * 4 + 1, inverse>())); - cwrite<16>(out + index * 64 + 16 * 2, cmul(w2, fixed_twiddle<T, 16, 256, 0, index * 4 + 2, inverse>())); - cwrite<16>(out + index * 64 + 16 * 3, cmul(w3, fixed_twiddle<T, 16, 256, 0, index * 4 + 3, inverse>())); -} - -template <size_t n2, size_t nnstep, size_t N, typename T> -KFR_INTRIN void apply_twiddles2(cvec<T, N>& a1) -{ - cvec<T, N> tw1 = fixed_twiddle<T, N, 64, n2 * nnstep * 1, nnstep * 1>(); - - a1 = cmul(a1, tw1); -} - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw3r1 = static_cast<T>(-0.5 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw3i1 = - static_cast<T>(0.86602540378443864676372317075) * twiddleimagmask<T, N, inverse>(); - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly3(cvec<T, N> a00, cvec<T, N> a01, cvec<T, N> a02, cvec<T, N>& w00, cvec<T, N>& w01, - cvec<T, N>& w02) -{ - - const cvec<T, N> sum1 = a01 + a02; - const cvec<T, N> dif1 = swap<2>(a01 - a02); - w00 = a00 + sum1; - - const cvec<T, N> s1 = w00 + sum1 * tw3r1<T, N, inverse>; - - const cvec<T, N> d1 = dif1 * tw3i1<T, N, inverse>; - - w01 = s1 + d1; - w02 = s1 - d1; -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly3(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2) -{ - butterfly3<N, inverse>(a0, a1, a2, a0, a1, a2); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly6(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, - const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, cvec<T, N>& w0, - cvec<T, N>& w1, cvec<T, N>& w2, cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5) -{ - cvec<T, N* 2> a03 = concat(a0, a3); - cvec<T, N* 2> a25 = concat(a2, a5); - cvec<T, N* 2> a41 = concat(a4, a1); - butterfly3<N * 2, inverse>(a03, a25, a41, a03, a25, a41); - cvec<T, N> t0, t1, t2, t3, t4, t5; - split(a03, t0, t1); - split(a25, t2, t3); - split(a41, t4, t5); - t3 = -t3; - cvec<T, N* 2> a04 = concat(t0, t4); - cvec<T, N* 2> a15 = concat(t1, t5); - cvec<T, N * 2> w02, w35; - butterfly2<N * 2>(a04, a15, w02, w35); - split(w02, w0, w2); - split(w35, w3, w5); - - butterfly2<N>(t2, t3, w1, w4); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly6(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, - cvec<T, N>& a5) -{ - butterfly6<N, inverse>(a0, a1, a2, a3, a4, a5, a0, a1, a2, a3, a4, a5); -} - -template <typename T, bool inverse = false> -const static cvec<T, 1> tw9_1 = { T(0.76604444311897803520239265055541), - (inverse ? -1 : 1) * T(-0.64278760968653932632264340990727) }; -template <typename T, bool inverse = false> -const static cvec<T, 1> tw9_2 = { T(0.17364817766693034885171662676931), - (inverse ? -1 : 1) * T(-0.98480775301220805936674302458952) }; -template <typename T, bool inverse = false> -const static cvec<T, 1> tw9_4 = { T(-0.93969262078590838405410927732473), - (inverse ? -1 : 1) * T(-0.34202014332566873304409961468226) }; - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly9(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, - const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, - const cvec<T, N>& a6, const cvec<T, N>& a7, const cvec<T, N>& a8, cvec<T, N>& w0, - cvec<T, N>& w1, cvec<T, N>& w2, cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5, - cvec<T, N>& w6, cvec<T, N>& w7, cvec<T, N>& w8) -{ - cvec<T, N* 3> a012 = concat(a0, a1, a2); - cvec<T, N* 3> a345 = concat(a3, a4, a5); - cvec<T, N* 3> a678 = concat(a6, a7, a8); - butterfly3<N * 3, inverse>(a012, a345, a678, a012, a345, a678); - cvec<T, N> t0, t1, t2, t3, t4, t5, t6, t7, t8; - split(a012, t0, t1, t2); - split(a345, t3, t4, t5); - split(a678, t6, t7, t8); - - t4 = cmul(t4, tw9_1<T, inverse>); - t5 = cmul(t5, tw9_2<T, inverse>); - t7 = cmul(t7, tw9_2<T, inverse>); - t8 = cmul(t8, tw9_4<T, inverse>); - - cvec<T, N* 3> t036 = concat(t0, t3, t6); - cvec<T, N* 3> t147 = concat(t1, t4, t7); - cvec<T, N* 3> t258 = concat(t2, t5, t8); - - butterfly3<N * 3, inverse>(t036, t147, t258, t036, t147, t258); - split(t036, w0, w1, w2); - split(t147, w3, w4, w5); - split(t258, w6, w7, w8); -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly9(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, - cvec<T, N>& a5, cvec<T, N>& a6, cvec<T, N>& a7, cvec<T, N>& a8) -{ - butterfly9<N, inverse>(a0, a1, a2, a3, a4, a5, a6, a7, a8, a0, a1, a2, a3, a4, a5, a6, a7, a8); -} - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw7r1 = static_cast<T>(0.623489801858733530525004884 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw7i1 = - static_cast<T>(0.78183148246802980870844452667) * twiddleimagmask<T, N, inverse>(); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw7r2 = static_cast<T>(-0.2225209339563144042889025645 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw7i2 = - static_cast<T>(0.97492791218182360701813168299) * twiddleimagmask<T, N, inverse>(); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw7r3 = static_cast<T>(-0.90096886790241912623610231951 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw7i3 = - static_cast<T>(0.43388373911755812047576833285) * twiddleimagmask<T, N, inverse>(); - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly7(cvec<T, N> a00, cvec<T, N> a01, cvec<T, N> a02, cvec<T, N> a03, cvec<T, N> a04, - cvec<T, N> a05, cvec<T, N> a06, cvec<T, N>& w00, cvec<T, N>& w01, cvec<T, N>& w02, - cvec<T, N>& w03, cvec<T, N>& w04, cvec<T, N>& w05, cvec<T, N>& w06) -{ - const cvec<T, N> sum1 = a01 + a06; - const cvec<T, N> dif1 = swap<2>(a01 - a06); - const cvec<T, N> sum2 = a02 + a05; - const cvec<T, N> dif2 = swap<2>(a02 - a05); - const cvec<T, N> sum3 = a03 + a04; - const cvec<T, N> dif3 = swap<2>(a03 - a04); - w00 = a00 + sum1 + sum2 + sum3; - - const cvec<T, N> s1 = - w00 + sum1 * tw7r1<T, N, inverse> + sum2 * tw7r2<T, N, inverse> + sum3 * tw7r3<T, N, inverse>; - const cvec<T, N> s2 = - w00 + sum1 * tw7r2<T, N, inverse> + sum2 * tw7r3<T, N, inverse> + sum3 * tw7r1<T, N, inverse>; - const cvec<T, N> s3 = - w00 + sum1 * tw7r3<T, N, inverse> + sum2 * tw7r1<T, N, inverse> + sum3 * tw7r2<T, N, inverse>; - - const cvec<T, N> d1 = - dif1 * tw7i1<T, N, inverse> + dif2 * tw7i2<T, N, inverse> + dif3 * tw7i3<T, N, inverse>; - const cvec<T, N> d2 = - dif1 * tw7i2<T, N, inverse> - dif2 * tw7i3<T, N, inverse> - dif3 * tw7i1<T, N, inverse>; - const cvec<T, N> d3 = - dif1 * tw7i3<T, N, inverse> - dif2 * tw7i1<T, N, inverse> + dif3 * tw7i2<T, N, inverse>; - - w01 = s1 + d1; - w06 = s1 - d1; - w02 = s2 + d2; - w05 = s2 - d2; - w03 = s3 + d3; - w04 = s3 - d3; -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly7(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, - cvec<T, N>& a5, cvec<T, N>& a6) -{ - butterfly7<N, inverse>(a0, a1, a2, a3, a4, a5, a6, a0, a1, a2, a3, a4, a5, a6); -} - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11r1 = static_cast<T>(0.84125353283118116886181164892 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11i1 = - static_cast<T>(0.54064081745559758210763595432) * twiddleimagmask<T, N, inverse>(); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11r2 = static_cast<T>(0.41541501300188642552927414923 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11i2 = - static_cast<T>(0.90963199535451837141171538308) * twiddleimagmask<T, N, inverse>(); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11r3 = static_cast<T>(-0.14231483827328514044379266862 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11i3 = - static_cast<T>(0.98982144188093273237609203778) * twiddleimagmask<T, N, inverse>(); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11r4 = static_cast<T>(-0.65486073394528506405692507247 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11i4 = - static_cast<T>(0.75574957435425828377403584397) * twiddleimagmask<T, N, inverse>(); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11r5 = static_cast<T>(-0.95949297361449738989036805707 - 1.0); - -template <typename T, size_t N, bool inverse> -static const cvec<T, N> tw11i5 = - static_cast<T>(0.28173255684142969771141791535) * twiddleimagmask<T, N, inverse>(); - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly11(cvec<T, N> a00, cvec<T, N> a01, cvec<T, N> a02, cvec<T, N> a03, cvec<T, N> a04, - cvec<T, N> a05, cvec<T, N> a06, cvec<T, N> a07, cvec<T, N> a08, cvec<T, N> a09, - cvec<T, N> a10, cvec<T, N>& w00, cvec<T, N>& w01, cvec<T, N>& w02, - cvec<T, N>& w03, cvec<T, N>& w04, cvec<T, N>& w05, cvec<T, N>& w06, - cvec<T, N>& w07, cvec<T, N>& w08, cvec<T, N>& w09, cvec<T, N>& w10) -{ - const cvec<T, N> sum1 = a01 + a10; - const cvec<T, N> dif1 = swap<2>(a01 - a10); - const cvec<T, N> sum2 = a02 + a09; - const cvec<T, N> dif2 = swap<2>(a02 - a09); - const cvec<T, N> sum3 = a03 + a08; - const cvec<T, N> dif3 = swap<2>(a03 - a08); - const cvec<T, N> sum4 = a04 + a07; - const cvec<T, N> dif4 = swap<2>(a04 - a07); - const cvec<T, N> sum5 = a05 + a06; - const cvec<T, N> dif5 = swap<2>(a05 - a06); - w00 = a00 + sum1 + sum2 + sum3 + sum4 + sum5; - - const cvec<T, N> s1 = w00 + sum1 * tw11r1<T, N, inverse> + sum2 * tw11r2<T, N, inverse> + - sum3 * tw11r3<T, N, inverse> + sum4 * tw11r4<T, N, inverse> + - sum5 * tw11r5<T, N, inverse>; - const cvec<T, N> s2 = w00 + sum1 * tw11r2<T, N, inverse> + sum2 * tw11r3<T, N, inverse> + - sum3 * tw11r4<T, N, inverse> + sum4 * tw11r5<T, N, inverse> + - sum5 * tw11r1<T, N, inverse>; - const cvec<T, N> s3 = w00 + sum1 * tw11r3<T, N, inverse> + sum2 * tw11r4<T, N, inverse> + - sum3 * tw11r5<T, N, inverse> + sum4 * tw11r1<T, N, inverse> + - sum5 * tw11r2<T, N, inverse>; - const cvec<T, N> s4 = w00 + sum1 * tw11r4<T, N, inverse> + sum2 * tw11r5<T, N, inverse> + - sum3 * tw11r1<T, N, inverse> + sum4 * tw11r2<T, N, inverse> + - sum5 * tw11r3<T, N, inverse>; - const cvec<T, N> s5 = w00 + sum1 * tw11r5<T, N, inverse> + sum2 * tw11r1<T, N, inverse> + - sum3 * tw11r2<T, N, inverse> + sum4 * tw11r3<T, N, inverse> + - sum5 * tw11r4<T, N, inverse>; - - const cvec<T, N> d1 = dif1 * tw11i1<T, N, inverse> + dif2 * tw11i2<T, N, inverse> + - dif3 * tw11i3<T, N, inverse> + dif4 * tw11i4<T, N, inverse> + - dif5 * tw11i5<T, N, inverse>; - const cvec<T, N> d2 = dif1 * tw11i2<T, N, inverse> - dif2 * tw11i3<T, N, inverse> - - dif3 * tw11i4<T, N, inverse> - dif4 * tw11i5<T, N, inverse> - - dif5 * tw11i1<T, N, inverse>; - const cvec<T, N> d3 = dif1 * tw11i3<T, N, inverse> - dif2 * tw11i4<T, N, inverse> + - dif3 * tw11i5<T, N, inverse> + dif4 * tw11i1<T, N, inverse> + - dif5 * tw11i2<T, N, inverse>; - const cvec<T, N> d4 = dif1 * tw11i4<T, N, inverse> - dif2 * tw11i5<T, N, inverse> + - dif3 * tw11i1<T, N, inverse> - dif4 * tw11i2<T, N, inverse> - - dif5 * tw11i3<T, N, inverse>; - const cvec<T, N> d5 = dif1 * tw11i5<T, N, inverse> - dif2 * tw11i1<T, N, inverse> + - dif3 * tw11i2<T, N, inverse> - dif4 * tw11i3<T, N, inverse> + - dif5 * tw11i4<T, N, inverse>; - - w01 = s1 + d1; - w10 = s1 - d1; - w02 = s2 + d2; - w09 = s2 - d2; - w03 = s3 + d3; - w08 = s3 - d3; - w04 = s4 + d4; - w07 = s4 - d4; - w05 = s5 + d5; - w06 = s5 - d5; -} - -template <typename T, size_t N, bool inverse> -const static cvec<T, N> tw5r1 = static_cast<T>(0.30901699437494742410229341718 - 1.0); -template <typename T, size_t N, bool inverse> -const static cvec<T, N> tw5i1 = - static_cast<T>(0.95105651629515357211643933338) * twiddleimagmask<T, N, inverse>(); -template <typename T, size_t N, bool inverse> -const static cvec<T, N> tw5r2 = static_cast<T>(-0.80901699437494742410229341718 - 1.0); -template <typename T, size_t N, bool inverse> -const static cvec<T, N> tw5i2 = - static_cast<T>(0.58778525229247312916870595464) * twiddleimagmask<T, N, inverse>(); - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly5(const cvec<T, N>& a00, const cvec<T, N>& a01, const cvec<T, N>& a02, - const cvec<T, N>& a03, const cvec<T, N>& a04, cvec<T, N>& w00, cvec<T, N>& w01, - cvec<T, N>& w02, cvec<T, N>& w03, cvec<T, N>& w04) -{ - const cvec<T, N> sum1 = a01 + a04; - const cvec<T, N> dif1 = swap<2>(a01 - a04); - const cvec<T, N> sum2 = a02 + a03; - const cvec<T, N> dif2 = swap<2>(a02 - a03); - w00 = a00 + sum1 + sum2; - - const cvec<T, N> s1 = w00 + sum1 * tw5r1<T, N, inverse> + sum2 * tw5r2<T, N, inverse>; - const cvec<T, N> s2 = w00 + sum1 * tw5r2<T, N, inverse> + sum2 * tw5r1<T, N, inverse>; - - const cvec<T, N> d1 = dif1 * tw5i1<T, N, inverse> + dif2 * tw5i2<T, N, inverse>; - const cvec<T, N> d2 = dif1 * tw5i2<T, N, inverse> - dif2 * tw5i1<T, N, inverse>; - - w01 = s1 + d1; - w04 = s1 - d1; - w02 = s2 + d2; - w03 = s2 - d2; -} - -template <size_t N, bool inverse = false, typename T> -KFR_INTRIN void butterfly10(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, - const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, - const cvec<T, N>& a6, const cvec<T, N>& a7, const cvec<T, N>& a8, - const cvec<T, N>& a9, cvec<T, N>& w0, cvec<T, N>& w1, cvec<T, N>& w2, - cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5, cvec<T, N>& w6, cvec<T, N>& w7, - cvec<T, N>& w8, cvec<T, N>& w9) -{ - cvec<T, N* 2> a05 = concat(a0, a5); - cvec<T, N* 2> a27 = concat(a2, a7); - cvec<T, N* 2> a49 = concat(a4, a9); - cvec<T, N* 2> a61 = concat(a6, a1); - cvec<T, N* 2> a83 = concat(a8, a3); - butterfly5<N * 2, inverse>(a05, a27, a49, a61, a83, a05, a27, a49, a61, a83); - cvec<T, N> t0, t1, t2, t3, t4, t5, t6, t7, t8, t9; - split(a05, t0, t1); - split(a27, t2, t3); - split(a49, t4, t5); - split(a61, t6, t7); - split(a83, t8, t9); - t5 = -t5; - - cvec<T, N * 2> t02, t13; - cvec<T, N * 2> w06, w51; - t02 = concat(t0, t2); - t13 = concat(t1, t3); - butterfly2<N * 2>(t02, t13, w06, w51); - split(w06, w0, w6); - split(w51, w5, w1); - - cvec<T, N * 2> t68, t79; - cvec<T, N * 2> w84, w39; - t68 = concat(t6, t8); - t79 = concat(t7, t9); - butterfly2<N * 2>(t68, t79, w84, w39); - split(w84, w8, w4); - split(w39, w3, w9); - butterfly2<N>(t4, t5, w7, w2); -} - -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, vec<T, N>& out0, - vec<T, N>& out1) -{ - butterfly2<N / 2>(in0, in1, out0, out1); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2) -{ - butterfly3<N / 2, inverse>(in0, in1, in2, out0, out1, out2); -} - -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2, - vec<T, N>& out3) -{ - butterfly4<N / 2, inverse>(cfalse, in0, in1, in2, in3, out0, out1, out2, out3); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, vec<T, N>& out0, vec<T, N>& out1, - vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4) -{ - butterfly5<N / 2, inverse>(in0, in1, in2, in3, in4, out0, out1, out2, out3, out4); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, vec<T, N>& out0, - vec<T, N>& out1, vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5) -{ - butterfly6<N / 2, inverse>(in0, in1, in2, in3, in4, in5, out0, out1, out2, out3, out4, out5); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, - const vec<T, N>& in6, vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2, - vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6) -{ - butterfly7<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, out0, out1, out2, out3, out4, out5, out6); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, - const vec<T, N>& in6, const vec<T, N>& in7, vec<T, N>& out0, vec<T, N>& out1, - vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6, - vec<T, N>& out7) -{ - butterfly8<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, out0, out1, out2, out3, out4, out5, - out6, out7); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, - const vec<T, N>& in6, const vec<T, N>& in7, const vec<T, N>& in8, vec<T, N>& out0, - vec<T, N>& out1, vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, - vec<T, N>& out6, vec<T, N>& out7, vec<T, N>& out8) -{ - butterfly9<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, in8, out0, out1, out2, out3, out4, - out5, out6, out7, out8); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, - const vec<T, N>& in6, const vec<T, N>& in7, const vec<T, N>& in8, - const vec<T, N>& in9, vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2, - vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6, vec<T, N>& out7, - vec<T, N>& out8, vec<T, N>& out9) -{ - butterfly10<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, in8, in9, out0, out1, out2, out3, - out4, out5, out6, out7, out8, out9); -} -template <bool inverse, typename T, size_t N> -KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, - const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, - const vec<T, N>& in6, const vec<T, N>& in7, const vec<T, N>& in8, - const vec<T, N>& in9, const vec<T, N>& in10, vec<T, N>& out0, vec<T, N>& out1, - vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6, - vec<T, N>& out7, vec<T, N>& out8, vec<T, N>& out9, vec<T, N>& out10) -{ - butterfly11<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out0, out1, out2, - out3, out4, out5, out6, out7, out8, out9, out10); -} -template <bool transposed, typename T, size_t... N, size_t Nout = csum<size_t, N...>()> -KFR_INTRIN void cread_transposed(cbool_t<transposed>, const complex<T>* ptr, vec<T, N>&... w) -{ - vec<T, Nout> temp = read<Nout>(ptr_cast<T>(ptr)); - if (transposed) - temp = ctranspose<sizeof...(N)>(temp); - split(temp, w...); -} - -// Warning: Reads past the end. Use with care -KFR_INTRIN void cread_transposed(cbool_t<true>, const complex<f32>* ptr, cvec<f32, 4>& w0, cvec<f32, 4>& w1, - cvec<f32, 4>& w2) -{ - cvec<f32, 4> w3; - cvec<f32, 16> v16 = concat(cread<4>(ptr), cread<4>(ptr + 3), cread<4>(ptr + 6), cread<4>(ptr + 9)); - v16 = digitreverse4<2>(v16); - split(v16, w0, w1, w2, w3); -} - -KFR_INTRIN void cread_transposed(cbool_t<true>, const complex<f32>* ptr, cvec<f32, 4>& w0, cvec<f32, 4>& w1, - cvec<f32, 4>& w2, cvec<f32, 4>& w3, cvec<f32, 4>& w4) -{ - cvec<f32, 16> v16 = concat(cread<4>(ptr), cread<4>(ptr + 5), cread<4>(ptr + 10), cread<4>(ptr + 15)); - v16 = digitreverse4<2>(v16); - split(v16, w0, w1, w2, w3); - w4 = cgather<4, 5>(ptr + 4); -} - -template <bool transposed, typename T, size_t... N, size_t Nout = csum<size_t, N...>()> -KFR_INTRIN void cwrite_transposed(cbool_t<transposed>, complex<T>* ptr, vec<T, N>... args) -{ - auto temp = concat(args...); - if (transposed) - temp = ctransposeinverse<sizeof...(N)>(temp); - write(ptr_cast<T>(ptr), temp); -} - -template <size_t I, size_t radix, typename T, size_t N, size_t width = N / 2> -KFR_INTRIN vec<T, N> mul_tw(cbool_t<false>, const vec<T, N>& x, const complex<T>* twiddle) -{ - return I == 0 ? x : cmul(x, cread<width>(twiddle + width * (I - 1))); -} -template <size_t I, size_t radix, typename T, size_t N, size_t width = N / 2> -KFR_INTRIN vec<T, N> mul_tw(cbool_t<true>, const vec<T, N>& x, const complex<T>* twiddle) -{ - return I == 0 ? x : cmul_conj(x, cread<width>(twiddle + width * (I - 1))); -} - -// Non-final -template <typename T, size_t width, size_t radix, bool inverse, size_t... I> -KFR_INTRIN void butterfly_helper(csizes_t<I...>, size_t i, csize_t<width>, csize_t<radix>, cbool_t<inverse>, - complex<T>* out, const complex<T>* in, const complex<T>* tw, size_t stride) -{ - carray<cvec<T, width>, radix> inout; - - swallow{ (inout.get(csize_t<I>()) = cread<width>(in + i + stride * I))... }; - - butterfly(cbool_t<inverse>(), inout.template get<I>()..., inout.template get<I>()...); - - swallow{ ( - cwrite<width>(out + i + stride * I, - mul_tw<I, radix>(cbool_t<inverse>(), inout.template get<I>(), tw + i * (radix - 1))), - 0)... }; -} - -// Final -template <typename T, size_t width, size_t radix, bool inverse, size_t... I> -KFR_INTRIN void butterfly_helper(csizes_t<I...>, size_t i, csize_t<width>, csize_t<radix>, cbool_t<inverse>, - complex<T>* out, const complex<T>* in, size_t stride) -{ - carray<cvec<T, width>, radix> inout; - - // swallow{ ( inout.get( csize<I> ) = infn( i, I, cvec<T, width>( ) ) )... }; - cread_transposed(ctrue, in + i * radix, inout.template get<I>()...); - - butterfly(cbool_t<inverse>(), inout.template get<I>()..., inout.template get<I>()...); - - swallow{ (cwrite<width>(out + i + stride * I, inout.get(csize_t<I>())), 0)... }; -} - -template <size_t width, size_t radix, typename... Args> -KFR_INTRIN void butterfly(size_t i, csize_t<width>, csize_t<radix>, Args&&... args) -{ - butterfly_helper(csizeseq_t<radix>(), i, csize_t<width>(), csize_t<radix>(), std::forward<Args>(args)...); -} - -template <typename... Args> -KFR_INTRIN void butterfly_cycle(size_t&, size_t, csize_t<0>, Args&&...) -{ -} -template <size_t width, typename... Args> -KFR_INTRIN void butterfly_cycle(size_t& i, size_t count, csize_t<width>, Args&&... args) -{ - CMT_LOOP_NOUNROLL - for (; i < count / width * width; i += width) - butterfly(i, csize_t<width>(), std::forward<Args>(args)...); - butterfly_cycle(i, count, csize_t<width / 2>(), std::forward<Args>(args)...); -} - -template <size_t width, typename... Args> -KFR_INTRIN void butterflies(size_t count, csize_t<width>, Args&&... args) -{ - CMT_ASSUME(count > 0); - size_t i = 0; - butterfly_cycle(i, count, csize_t<width>(), std::forward<Args>(args)...); -} - -template <typename T, bool inverse, typename Tradix, typename Tstride> -KFR_INTRIN void generic_butterfly_cycle(csize_t<0>, Tradix radix, cbool_t<inverse>, complex<T>*, - const complex<T>*, Tstride, size_t, size_t, const complex<T>*, size_t) -{ -} - -template <size_t width, bool inverse, typename T, typename Tradix, typename Thalfradix, - typename Thalfradixsqr, typename Tstride> -KFR_INTRIN void generic_butterfly_cycle(csize_t<width>, Tradix radix, cbool_t<inverse>, complex<T>* out, - const complex<T>* in, Tstride ostride, Thalfradix halfradix, - Thalfradixsqr halfradix_sqr, const complex<T>* twiddle, size_t i) -{ - CMT_LOOP_NOUNROLL - for (; i < halfradix / width * width; i += width) - { - const cvec<T, 1> in0 = cread<1>(in); - cvec<T, width> sum0 = resize<2 * width>(in0); - cvec<T, width> sum1 = sum0; - - for (size_t j = 0; j < halfradix; j++) - { - const cvec<T, 1> ina = cread<1>(in + (1 + j)); - const cvec<T, 1> inb = cread<1>(in + radix - (j + 1)); - cvec<T, width> tw = cread<width>(twiddle); - if (inverse) - tw = negodd /*cconj*/ (tw); - - cmul_2conj(sum0, sum1, ina, inb, tw); - twiddle += halfradix; - } - twiddle = twiddle - halfradix_sqr + width; - - // if (inverse) - // std::swap(sum0, sum1); - - if (is_constant_val(ostride)) - { - cwrite<width>(out + (1 + i), sum0); - cwrite<width>(out + (radix - (i + 1)) - (width - 1), reverse<2>(sum1)); - } - else - { - cscatter<width>(out + (i + 1) * ostride, ostride, sum0); - cscatter<width>(out + (radix - (i + 1)) * ostride - (width - 1) * ostride, ostride, - reverse<2>(sum1)); - } - } - generic_butterfly_cycle(csize_t<width / 2>(), radix, cbool_t<inverse>(), out, in, ostride, halfradix, - halfradix_sqr, twiddle, i); -} - -template <typename T> -KFR_SINTRIN vec<T, 2> hcadd(vec<T, 2> value) -{ - return value; -} -template <typename T, size_t N, KFR_ENABLE_IF(N >= 4)> -KFR_SINTRIN vec<T, 2> hcadd(vec<T, N> value) -{ - return hcadd(low(value) + high(value)); -} - -template <size_t width, typename T, bool inverse, typename Tstride = csize_t<1>> -KFR_INTRIN void generic_butterfly_w(size_t radix, cbool_t<inverse>, complex<T>* out, const complex<T>* in, - const complex<T>* twiddle, Tstride ostride = Tstride{}) -{ - CMT_ASSUME(radix > 0); - { - cvec<T, width> sum = T(); - size_t j = 0; - CMT_LOOP_NOUNROLL - for (; j < radix / width * width; j += width) - { - sum += cread<width>(in + j); - } - cvec<T, 1> sums = T(); - CMT_LOOP_NOUNROLL - for (; j < radix; j++) - { - sums += cread<1>(in + j); - } - cwrite<1>(out, hcadd(sum) + sums); - } - const auto halfradix = radix / 2; - const auto halfradix_sqr = halfradix * halfradix; - CMT_ASSUME(halfradix > 0); - size_t i = 0; - - generic_butterfly_cycle(csize_t<width>(), radix, cbool_t<inverse>(), out, in, ostride, halfradix, - halfradix * halfradix, twiddle, i); -} - -template <size_t width, size_t radix, typename T, bool inverse, typename Tstride = csize_t<1>> -KFR_INTRIN void spec_generic_butterfly_w(csize_t<radix>, cbool_t<inverse>, complex<T>* out, - const complex<T>* in, const complex<T>* twiddle, - Tstride ostride = Tstride{}) -{ - { - cvec<T, width> sum = T(); - size_t j = 0; - CMT_LOOP_UNROLL - for (; j < radix / width * width; j += width) - { - sum += cread<width>(in + j); - } - cvec<T, 1> sums = T(); - CMT_LOOP_UNROLL - for (; j < radix; j++) - { - sums += cread<1>(in + j); - } - cwrite<1>(out, hcadd(sum) + sums); - } - const size_t halfradix = radix / 2; - const size_t halfradix_sqr = halfradix * halfradix; - CMT_ASSUME(halfradix > 0); - size_t i = 0; - - generic_butterfly_cycle(csize_t<width>(), radix, cbool_t<inverse>(), out, in, ostride, halfradix, - halfradix_sqr, twiddle, i); -} - -template <typename T, bool inverse, typename Tstride = csize_t<1>> -KFR_INTRIN void generic_butterfly(size_t radix, cbool_t<inverse>, complex<T>* out, const complex<T>* in, - complex<T>* temp, const complex<T>* twiddle, Tstride ostride = Tstride{}) -{ - constexpr size_t width = platform<T>::vector_width; - - cswitch(csizes_t<11, 13>(), radix, - [&](auto radix_) CMT_INLINE_LAMBDA { - spec_generic_butterfly_w<width>(radix_, cbool_t<inverse>(), out, in, twiddle, ostride); - }, - [&]() CMT_INLINE_LAMBDA { - generic_butterfly_w<width>(radix, cbool_t<inverse>(), out, in, twiddle, ostride); - }); -} - -template <typename T, size_t N> -constexpr cvec<T, N> cmask08 = broadcast<N * 2, T>(T(), -T()); - -template <typename T, size_t N> -constexpr cvec<T, N> cmask0088 = broadcast<N * 4, T>(T(), T(), -T(), -T()); - -template <bool A = false, typename T, size_t N> -KFR_INTRIN void cbitreverse_write(complex<T>* dest, const vec<T, N>& x) -{ - cwrite<N / 2, A>(dest, bitreverse<2>(x)); -} - -template <bool A = false, typename T, size_t N> -KFR_INTRIN void cdigitreverse4_write(complex<T>* dest, const vec<T, N>& x) -{ - cwrite<N / 2, A>(dest, digitreverse4<2>(x)); -} - -template <size_t N, bool A = false, typename T> -KFR_INTRIN cvec<T, N> cbitreverse_read(const complex<T>* src) -{ - return bitreverse<2>(cread<N, A>(src)); -} - -template <size_t N, bool A = false, typename T> -KFR_INTRIN cvec<T, N> cdigitreverse4_read(const complex<T>* src) -{ - return digitreverse4<2>(cread<N, A>(src)); -} - -#if 1 - -template <> -KFR_INTRIN cvec<f64, 16> cdigitreverse4_read<16, false, f64>(const complex<f64>* src) -{ - return concat(cread<1>(src + 0), cread<1>(src + 4), cread<1>(src + 8), cread<1>(src + 12), - cread<1>(src + 1), cread<1>(src + 5), cread<1>(src + 9), cread<1>(src + 13), - cread<1>(src + 2), cread<1>(src + 6), cread<1>(src + 10), cread<1>(src + 14), - cread<1>(src + 3), cread<1>(src + 7), cread<1>(src + 11), cread<1>(src + 15)); -} -template <> -KFR_INTRIN void cdigitreverse4_write<false, f64, 32>(complex<f64>* dest, const vec<f64, 32>& x) -{ - cwrite<1>(dest, part<16, 0>(x)); - cwrite<1>(dest + 4, part<16, 1>(x)); - cwrite<1>(dest + 8, part<16, 2>(x)); - cwrite<1>(dest + 12, part<16, 3>(x)); - - cwrite<1>(dest + 1, part<16, 4>(x)); - cwrite<1>(dest + 5, part<16, 5>(x)); - cwrite<1>(dest + 9, part<16, 6>(x)); - cwrite<1>(dest + 13, part<16, 7>(x)); - - cwrite<1>(dest + 2, part<16, 8>(x)); - cwrite<1>(dest + 6, part<16, 9>(x)); - cwrite<1>(dest + 10, part<16, 10>(x)); - cwrite<1>(dest + 14, part<16, 11>(x)); - - cwrite<1>(dest + 3, part<16, 12>(x)); - cwrite<1>(dest + 7, part<16, 13>(x)); - cwrite<1>(dest + 11, part<16, 14>(x)); - cwrite<1>(dest + 15, part<16, 15>(x)); -} -#endif -} // namespace internal -} // namespace kfr - -CMT_PRAGMA_MSVC(warning(pop)) diff --git a/include/kfr/dft/impl/bitrev.hpp b/include/kfr/dft/impl/bitrev.hpp @@ -0,0 +1,390 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ +#pragma once + +#include "../../base/complex.hpp" +#include "../../base/constants.hpp" +#include "../../base/digitreverse.hpp" +#include "../../base/vec.hpp" + +#include "../../data/bitrev.hpp" + +#include "ft.hpp" + +namespace kfr +{ + +namespace internal +{ + +constexpr bool fft_reorder_aligned = false; + +template <size_t Bits> +CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x) +{ + constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table)); + if (Bits > bitrev_table_log2N) + return bitreverse<Bits>(x); + + return data::bitrev_table[x] >> (bitrev_table_log2N - Bits); +} + +CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits) +{ + constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table)); + if (bits > bitrev_table_log2N) + return bitreverse<32>(x) >> (32 - bits); + + return data::bitrev_table[x] >> (bitrev_table_log2N - bits); +} + +CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits) +{ + constexpr size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev_table)); + if (bits > bitrev_table_log2N) + return digitreverse4<32>(x) >> (32 - bits); + + x = data::bitrev_table[x]; + x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); + x = x >> (bitrev_table_log2N - bits); + return x; +} + +template <size_t log2n, size_t bitrev, typename T> +KFR_INTRIN void fft_reorder_swap(T* inout, size_t i) +{ + using cxx = cvec<T, 16>; + constexpr size_t N = 1 << log2n; + constexpr size_t N4 = 2 * N / 4; + + cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i)); + vi = digitreverse<bitrev, 2>(vi); + cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vi); +} + +template <size_t log2n, size_t bitrev, typename T> +KFR_INTRIN void fft_reorder_swap_two(T* inout, size_t i, size_t j) +{ + CMT_ASSUME(i != j); + using cxx = cvec<T, 16>; + constexpr size_t N = 1 << log2n; + constexpr size_t N4 = 2 * N / 4; + + cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i)); + cxx vj = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j)); + + vi = digitreverse<bitrev, 2>(vi); + cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vi); + vj = digitreverse<bitrev, 2>(vj); + cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), vj); +} + +template <size_t log2n, size_t bitrev, typename T> +KFR_INTRIN void fft_reorder_swap(T* inout, size_t i, size_t j) +{ + CMT_ASSUME(i != j); + using cxx = cvec<T, 16>; + constexpr size_t N = 1 << log2n; + constexpr size_t N4 = 2 * N / 4; + + cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i)); + cxx vj = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j)); + + vi = digitreverse<bitrev, 2>(vi); + cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), vi); + vj = digitreverse<bitrev, 2>(vj); + cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vj); +} + +template <size_t log2n, size_t bitrev, typename T> +KFR_INTRIN void fft_reorder_swap(complex<T>* inout, size_t i) +{ + fft_reorder_swap<log2n, bitrev>(ptr_cast<T>(inout), i * 2); +} + +template <size_t log2n, size_t bitrev, typename T> +KFR_INTRIN void fft_reorder_swap_two(complex<T>* inout, size_t i0, size_t i1) +{ + fft_reorder_swap_two<log2n, bitrev>(ptr_cast<T>(inout), i0 * 2, i1 * 2); +} + +template <size_t log2n, size_t bitrev, typename T> +KFR_INTRIN void fft_reorder_swap(complex<T>* inout, size_t i, size_t j) +{ + fft_reorder_swap<log2n, bitrev>(ptr_cast<T>(inout), i * 2, j * 2); +} + +template <typename T> +KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<11>) +{ + fft_reorder_swap_two<11>(inout, 0 * 4, 8 * 4); + fft_reorder_swap<11>(inout, 1 * 4, 64 * 4); + fft_reorder_swap<11>(inout, 2 * 4, 32 * 4); + fft_reorder_swap<11>(inout, 3 * 4, 96 * 4); + fft_reorder_swap<11>(inout, 4 * 4, 16 * 4); + fft_reorder_swap<11>(inout, 5 * 4, 80 * 4); + fft_reorder_swap<11>(inout, 6 * 4, 48 * 4); + fft_reorder_swap<11>(inout, 7 * 4, 112 * 4); + fft_reorder_swap<11>(inout, 9 * 4, 72 * 4); + fft_reorder_swap<11>(inout, 10 * 4, 40 * 4); + fft_reorder_swap<11>(inout, 11 * 4, 104 * 4); + fft_reorder_swap<11>(inout, 12 * 4, 24 * 4); + fft_reorder_swap<11>(inout, 13 * 4, 88 * 4); + fft_reorder_swap<11>(inout, 14 * 4, 56 * 4); + fft_reorder_swap<11>(inout, 15 * 4, 120 * 4); + fft_reorder_swap<11>(inout, 17 * 4, 68 * 4); + fft_reorder_swap<11>(inout, 18 * 4, 36 * 4); + fft_reorder_swap<11>(inout, 19 * 4, 100 * 4); + fft_reorder_swap_two<11>(inout, 20 * 4, 28 * 4); + fft_reorder_swap<11>(inout, 21 * 4, 84 * 4); + fft_reorder_swap<11>(inout, 22 * 4, 52 * 4); + fft_reorder_swap<11>(inout, 23 * 4, 116 * 4); + fft_reorder_swap<11>(inout, 25 * 4, 76 * 4); + fft_reorder_swap<11>(inout, 26 * 4, 44 * 4); + fft_reorder_swap<11>(inout, 27 * 4, 108 * 4); + fft_reorder_swap<11>(inout, 29 * 4, 92 * 4); + fft_reorder_swap<11>(inout, 30 * 4, 60 * 4); + fft_reorder_swap<11>(inout, 31 * 4, 124 * 4); + fft_reorder_swap<11>(inout, 33 * 4, 66 * 4); + fft_reorder_swap_two<11>(inout, 34 * 4, 42 * 4); + fft_reorder_swap<11>(inout, 35 * 4, 98 * 4); + fft_reorder_swap<11>(inout, 37 * 4, 82 * 4); + fft_reorder_swap<11>(inout, 38 * 4, 50 * 4); + fft_reorder_swap<11>(inout, 39 * 4, 114 * 4); + fft_reorder_swap<11>(inout, 41 * 4, 74 * 4); + fft_reorder_swap<11>(inout, 43 * 4, 106 * 4); + fft_reorder_swap<11>(inout, 45 * 4, 90 * 4); + fft_reorder_swap<11>(inout, 46 * 4, 58 * 4); + fft_reorder_swap<11>(inout, 47 * 4, 122 * 4); + fft_reorder_swap<11>(inout, 49 * 4, 70 * 4); + fft_reorder_swap<11>(inout, 51 * 4, 102 * 4); + fft_reorder_swap<11>(inout, 53 * 4, 86 * 4); + fft_reorder_swap_two<11>(inout, 54 * 4, 62 * 4); + fft_reorder_swap<11>(inout, 55 * 4, 118 * 4); + fft_reorder_swap<11>(inout, 57 * 4, 78 * 4); + fft_reorder_swap<11>(inout, 59 * 4, 110 * 4); + fft_reorder_swap<11>(inout, 61 * 4, 94 * 4); + fft_reorder_swap<11>(inout, 63 * 4, 126 * 4); + fft_reorder_swap_two<11>(inout, 65 * 4, 73 * 4); + fft_reorder_swap<11>(inout, 67 * 4, 97 * 4); + fft_reorder_swap<11>(inout, 69 * 4, 81 * 4); + fft_reorder_swap<11>(inout, 71 * 4, 113 * 4); + fft_reorder_swap<11>(inout, 75 * 4, 105 * 4); + fft_reorder_swap<11>(inout, 77 * 4, 89 * 4); + fft_reorder_swap<11>(inout, 79 * 4, 121 * 4); + fft_reorder_swap<11>(inout, 83 * 4, 101 * 4); + fft_reorder_swap_two<11>(inout, 85 * 4, 93 * 4); + fft_reorder_swap<11>(inout, 87 * 4, 117 * 4); + fft_reorder_swap<11>(inout, 91 * 4, 109 * 4); + fft_reorder_swap<11>(inout, 95 * 4, 125 * 4); + fft_reorder_swap_two<11>(inout, 99 * 4, 107 * 4); + fft_reorder_swap<11>(inout, 103 * 4, 115 * 4); + fft_reorder_swap<11>(inout, 111 * 4, 123 * 4); + fft_reorder_swap_two<11>(inout, 119 * 4, 127 * 4); +} + +template <typename T> +KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<7>) +{ + constexpr size_t bitrev = 2; + fft_reorder_swap_two<7, bitrev>(inout, 0 * 4, 2 * 4); + fft_reorder_swap<7, bitrev>(inout, 1 * 4, 4 * 4); + fft_reorder_swap<7, bitrev>(inout, 3 * 4, 6 * 4); + fft_reorder_swap_two<7, bitrev>(inout, 5 * 4, 7 * 4); +} + +template <typename T> +KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<8>) +{ + constexpr size_t bitrev = 4; + fft_reorder_swap_two<8, bitrev>(inout, 0 * 4, 5 * 4); + fft_reorder_swap<8, bitrev>(inout, 1 * 4, 4 * 4); + fft_reorder_swap<8, bitrev>(inout, 2 * 4, 8 * 4); + fft_reorder_swap<8, bitrev>(inout, 3 * 4, 12 * 4); + fft_reorder_swap<8, bitrev>(inout, 6 * 4, 9 * 4); + fft_reorder_swap<8, bitrev>(inout, 7 * 4, 13 * 4); + fft_reorder_swap_two<8, bitrev>(inout, 10 * 4, 15 * 4); + fft_reorder_swap<8, bitrev>(inout, 11 * 4, 14 * 4); +} + +template <typename T> +KFR_INTRIN void fft_reorder(complex<T>* inout, csize_t<9>) +{ + constexpr size_t bitrev = 2; + fft_reorder_swap_two<9, bitrev>(inout, 0 * 4, 4 * 4); + fft_reorder_swap<9, bitrev>(inout, 1 * 4, 16 * 4); + fft_reorder_swap<9, bitrev>(inout, 2 * 4, 8 * 4); + fft_reorder_swap<9, bitrev>(inout, 3 * 4, 24 * 4); + fft_reorder_swap<9, bitrev>(inout, 5 * 4, 20 * 4); + fft_reorder_swap<9, bitrev>(inout, 6 * 4, 12 * 4); + fft_reorder_swap<9, bitrev>(inout, 7 * 4, 28 * 4); + fft_reorder_swap<9, bitrev>(inout, 9 * 4, 18 * 4); + fft_reorder_swap_two<9, bitrev>(inout, 10 * 4, 14 * 4); + fft_reorder_swap<9, bitrev>(inout, 11 * 4, 26 * 4); + fft_reorder_swap<9, bitrev>(inout, 13 * 4, 22 * 4); + fft_reorder_swap<9, bitrev>(inout, 15 * 4, 30 * 4); + fft_reorder_swap_two<9, bitrev>(inout, 17 * 4, 21 * 4); + fft_reorder_swap<9, bitrev>(inout, 19 * 4, 25 * 4); + fft_reorder_swap<9, bitrev>(inout, 23 * 4, 29 * 4); + fft_reorder_swap_two<9, bitrev>(inout, 27 * 4, 31 * 4); +} + +template <typename T, bool use_br2> +void cwrite_reordered(T* out, const cvec<T, 16>& value, size_t N4, cbool_t<use_br2>) +{ + cwrite_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(out), N4, + digitreverse<(use_br2 ? 2 : 4), 2>(value)); +} + +template <typename T, bool use_br2> +KFR_INTRIN void fft_reorder_swap_n4(T* inout, size_t i, size_t j, size_t N4, cbool_t<use_br2>) +{ + CMT_ASSUME(i != j); + const cvec<T, 16> vi = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), N4); + const cvec<T, 16> vj = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), N4); + cwrite_reordered(inout + j, vi, N4, cbool_t<use_br2>()); + cwrite_reordered(inout + i, vj, N4, cbool_t<use_br2>()); +} + +template <typename T> +KFR_INTRIN void fft_reorder(complex<T>* inout, size_t log2n, ctrue_t use_br2) +{ + const size_t N = 1 << log2n; + const size_t N4 = N / 4; + const size_t iend = N / 16 * 4 * 2; + constexpr size_t istep = 2 * 4; + const size_t jstep1 = (1 << (log2n - 5)) * 4 * 2; + const size_t jstep2 = size_t(1 << (log2n - 5)) * 4 * 2 - size_t(1 << (log2n - 6)) * 4 * 2; + T* io = ptr_cast<T>(inout); + + for (size_t i = 0; i < iend;) + { + size_t j = bitrev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep; + j = j + jstep1; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep; + j = j - jstep2; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep; + j = j + jstep1; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep; + } +} + +template <typename T> +KFR_INTRIN void fft_reorder(complex<T>* inout, size_t log2n, cfalse_t use_br2) +{ + const size_t N = size_t(1) << log2n; + const size_t N4 = N / 4; + const size_t N16 = N * 2 / 16; + size_t iend = N16; + constexpr size_t istep = 2 * 4; + const size_t jstep = N / 64 * 4 * 2; + T* io = ptr_cast<T>(inout); + + size_t i = 0; + CMT_PRAGMA_CLANG(clang loop unroll_count(2)) + for (; i < iend;) + { + size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep * 4; + } + iend += N16; + CMT_PRAGMA_CLANG(clang loop unroll_count(2)) + for (; i < iend;) + { + size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; + + fft_reorder_swap_n4(io, i, j, N4, use_br2); + + i += istep; + j = j + jstep; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep * 3; + } + iend += N16; + CMT_PRAGMA_CLANG(clang loop unroll_count(2)) + for (; i < iend;) + { + size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; + + fft_reorder_swap_n4(io, i, j, N4, use_br2); + + i += istep; + j = j + jstep; + + fft_reorder_swap_n4(io, i, j, N4, use_br2); + + i += istep; + j = j + jstep; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep * 2; + } + iend += N16; + CMT_PRAGMA_CLANG(clang loop unroll_count(2)) + for (; i < iend;) + { + size_t j = dig4rev_using_table(static_cast<u32>(i >> 3), log2n - 4) << 3; + + fft_reorder_swap_n4(io, i, j, N4, use_br2); + + i += istep; + j = j + jstep; + + fft_reorder_swap_n4(io, i, j, N4, use_br2); + + i += istep; + j = j + jstep; + + fft_reorder_swap_n4(io, i, j, N4, use_br2); + + i += istep; + j = j + jstep; + + if (i >= j) + fft_reorder_swap_n4(io, i, j, N4, use_br2); + i += istep; + } +} +} +} diff --git a/include/kfr/dft/impl/convolution-impl.cpp b/include/kfr/dft/impl/convolution-impl.cpp @@ -0,0 +1,204 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ +#include "../convolution.hpp" + +namespace kfr +{ + +namespace internal +{ + +template <typename T> +univector<T> convolve(const univector_ref<const T>& src1, const univector_ref<const T>& src2) +{ + const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); + univector<complex<T>> src1padded = src1; + univector<complex<T>> src2padded = src2; + src1padded.resize(size, 0); + src2padded.resize(size, 0); + + dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size); + univector<u8> temp(dft->temp_size); + dft->execute(src1padded, src1padded, temp); + dft->execute(src2padded, src2padded, temp); + src1padded = src1padded * src2padded; + dft->execute(src1padded, src1padded, temp, true); + const T invsize = reciprocal<T>(size); + return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; +} + +template <typename T> +univector<T> correlate(const univector_ref<const T>& src1, const univector_ref<const T>& src2) +{ + const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); + univector<complex<T>> src1padded = src1; + univector<complex<T>> src2padded = reverse(src2); + src1padded.resize(size, 0); + src2padded.resize(size, 0); + dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), size); + univector<u8> temp(dft->temp_size); + dft->execute(src1padded, src1padded, temp); + dft->execute(src2padded, src2padded, temp); + src1padded = src1padded * src2padded; + dft->execute(src1padded, src1padded, temp, true); + const T invsize = reciprocal<T>(size); + return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; +} + +template <typename T> +univector<T> autocorrelate(const univector_ref<const T>& src1) +{ + univector<T> result = correlate(src1, src1); + result = result.slice(result.size() / 2); + return result; +} + +} // namespace internal + +template <typename T> +convolve_filter<T>::convolve_filter(size_t size, size_t block_size) + : fft(2 * next_poweroftwo(block_size)), size(size), block_size(block_size), temp(fft.temp_size), + segments((size + block_size - 1) / block_size) +{ +} + +template <typename T> +convolve_filter<T>::convolve_filter(const univector<T>& data, size_t block_size) + : fft(2 * next_poweroftwo(block_size)), size(data.size()), block_size(next_poweroftwo(block_size)), + temp(fft.temp_size), + segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), + ir_segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), + input_position(0), position(0) +{ + set_data(data); +} + +template <typename T> +void convolve_filter<T>::set_data(const univector<T>& data) +{ + univector<T> input(fft.size); + const T ifftsize = reciprocal(T(fft.size)); + for (size_t i = 0; i < ir_segments.size(); i++) + { + segments[i].resize(block_size); + ir_segments[i].resize(block_size, 0); + input = padded(data.slice(i * block_size, block_size)); + + fft.execute(ir_segments[i], input, temp, dft_pack_format::Perm); + process(ir_segments[i], ir_segments[i] * ifftsize); + } + saved_input.resize(block_size, 0); + scratch.resize(block_size * 2); + premul.resize(block_size, 0); + cscratch.resize(block_size); + overlap.resize(block_size, 0); +} + +template <typename T> +void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size) +{ + size_t processed = 0; + while (processed < size) + { + const size_t processing = std::min(size - processed, block_size - input_position); + internal::builtin_memcpy(saved_input.data() + input_position, input + processed, + processing * sizeof(T)); + + process(scratch, padded(saved_input)); + fft.execute(segments[position], scratch, temp, dft_pack_format::Perm); + + if (input_position == 0) + { + process(premul, zeros()); + for (size_t i = 1; i < segments.size(); i++) + { + const size_t n = (position + i) % segments.size(); + fft_multiply_accumulate(premul, ir_segments[i], segments[n], dft_pack_format::Perm); + } + } + fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], dft_pack_format::Perm); + + fft.execute(scratch, cscratch, temp, dft_pack_format::Perm); + + process(make_univector(output + processed, processing), + scratch.slice(input_position) + overlap.slice(input_position)); + + input_position += processing; + if (input_position == block_size) + { + input_position = 0; + process(saved_input, zeros()); + + internal::builtin_memcpy(overlap.data(), scratch.data() + block_size, block_size * sizeof(T)); + + position = position > 0 ? position - 1 : segments.size() - 1; + } + + processed += processing; + } +} + +namespace internal +{ + +template univector<float> convolve<float>(const univector_ref<const float>&, + const univector_ref<const float>&); +template univector<float> correlate<float>(const univector_ref<const float>&, + const univector_ref<const float>&); + +template univector<float> autocorrelate<float>(const univector_ref<const float>&); + +} // namespace internal + +template convolve_filter<float>::convolve_filter(size_t, size_t); + +template convolve_filter<float>::convolve_filter(const univector<float>&, size_t); + +template void convolve_filter<float>::set_data(const univector<float>&); + +template void convolve_filter<float>::process_buffer(float* output, const float* input, size_t size); + +namespace internal +{ + +template univector<double> convolve<double>(const univector_ref<const double>&, + const univector_ref<const double>&); +template univector<double> correlate<double>(const univector_ref<const double>&, + const univector_ref<const double>&); + +template univector<double> autocorrelate<double>(const univector_ref<const double>&); + +} // namespace internal + +template convolve_filter<double>::convolve_filter(size_t, size_t); + +template convolve_filter<double>::convolve_filter(const univector<double>&, size_t); + +template void convolve_filter<double>::set_data(const univector<double>&); + +template void convolve_filter<double>::process_buffer(double* output, const double* input, size_t size); + +} // namespace kfr diff --git a/include/kfr/dft/impl/dft-impl-f32.cpp b/include/kfr/dft/impl/dft-impl-f32.cpp @@ -0,0 +1,29 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ +#include "dft-impl.hpp" + +#define FLOAT float +#include "dft-templates.hpp" diff --git a/include/kfr/dft/impl/dft-impl-f64.cpp b/include/kfr/dft/impl/dft-impl-f64.cpp @@ -0,0 +1,30 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ +#include "dft-impl.hpp" + +#define FLOAT double +#include "dft-templates.hpp" + diff --git a/include/kfr/dft/impl/dft-impl.hpp b/include/kfr/dft/impl/dft-impl.hpp @@ -0,0 +1,1689 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ + +#include "../dft_c.h" + +#include "../../base/basic_expressions.hpp" +#include "../../testo/assert.hpp" +#include "bitrev.hpp" +#include "../cache.hpp" +#include "../fft.hpp" +#include "ft.hpp" + +CMT_PRAGMA_GNU(GCC diagnostic push) +#if CMT_HAS_WARNING("-Wshadow") +CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow") +#endif + +CMT_PRAGMA_MSVC(warning(push)) +CMT_PRAGMA_MSVC(warning(disable : 4100)) + +namespace kfr +{ + +constexpr csizes_t<2, 3, 4, 5, 6, 7, 8, 9, 10> dft_radices{}; + +#define DFT_ASSERT TESTO_ASSERT_INACTIVE + +template <typename T> +constexpr size_t fft_vector_width = platform<T>::vector_width; + +using cdirect_t = cfalse_t; +using cinvert_t = ctrue_t; + +template <typename T> +struct dft_stage +{ + size_t radix = 0; + size_t stage_size = 0; + size_t data_size = 0; + size_t temp_size = 0; + u8* data = nullptr; + size_t repeats = 1; + size_t out_offset = 0; + size_t blocks = 0; + const char* name = nullptr; + bool recursion = false; + bool can_inplace = true; + bool inplace = false; + bool to_scratch = false; + bool need_reorder = true; + + void initialize(size_t size) { do_initialize(size); } + + virtual void dump() const + { + printf("%s: \n\t%5zu,%5zu,%5zu,%5zu,%5zu,%5zu,%5zu, %d, %d, %d, %d\n", name ? name : "unnamed", radix, + stage_size, data_size, temp_size, repeats, out_offset, blocks, recursion, can_inplace, inplace, + to_scratch); + } + + KFR_INTRIN void execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) + { + do_execute(cdirect_t(), out, in, temp); + } + KFR_INTRIN void execute(cinvert_t, complex<T>* out, const complex<T>* in, u8* temp) + { + do_execute(cinvert_t(), out, in, temp); + } + virtual ~dft_stage() {} + +protected: + virtual void do_initialize(size_t) {} + virtual void do_execute(cdirect_t, complex<T>*, const complex<T>*, u8* temp) = 0; + virtual void do_execute(cinvert_t, complex<T>*, const complex<T>*, u8* temp) = 0; +}; + +#define DFT_STAGE_FN \ + void do_execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) override \ + { \ + return do_execute<false>(out, in, temp); \ + } \ + void do_execute(cinvert_t, complex<T>* out, const complex<T>* in, u8* temp) override \ + { \ + return do_execute<true>(out, in, temp); \ + } + +CMT_PRAGMA_GNU(GCC diagnostic push) +#if CMT_HAS_WARNING("-Wassume") +CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wassume") +#endif + +namespace internal +{ + +template <size_t width, bool inverse, typename T> +KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split_format*/, cbool_t<inverse>, + const cvec<T, width>& w, const cvec<T, width>& tw) +{ + cvec<T, width> ww = w; + cvec<T, width> tw_ = tw; + cvec<T, width> b1 = ww * dupeven(tw_); + ww = swap<2>(ww); + + if (inverse) + tw_ = -(tw_); + ww = subadd(b1, ww * dupodd(tw_)); + return ww; +} + +template <size_t width, bool use_br2, bool inverse, bool aligned, typename T> +KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, cfalse_t, cfalse_t, cfalse_t, cbool_t<use_br2>, + cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in, + const complex<T>* twiddle) +{ + const size_t N4 = N / 4; + cvec<T, width> w1, w2, w3; + + cvec<T, width> sum02, sum13, diff02, diff13; + + cvec<T, width> a0, a1, a2, a3; + a0 = cread<width, aligned>(in + 0); + a2 = cread<width, aligned>(in + N4 * 2); + sum02 = a0 + a2; + + a1 = cread<width, aligned>(in + N4); + a3 = cread<width, aligned>(in + N4 * 3); + sum13 = a1 + a3; + + cwrite<width, aligned>(out, sum02 + sum13); + w2 = sum02 - sum13; + cwrite<width, aligned>(out + N4 * (use_br2 ? 1 : 2), + radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w2, + cread<width, true>(twiddle + width))); + diff02 = a0 - a2; + diff13 = a1 - a3; + if (inverse) + { + diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T())); + diff13 = swap<2>(diff13); + } + else + { + diff13 = swap<2>(diff13); + diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T())); + } + + w1 = diff02 + diff13; + + cwrite<width, aligned>(out + N4 * (use_br2 ? 2 : 1), + radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), w1, + cread<width, true>(twiddle + 0))); + w3 = diff02 - diff13; + cwrite<width, aligned>(out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), cfalse, cbool_t<inverse>(), + w3, cread<width, true>(twiddle + width * 2))); +} + +template <size_t width, bool inverse, typename T> +KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, ctrue_t /*split_format*/, cbool_t<inverse>, + const cvec<T, width>& w, const cvec<T, width>& tw) +{ + vec<T, width> re1, im1, twre, twim; + split(w, re1, im1); + split(tw, twre, twim); + + const vec<T, width> b1re = re1 * twre; + const vec<T, width> b1im = im1 * twre; + if (inverse) + return concat(b1re + im1 * twim, b1im - re1 * twim); + else + return concat(b1re - im1 * twim, b1im + re1 * twim); +} + +template <size_t width, bool splitout, bool splitin, bool use_br2, bool inverse, bool aligned, typename T> +KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, ctrue_t, cbool_t<splitout>, cbool_t<splitin>, + cbool_t<use_br2>, cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, + const complex<T>* in, const complex<T>* twiddle) +{ + const size_t N4 = N / 4; + cvec<T, width> w1, w2, w3; + constexpr bool read_split = !splitin && splitout; + constexpr bool write_split = splitin && !splitout; + + vec<T, width> re0, im0, re1, im1, re2, im2, re3, im3; + + split(cread_split<width, aligned, read_split>(in + N4 * 0), re0, im0); + split(cread_split<width, aligned, read_split>(in + N4 * 1), re1, im1); + split(cread_split<width, aligned, read_split>(in + N4 * 2), re2, im2); + split(cread_split<width, aligned, read_split>(in + N4 * 3), re3, im3); + + const vec<T, width> sum02re = re0 + re2; + const vec<T, width> sum02im = im0 + im2; + const vec<T, width> sum13re = re1 + re3; + const vec<T, width> sum13im = im1 + im3; + + cwrite_split<width, aligned, write_split>(out, concat(sum02re + sum13re, sum02im + sum13im)); + w2 = concat(sum02re - sum13re, sum02im - sum13im); + cwrite_split<width, aligned, write_split>( + out + N4 * (use_br2 ? 1 : 2), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w2, + cread<width, true>(twiddle + width))); + + const vec<T, width> diff02re = re0 - re2; + const vec<T, width> diff02im = im0 - im2; + const vec<T, width> diff13re = re1 - re3; + const vec<T, width> diff13im = im1 - im3; + + (inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re); + (inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re); + + cwrite_split<width, aligned, write_split>( + out + N4 * (use_br2 ? 2 : 1), radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w1, + cread<width, true>(twiddle + 0))); + cwrite_split<width, aligned, write_split>( + out + N4 * 3, radix4_apply_twiddle(csize_t<width>(), ctrue, cbool_t<inverse>(), w3, + cread<width, true>(twiddle + width * 2))); +} + +template <typename T> +CMT_NOINLINE cvec<T, 1> calculate_twiddle(size_t n, size_t size) +{ + if (n == 0) + { + return make_vector(static_cast<T>(1), static_cast<T>(0)); + } + else if (n == size / 4) + { + return make_vector(static_cast<T>(0), static_cast<T>(-1)); + } + else if (n == size / 2) + { + return make_vector(static_cast<T>(-1), static_cast<T>(0)); + } + else if (n == size * 3 / 4) + { + return make_vector(static_cast<T>(0), static_cast<T>(1)); + } + else + { + fbase kth = c_pi<fbase, 2> * (n / static_cast<fbase>(size)); + fbase tcos = +kfr::cos(kth); + fbase tsin = -kfr::sin(kth); + return make_vector(static_cast<T>(tcos), static_cast<T>(tsin)); + } +} + +template <typename T, size_t width> +KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_t nnstep, size_t size, + bool split_format) +{ + vec<T, 2 * width> result = T(); + CMT_LOOP_UNROLL + for (size_t i = 0; i < width; i++) + { + const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size); + result[i * 2] = r[0]; + result[i * 2 + 1] = r[1]; + } + if (split_format) + ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result); + else + ref_cast<cvec<T, width>>(twiddle[0]) = result; + twiddle += width; +} + +template <typename T, size_t width> +CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format) +{ + const size_t count = stage_size / 4; + size_t nnstep = size / stage_size; + DFT_ASSERT(width <= count); + CMT_LOOP_NOUNROLL + for (size_t n = 0; n < count; n += width) + { + initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format); + initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format); + initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 3, nnstep * 3, size, split_format); + } +} + +#if defined CMT_ARCH_SSE +#ifdef CMT_COMPILER_GNU +#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr), 0, _MM_HINT_T0); +#else +#define KFR_PREFETCH(addr) _mm_prefetch(::kfr::ptr_cast<char>(addr), _MM_HINT_T0); +#endif +#else +#define KFR_PREFETCH(addr) __builtin_prefetch(::kfr::ptr_cast<void>(addr)); +#endif + +template <typename T> +KFR_SINTRIN void prefetch_one(const complex<T>* in) +{ + KFR_PREFETCH(in); +} + +template <typename T> +KFR_SINTRIN void prefetch_four(size_t stride, const complex<T>* in) +{ + KFR_PREFETCH(in); + KFR_PREFETCH(in + stride); + KFR_PREFETCH(in + stride * 2); + KFR_PREFETCH(in + stride * 3); +} + +template <typename Ntype, size_t width, bool splitout, bool splitin, bool prefetch, bool use_br2, + bool inverse, bool aligned, typename T> +KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t<splitout>, cbool_t<splitin>, + cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, + complex<T>* out, const complex<T>* in, const complex<T>*& twiddle) +{ + constexpr static size_t prefetch_offset = width * 8; + const auto N4 = N / csize_t<4>(); + const auto N43 = N4 * csize_t<3>(); + CMT_ASSUME(blocks > 0); + CMT_ASSUME(N > 0); + CMT_ASSUME(N4 > 0); + DFT_ASSERT(width <= N4); + CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++) + { + CMT_PRAGMA_CLANG(clang loop unroll_count(2)) + for (size_t n2 = 0; n2 < N4; n2 += width) + { + if (prefetch) + prefetch_four(N4, in + prefetch_offset); + radix4_body(N, csize_t<width>(), cbool_t<(splitout || splitin)>(), cbool_t<splitout>(), + cbool_t<splitin>(), cbool_t<use_br2>(), cbool_t<inverse>(), cbool_t<aligned>(), out, + in, twiddle + n2 * 3); + in += width; + out += width; + } + in += N43; + out += N43; + } + twiddle += N43; + return {}; +} + +template <bool splitin, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> +KFR_SINTRIN ctrue_t radix4_pass(csize_t<32>, size_t blocks, csize_t<8>, cfalse_t, cbool_t<splitin>, + cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, + complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) +{ + CMT_ASSUME(blocks > 0); + constexpr static size_t prefetch_offset = 32 * 4; + for (size_t b = 0; b < blocks; b++) + { + if (prefetch) + prefetch_four(csize_t<64>(), out + prefetch_offset); + cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7; + split(cread_split<8, aligned, splitin>(out + 0), w0, w1); + split(cread_split<8, aligned, splitin>(out + 8), w2, w3); + split(cread_split<8, aligned, splitin>(out + 16), w4, w5); + split(cread_split<8, aligned, splitin>(out + 24), w6, w7); + + butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7); + + w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>()); + w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>()); + w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>()); + w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>()); + w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>()); + w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>()); + w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>()); + + cvec<T, 8> z0, z1, z2, z3; + transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3); + + butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3); + cwrite<32, aligned>(out, bitreverse<2>(concat(z0, z1, z2, z3))); + out += 32; + } + return {}; +} + +template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> +KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, + cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, + complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) +{ + CMT_ASSUME(blocks > 0); + DFT_ASSERT(2 <= blocks); + constexpr static size_t prefetch_offset = width * 16; + for (size_t b = 0; b < blocks; b += 2) + { + if (prefetch) + prefetch_one(out + prefetch_offset); + + cvec<T, 8> vlo = cread<8, aligned>(out + 0); + cvec<T, 8> vhi = cread<8, aligned>(out + 8); + butterfly8<inverse>(vlo); + butterfly8<inverse>(vhi); + vlo = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vlo); + vhi = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vhi); + cwrite<8, aligned>(out, vlo); + cwrite<8, aligned>(out + 8, vhi); + out += 16; + } + return {}; +} + +template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> +KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, + cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, + complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) +{ + CMT_ASSUME(blocks > 0); + constexpr static size_t prefetch_offset = width * 4; + DFT_ASSERT(2 <= blocks); + CMT_PRAGMA_CLANG(clang loop unroll_count(2)) + for (size_t b = 0; b < blocks; b += 2) + { + if (prefetch) + prefetch_one(out + prefetch_offset); + + cvec<T, 16> vlo = cread<16, aligned>(out); + cvec<T, 16> vhi = cread<16, aligned>(out + 16); + butterfly4<4, inverse>(vlo); + butterfly4<4, inverse>(vhi); + apply_twiddles4<0, 4, 4, inverse>(vlo); + apply_twiddles4<0, 4, 4, inverse>(vhi); + vlo = digitreverse4<2>(vlo); + vhi = digitreverse4<2>(vhi); + butterfly4<4, inverse>(vlo); + butterfly4<4, inverse>(vhi); + + use_br2 ? cbitreverse_write(out, vlo) : cdigitreverse4_write(out, vlo); + use_br2 ? cbitreverse_write(out + 16, vhi) : cdigitreverse4_write(out + 16, vhi); + out += 32; + } + return {}; +} + +template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> +KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, + cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, + complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) +{ + constexpr static size_t prefetch_offset = width * 4; + CMT_ASSUME(blocks > 0); + DFT_ASSERT(4 <= blocks); + CMT_LOOP_NOUNROLL + for (size_t b = 0; b < blocks; b += 4) + { + if (prefetch) + prefetch_one(out + prefetch_offset); + + cvec<T, 16> v16 = cdigitreverse4_read<16, aligned>(out); + butterfly4<4, inverse>(v16); + cdigitreverse4_write<aligned>(out, v16); + + out += 4 * 4; + } + return {}; +} + +template <typename T> +static void dft_stage_fixed_initialize(dft_stage<T>* stage, size_t width) +{ + complex<T>* twiddle = ptr_cast<complex<T>>(stage->data); + const size_t N = stage->repeats * stage->radix; + const size_t Nord = stage->repeats; + size_t i = 0; + + while (width > 0) + { + CMT_LOOP_NOUNROLL + for (; i < Nord / width * width; i += width) + { + CMT_LOOP_NOUNROLL + for (size_t j = 1; j < stage->radix; j++) + { + CMT_LOOP_NOUNROLL + for (size_t k = 0; k < width; k++) + { + cvec<T, 1> xx = cossin_conj(broadcast<2, T>(c_pi<T, 2> * (i + k) * j / N)); + ref_cast<cvec<T, 1>>(twiddle[k]) = xx; + } + twiddle += width; + } + } + width = width / 2; + } +} + +template <typename T, size_t radix> +struct dft_stage_fixed_impl : dft_stage<T> +{ + dft_stage_fixed_impl(size_t radix_, size_t iterations, size_t blocks) + { + this->name = type_name<decltype(*this)>(); + this->radix = radix; + this->blocks = blocks; + this->repeats = iterations; + this->recursion = false; // true; + this->data_size = + align_up((this->repeats * (radix - 1)) * sizeof(complex<T>), platform<>::native_cache_alignment); + } + + constexpr static size_t width = + radix >= 7 ? fft_vector_width<T> / 2 : radix >= 4 ? fft_vector_width<T> : fft_vector_width<T> * 2; + virtual void do_initialize(size_t size) override final { dft_stage_fixed_initialize(this, width); } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + const size_t Nord = this->repeats; + const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + + const size_t N = Nord * this->radix; + CMT_LOOP_NOUNROLL + for (size_t b = 0; b < this->blocks; b++) + { + butterflies(Nord, csize<width>, csize<radix>, cbool<inverse>, out, in, twiddle, Nord); + in += N; + out += N; + } + } +}; + +template <typename T, size_t radix> +struct dft_stage_fixed_final_impl : dft_stage<T> +{ + dft_stage_fixed_final_impl(size_t radix_, size_t iterations, size_t blocks) + { + this->name = type_name<decltype(*this)>(); + this->radix = radix; + this->blocks = blocks; + this->repeats = iterations; + this->recursion = false; + this->can_inplace = false; + } + constexpr static size_t width = + radix >= 7 ? fft_vector_width<T> / 2 : radix >= 4 ? fft_vector_width<T> : fft_vector_width<T> * 2; + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + const size_t b = this->blocks; + const size_t size = b * radix; + + butterflies(b, csize<width>, csize<radix>, cbool<inverse>, out, in, b); + } +}; + +template <typename E> +inline E& apply_conj(E& e, cfalse_t) +{ + return e; +} + +template <typename E> +inline auto apply_conj(E& e, ctrue_t) +{ + return cconj(e); +} + +/// [0, N - 1, N - 2, N - 3, ..., 3, 2, 1] +template <typename E> +struct fft_inverse : expression_base<E> +{ + using value_type = value_type_of<E>; + + CMT_INLINE fft_inverse(E&& expr) noexcept : expression_base<E>(std::forward<E>(expr)) {} + + CMT_INLINE vec<value_type, 1> operator()(cinput_t input, size_t index, vec_t<value_type, 1>) const + { + return this->argument_first(input, index == 0 ? 0 : this->size() - index, vec_t<value_type, 1>()); + } + + template <size_t N> + CMT_INLINE vec<value_type, N> operator()(cinput_t input, size_t index, vec_t<value_type, N>) const + { + if (index == 0) + { + return concat( + this->argument_first(input, index, vec_t<value_type, 1>()), + reverse(this->argument_first(input, this->size() - (N - 1), vec_t<value_type, N - 1>()))); + } + return reverse(this->argument_first(input, this->size() - index - (N - 1), vec_t<value_type, N>())); + } +}; + +template <typename E> +inline auto apply_fft_inverse(E&& e) +{ + return fft_inverse<E>(std::forward<E>(e)); +} + +template <typename T> +struct dft_arblen_stage_impl : dft_stage<T> +{ + dft_arblen_stage_impl(size_t size) + : fftsize(next_poweroftwo(size) * 2), plan(fftsize, dft_order::internal), size(size) + { + this->name = type_name<decltype(*this)>(); + this->radix = size; + this->blocks = 1; + this->repeats = 1; + this->recursion = false; + this->can_inplace = false; + this->temp_size = plan.temp_size; + + chirp_ = render(cexp(sqr(linspace(T(1) - size, size - T(1), size * 2 - 1, true, true)) * + complex<T>(0, -1) * c_pi<T> / size)); + + ichirpp_ = render(truncate(padded(1 / slice(chirp_, 0, 2 * size - 1)), fftsize)); + + univector<u8> temp(plan.temp_size); + plan.execute(ichirpp_, ichirpp_, temp); + xp.resize(fftsize, 0); + xp_fft.resize(fftsize); + invN2 = T(1) / fftsize; + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) + { + const size_t n = this->size; + const size_t N2 = this->fftsize; + + auto&& chirp = apply_conj(chirp_, cbool<inverse>); + + xp.slice(0, n) = make_univector(in, n) * slice(chirp, n - 1); + + plan.execute(xp_fft.data(), xp.data(), temp); + + if (inverse) + xp_fft = xp_fft * cconj(apply_fft_inverse(ichirpp_)); + else + xp_fft = xp_fft * ichirpp_; + plan.execute(xp_fft.data(), xp_fft.data(), temp, ctrue); + + make_univector(out, n) = xp_fft.slice(n - 1) * slice(chirp, n - 1) * invN2; + } + + const size_t size; + const size_t fftsize; + T invN2; + dft_plan<T> plan; + univector<complex<T>> chirp_; + univector<complex<T>> ichirpp_; + univector<complex<T>> xp; + univector<complex<T>> xp_fft; +}; + +template <typename T, size_t radix1, size_t radix2, size_t size = radix1* radix2> +struct dft_special_stage_impl : dft_stage<T> +{ + dft_special_stage_impl() : stage1(radix1, size / radix1, 1), stage2(radix2, 1, size / radix2) + { + this->name = type_name<decltype(*this)>(); + this->radix = size; + this->blocks = 1; + this->repeats = 1; + this->recursion = false; + this->can_inplace = false; + this->temp_size = stage1.temp_size + stage2.temp_size + sizeof(complex<T>) * size; + this->data_size = stage1.data_size + stage2.data_size; + } + void dump() const override + { + dft_stage<T>::dump(); + printf(" "); + stage1.dump(); + printf(" "); + stage2.dump(); + } + void do_initialize(size_t stage_size) override + { + stage1.data = this->data; + stage2.data = this->data + stage1.data_size; + stage1.initialize(stage_size); + stage2.initialize(stage_size); + } + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) + { + complex<T>* scratch = ptr_cast<complex<T>>(temp + stage1.temp_size + stage2.temp_size); + stage1.do_execute(cbool<inverse>, scratch, in, temp); + stage2.do_execute(cbool<inverse>, out, scratch, temp + stage1.temp_size); + } + dft_stage_fixed_impl<T, radix1> stage1; + dft_stage_fixed_final_impl<T, radix2> stage2; +}; + +template <typename T, bool final> +struct dft_stage_generic_impl : dft_stage<T> +{ + dft_stage_generic_impl(size_t radix, size_t iterations, size_t blocks) + { + this->name = type_name<decltype(*this)>(); + this->radix = radix; + this->blocks = blocks; + this->repeats = iterations; + this->recursion = false; // true; + this->can_inplace = false; + this->temp_size = align_up(sizeof(complex<T>) * radix, platform<>::native_cache_alignment); + this->data_size = + align_up(sizeof(complex<T>) * sqr(this->radix / 2), platform<>::native_cache_alignment); + } + +protected: + virtual void do_initialize(size_t size) override final + { + complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + CMT_LOOP_NOUNROLL + for (size_t i = 0; i < this->radix / 2; i++) + { + CMT_LOOP_NOUNROLL + for (size_t j = 0; j < this->radix / 2; j++) + { + cwrite<1>(twiddle++, cossin_conj(broadcast<2>((i + 1) * (j + 1) * c_pi<T, 2> / this->radix))); + } + } + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) + { + const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + const size_t bl = this->blocks; + const size_t Nord = this->repeats; + const size_t N = Nord * this->radix; + + CMT_LOOP_NOUNROLL + for (size_t b = 0; b < bl; b++) + generic_butterfly(this->radix, cbool<inverse>, out + b, in + b * this->radix, + ptr_cast<complex<T>>(temp), twiddle, bl); + } +}; + +template <typename T, typename Tr2> +inline void dft_permute(complex<T>* out, const complex<T>* in, size_t r0, size_t r1, Tr2 first_radix) +{ + CMT_ASSUME(r0 > 1); + CMT_ASSUME(r1 > 1); + + CMT_LOOP_NOUNROLL + for (size_t p = 0; p < r0; p++) + { + const complex<T>* in1 = in; + CMT_LOOP_NOUNROLL + for (size_t i = 0; i < r1; i++) + { + const complex<T>* in2 = in1; + CMT_LOOP_UNROLL + for (size_t j = 0; j < first_radix; j++) + { + *out++ = *in2; + in2 += r1; + } + in1++; + in += first_radix; + } + } +} + +template <typename T, typename Tr2> +inline void dft_permute_deep(complex<T>*& out, const complex<T>* in, const size_t* radices, size_t count, + size_t index, size_t inscale, size_t inner_size, Tr2 first_radix) +{ + const bool b = index == 1; + const size_t radix = radices[index]; + if (b) + { + CMT_LOOP_NOUNROLL + for (size_t i = 0; i < radix; i++) + { + const complex<T>* in1 = in; + CMT_LOOP_UNROLL + for (size_t j = 0; j < first_radix; j++) + { + *out++ = *in1; + in1 += inner_size; + } + in += inscale; + } + } + else + { + const size_t steps = radix; + const size_t inscale_next = inscale * radix; + CMT_LOOP_NOUNROLL + for (size_t i = 0; i < steps; i++) + { + dft_permute_deep(out, in, radices, count, index - 1, inscale_next, inner_size, first_radix); + in += inscale; + } + } +} + +template <typename T> +struct dft_reorder_stage_impl : dft_stage<T> +{ + dft_reorder_stage_impl(const int* radices, size_t count) : count(count) + { + this->name = type_name<decltype(*this)>(); + this->can_inplace = false; + this->data_size = 0; + std::copy(radices, radices + count, this->radices); + this->inner_size = 1; + this->size = 1; + for (size_t r = 0; r < count; r++) + { + if (r != 0 && r != count - 1) + this->inner_size *= radices[r]; + this->size *= radices[r]; + } + } + +protected: + size_t radices[32]; + size_t count = 0; + size_t size = 0; + size_t inner_size = 0; + virtual void do_initialize(size_t) override final {} + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + cswitch(dft_radices, radices[0], + [&](auto first_radix) { + if (count == 3) + { + dft_permute(out, in, radices[2], radices[1], first_radix); + } + else + { + const size_t rlast = radices[count - 1]; + for (size_t p = 0; p < rlast; p++) + { + dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, first_radix); + in += size / rlast; + } + } + }, + [&]() { + if (count == 3) + { + dft_permute(out, in, radices[2], radices[1], radices[0]); + } + else + { + const size_t rlast = radices[count - 1]; + for (size_t p = 0; p < rlast; p++) + { + dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, radices[0]); + in += size / rlast; + } + } + }); + } +}; + +template <typename T, bool splitin, bool is_even> +struct fft_stage_impl : dft_stage<T> +{ + fft_stage_impl(size_t stage_size) + { + this->name = type_name<decltype(*this)>(); + this->radix = 4; + this->stage_size = stage_size; + this->repeats = 4; + this->recursion = true; + this->data_size = + align_up(sizeof(complex<T>) * stage_size / 4 * 3, platform<>::native_cache_alignment); + } + +protected: + constexpr static bool prefetch = true; + constexpr static bool aligned = false; + constexpr static size_t width = fft_vector_width<T>; + + virtual void do_initialize(size_t size) override final + { + complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + initialize_twiddles<T, width>(twiddle, this->stage_size, size, true); + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + if (splitin) + in = out; + const size_t stg_size = this->stage_size; + CMT_ASSUME(stg_size >= 2048); + CMT_ASSUME(stg_size % 2048 == 0); + radix4_pass(stg_size, 1, csize_t<width>(), ctrue, cbool_t<splitin>(), cbool_t<!is_even>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); + } +}; + +template <typename T, bool splitin, size_t size> +struct fft_final_stage_impl : dft_stage<T> +{ + fft_final_stage_impl(size_t) + { + this->name = type_name<decltype(*this)>(); + this->radix = size; + this->stage_size = size; + this->out_offset = size; + this->repeats = 4; + this->recursion = true; + this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, platform<>::native_cache_alignment); + } + +protected: + constexpr static size_t width = fft_vector_width<T>; + constexpr static bool is_even = cometa::is_even(ilog2(size)); + constexpr static bool use_br2 = !is_even; + constexpr static bool aligned = false; + constexpr static bool prefetch = splitin; + + KFR_INTRIN void init_twiddles(csize_t<8>, size_t, cfalse_t, complex<T>*&) {} + KFR_INTRIN void init_twiddles(csize_t<4>, size_t, cfalse_t, complex<T>*&) {} + + template <size_t N, bool pass_splitin> + KFR_INTRIN void init_twiddles(csize_t<N>, size_t total_size, cbool_t<pass_splitin>, complex<T>*& twiddle) + { + constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width; + constexpr size_t pass_width = const_min(width, N / 4); + initialize_twiddles<T, pass_width>(twiddle, N, total_size, pass_split || pass_splitin); + init_twiddles(csize<N / 4>, total_size, cbool<pass_split>, twiddle); + } + + virtual void do_initialize(size_t total_size) override final + { + complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle); + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + final_stage<inverse>(csize<size>, 1, cbool<splitin>, out, in, twiddle); + } + + template <bool inverse, typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)> + KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + const complex<T>*& twiddle) + { + radix4_pass(csize_t<32>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + } + + template <bool inverse, typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)> + KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + const complex<T>*& twiddle) + { + radix4_pass(csize_t<16>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + } + + template <bool inverse> + KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + const complex<T>*& twiddle) + { + radix4_pass(csize_t<8>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + } + + template <bool inverse> + KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*, + const complex<T>*& twiddle) + { + radix4_pass(csize_t<4>(), invN, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + } + + template <bool inverse, size_t N, bool pass_splitin> + KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out, + const complex<T>* in, const complex<T>*& twiddle) + { + static_assert(N > 8, ""); + constexpr bool pass_split = N / 4 > 8 && N / 4 / 4 >= width; + constexpr size_t pass_width = const_min(width, N / 4); + static_assert(pass_width == width || (pass_split == pass_splitin), ""); + static_assert(pass_width <= N / 4, ""); + radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(), + cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in, + twiddle); + final_stage<inverse>(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle); + } +}; + +template <typename T, bool is_even> +struct fft_reorder_stage_impl : dft_stage<T> +{ + fft_reorder_stage_impl(size_t stage_size) + { + this->name = type_name<decltype(*this)>(); + this->stage_size = stage_size; + log2n = ilog2(stage_size); + this->data_size = 0; + } + +protected: + size_t log2n; + + virtual void do_initialize(size_t) override final {} + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + fft_reorder(out, log2n, cbool_t<!is_even>()); + } +}; + +template <typename T, size_t log2n> +struct fft_specialization; + +template <typename T> +struct fft_specialization<T, 1> : dft_stage<T> +{ + fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } + +protected: + constexpr static bool aligned = false; + DFT_STAGE_FN + + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + cvec<T, 1> a0, a1; + split(cread<2, aligned>(in), a0, a1); + cwrite<2, aligned>(out, concat(a0 + a1, a0 - a1)); + } +}; + +template <typename T> +struct fft_specialization<T, 2> : dft_stage<T> +{ + fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } + +protected: + constexpr static bool aligned = false; + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + cvec<T, 1> a0, a1, a2, a3; + split(cread<4>(in), a0, a1, a2, a3); + butterfly(cbool_t<inverse>(), a0, a1, a2, a3, a0, a1, a2, a3); + cwrite<4>(out, concat(a0, a1, a2, a3)); + } +}; + +template <typename T> +struct fft_specialization<T, 3> : dft_stage<T> +{ + fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } + +protected: + constexpr static bool aligned = false; + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + cvec<T, 8> v8 = cread<8, aligned>(in); + butterfly8<inverse>(v8); + cwrite<8, aligned>(out, v8); + } +}; + +template <typename T> +struct fft_specialization<T, 4> : dft_stage<T> +{ + fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } + +protected: + constexpr static bool aligned = false; + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + cvec<T, 16> v16 = cread<16, aligned>(in); + butterfly16<inverse>(v16); + cwrite<16, aligned>(out, v16); + } +}; + +template <typename T> +struct fft_specialization<T, 5> : dft_stage<T> +{ + fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } + +protected: + constexpr static bool aligned = false; + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + cvec<T, 32> v32 = cread<32, aligned>(in); + butterfly32<inverse>(v32); + cwrite<32, aligned>(out, v32); + } +}; + +template <typename T> +struct fft_specialization<T, 6> : dft_stage<T> +{ + fft_specialization(size_t) { this->name = type_name<decltype(*this)>(); } + +protected: + constexpr static bool aligned = false; + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + butterfly64(cbool_t<inverse>(), cbool_t<aligned>(), out, in); + } +}; + +template <typename T> +struct fft_specialization<T, 7> : dft_stage<T> +{ + fft_specialization(size_t) + { + this->name = type_name<decltype(*this)>(); + this->stage_size = 128; + this->data_size = align_up(sizeof(complex<T>) * 128 * 3 / 2, platform<>::native_cache_alignment); + } + +protected: + constexpr static bool aligned = false; + constexpr static size_t width = platform<T>::vector_width; + constexpr static bool use_br2 = true; + constexpr static bool prefetch = false; + constexpr static bool is_double = sizeof(T) == 8; + constexpr static size_t final_size = is_double ? 8 : 32; + constexpr static size_t split_format = final_size == 8; + + virtual void do_initialize(size_t total_size) override final + { + complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + initialize_twiddles<T, width>(twiddle, 128, total_size, split_format); + initialize_twiddles<T, width>(twiddle, 32, total_size, split_format); + initialize_twiddles<T, width>(twiddle, 8, total_size, split_format); + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); + final_pass<inverse>(csize_t<final_size>(), out, in, twiddle); + if (this->need_reorder) + fft_reorder(out, csize_t<7>()); + } + + template <bool inverse> + KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) + { + radix4_pass(128, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), + cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); + radix4_pass(32, 4, csize_t<width>(), cfalse, ctrue, cbool_t<use_br2>(), cbool_t<prefetch>(), + cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + radix4_pass(csize_t<8>(), 16, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + } + + template <bool inverse> + KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) + { + radix4_pass(128, 1, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(), + cbool_t<inverse>(), cbool_t<aligned>(), out, in, twiddle); + radix4_pass(csize_t<32>(), 4, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), + cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle); + } +}; + +template <> +struct fft_specialization<float, 8> : dft_stage<float> +{ + fft_specialization(size_t) + { + this->name = type_name<decltype(*this)>(); + this->temp_size = sizeof(complex<float>) * 256; + } + +protected: + using T = float; + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8* temp) + { + complex<float>* scratch = ptr_cast<complex<float>>(temp); + if (out == in) + { + butterfly16_multi_flip<0, inverse>(scratch, out); + butterfly16_multi_flip<1, inverse>(scratch, out); + butterfly16_multi_flip<2, inverse>(scratch, out); + butterfly16_multi_flip<3, inverse>(scratch, out); + + butterfly16_multi_natural<0, inverse>(out, scratch); + butterfly16_multi_natural<1, inverse>(out, scratch); + butterfly16_multi_natural<2, inverse>(out, scratch); + butterfly16_multi_natural<3, inverse>(out, scratch); + } + else + { + butterfly16_multi_flip<0, inverse>(out, in); + butterfly16_multi_flip<1, inverse>(out, in); + butterfly16_multi_flip<2, inverse>(out, in); + butterfly16_multi_flip<3, inverse>(out, in); + + butterfly16_multi_natural<0, inverse>(out, out); + butterfly16_multi_natural<1, inverse>(out, out); + butterfly16_multi_natural<2, inverse>(out, out); + butterfly16_multi_natural<3, inverse>(out, out); + } + } +}; + +template <> +struct fft_specialization<double, 8> : fft_final_stage_impl<double, false, 256> +{ + using T = double; + fft_specialization(size_t stage_size) : fft_final_stage_impl<double, false, 256>(stage_size) + { + this->name = type_name<decltype(*this)>(); + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + fft_final_stage_impl<double, false, 256>::template do_execute<inverse>(out, in, nullptr); + if (this->need_reorder) + fft_reorder(out, csize_t<8>()); + } +}; + +template <typename T> +struct fft_specialization<T, 9> : fft_final_stage_impl<T, false, 512> +{ + fft_specialization(size_t stage_size) : fft_final_stage_impl<T, false, 512>(stage_size) + { + this->name = type_name<decltype(*this)>(); + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + fft_final_stage_impl<T, false, 512>::template do_execute<inverse>(out, in, nullptr); + if (this->need_reorder) + fft_reorder(out, csize_t<9>()); + } +}; + +template <typename T> +struct fft_specialization<T, 10> : fft_final_stage_impl<T, false, 1024> +{ + fft_specialization(size_t stage_size) : fft_final_stage_impl<T, false, 1024>(stage_size) + { + this->name = type_name<decltype(*this)>(); + } + + DFT_STAGE_FN + template <bool inverse> + KFR_INTRIN void do_execute(complex<T>* out, const complex<T>* in, u8*) + { + fft_final_stage_impl<T, false, 1024>::template do_execute<inverse>(out, in, nullptr); + if (this->need_reorder) + fft_reorder(out, 10, cfalse); + } +}; + +} // namespace internal + +// + +template <typename T> +template <typename Stage, typename... Args> +void dft_plan<T>::add_stage(Args... args) +{ + dft_stage<T>* stage = new Stage(args...); + stage->need_reorder = need_reorder; + this->data_size += stage->data_size; + this->temp_size += stage->temp_size; + stages.push_back(dft_stage_ptr(stage)); +} + +template <typename T> +template <bool is_final> +void dft_plan<T>::prepare_dft_stage(size_t radix, size_t iterations, size_t blocks, cbool_t<is_final>) +{ + return cswitch( + dft_radices, radix, + [&](auto radix) CMT_INLINE_LAMBDA { + add_stage<conditional<is_final, internal::dft_stage_fixed_final_impl<T, val_of(radix)>, + internal::dft_stage_fixed_impl<T, val_of(radix)>>>(radix, iterations, + blocks); + }, + [&]() { add_stage<internal::dft_stage_generic_impl<T, is_final>>(radix, iterations, blocks); }); +} + +template <typename T> +template <bool is_even, bool first> +void dft_plan<T>::make_fft(size_t stage_size, cbool_t<is_even>, cbool_t<first>) +{ + constexpr size_t final_size = is_even ? 1024 : 512; + + if (stage_size >= 2048) + { + add_stage<internal::fft_stage_impl<T, !first, is_even>>(stage_size); + + make_fft(stage_size / 4, cbool_t<is_even>(), cfalse); + } + else + { + add_stage<internal::fft_final_stage_impl<T, !first, final_size>>(final_size); + } +} + +template <typename T> +struct reverse_wrapper +{ + T& iterable; +}; + +template <typename T> +auto begin(reverse_wrapper<T> w) +{ + return std::rbegin(w.iterable); +} + +template <typename T> +auto end(reverse_wrapper<T> w) +{ + return std::rend(w.iterable); +} + +template <typename T> +reverse_wrapper<T> reversed(T&& iterable) +{ + return { iterable }; +} + +template <typename T> +void dft_plan<T>::initialize() +{ + data = autofree<u8>(data_size); + size_t offset = 0; + for (dft_stage_ptr& stage : stages) + { + stage->data = data.data() + offset; + stage->initialize(this->size); + offset += stage->data_size; + } + + bool to_scratch = false; + bool scratch_needed = false; + for (dft_stage_ptr& stage : reversed(stages)) + { + if (to_scratch) + { + scratch_needed = true; + } + stage->to_scratch = to_scratch; + if (!stage->can_inplace) + { + to_scratch = !to_scratch; + } + } + if (scratch_needed || !stages[0]->can_inplace) + this->temp_size += align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment); +} + +template <typename T> +const complex<T>* dft_plan<T>::select_in(size_t stage, const complex<T>* out, const complex<T>* in, + const complex<T>* scratch, bool in_scratch) const +{ + if (stage == 0) + return in_scratch ? scratch : in; + return stages[stage - 1]->to_scratch ? scratch : out; +} + +template <typename T> +complex<T>* dft_plan<T>::select_out(size_t stage, complex<T>* out, complex<T>* scratch) const +{ + return stages[stage]->to_scratch ? scratch : out; +} + +template <typename T> +template <bool inverse> +void dft_plan<T>::execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const +{ + if (stages.size() == 1 && (stages[0]->can_inplace || in != out)) + { + return stages[0]->execute(cbool<inverse>, out, in, temp); + } + size_t stack[32] = { 0 }; + + complex<T>* scratch = + ptr_cast<complex<T>>(temp + this->temp_size - + align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment)); + + bool in_scratch = !stages[0]->can_inplace && in == out; + if (in_scratch) + { + internal::builtin_memcpy(scratch, in, sizeof(complex<T>) * this->size); + } + + const size_t count = stages.size(); + + for (size_t depth = 0; depth < count;) + { + if (stages[depth]->recursion) + { + size_t offset = 0; + size_t rdepth = depth; + size_t maxdepth = depth; + do + { + if (stack[rdepth] == stages[rdepth]->repeats) + { + stack[rdepth] = 0; + rdepth--; + } + else + { + complex<T>* rout = select_out(rdepth, out, scratch); + const complex<T>* rin = select_in(rdepth, out, in, scratch, in_scratch); + stages[rdepth]->execute(cbool<inverse>, rout + offset, rin + offset, temp); + offset += stages[rdepth]->out_offset; + stack[rdepth]++; + if (rdepth < count - 1 && stages[rdepth + 1]->recursion) + rdepth++; + else + maxdepth = rdepth; + } + } while (rdepth != depth); + depth = maxdepth + 1; + } + else + { + stages[depth]->execute(cbool<inverse>, select_out(depth, out, scratch), + select_in(depth, out, in, scratch, in_scratch), temp); + depth++; + } + } +} + +template <typename T> +dft_plan<T>::dft_plan(size_t size, dft_order order) : size(size), temp_size(0), data_size(0) +{ + need_reorder = true; + if (is_poweroftwo(size)) + { + const size_t log2n = ilog2(size); + cswitch(csizes_t<1, 2, 3, 4, 5, 6, 7, 8, 9, 10>(), log2n, + [&](auto log2n) { + (void)log2n; + constexpr size_t log2nv = val_of(decltype(log2n)()); + this->add_stage<internal::fft_specialization<T, log2nv>>(size); + }, + [&]() { + cswitch(cfalse_true, is_even(log2n), [&](auto is_even) { + this->make_fft(size, is_even, ctrue); + constexpr size_t is_evenv = val_of(decltype(is_even)()); + if (need_reorder) + this->add_stage<internal::fft_reorder_stage_impl<T, is_evenv>>(size); + }); + }); + } +#ifndef KFR_DFT_NO_NPo2 + else + { + if (size == 60) + { + this->add_stage<internal::dft_special_stage_impl<T, 6, 10>>(); + } + else if (size == 48) + { + this->add_stage<internal::dft_special_stage_impl<T, 6, 8>>(); + } + else + { + size_t cur_size = size; + constexpr size_t radices_count = dft_radices.back() + 1; + u8 count[radices_count] = { 0 }; + int radices[32] = { 0 }; + size_t radices_size = 0; + + cforeach(dft_radices[csizeseq<dft_radices.size(), dft_radices.size() - 1, -1>], [&](auto radix) { + while (cur_size && cur_size % val_of(radix) == 0) + { + count[val_of(radix)]++; + cur_size /= val_of(radix); + } + }); + + if (cur_size >= 101) + { + this->add_stage<internal::dft_arblen_stage_impl<T>>(size); + } + else + { + size_t blocks = 1; + size_t iterations = size; + + for (size_t r = dft_radices.front(); r <= dft_radices.back(); r++) + { + for (size_t i = 0; i < count[r]; i++) + { + iterations /= r; + radices[radices_size++] = r; + if (iterations == 1) + this->prepare_dft_stage(r, iterations, blocks, ctrue); + else + this->prepare_dft_stage(r, iterations, blocks, cfalse); + blocks *= r; + } + } + + if (cur_size > 1) + { + iterations /= cur_size; + radices[radices_size++] = cur_size; + if (iterations == 1) + this->prepare_dft_stage(cur_size, iterations, blocks, ctrue); + else + this->prepare_dft_stage(cur_size, iterations, blocks, cfalse); + } + + if (stages.size() > 2) + this->add_stage<internal::dft_reorder_stage_impl<T>>(radices, radices_size); + } + } + } +#endif + initialize(); +} + +template <typename T> +dft_plan_real<T>::dft_plan_real(size_t size) : dft_plan<T>(size / 2), size(size), rtwiddle(size / 4) +{ + using namespace internal; + + constexpr size_t width = platform<T>::vector_width * 2; + + block_process(size / 4, csizes_t<width, 1>(), [=](size_t i, auto w) { + constexpr size_t width = val_of(decltype(w)()); + cwrite<width>(rtwiddle.data() + i, + cossin(dup(-constants<T>::pi * ((enumerate<T, width>() + i + size / 4) / (size / 2))))); + }); +} + +template <typename T> +void dft_plan_real<T>::to_fmt(complex<T>* out, dft_pack_format fmt) const +{ + using namespace internal; + size_t csize = this->size / 2; // const size_t causes internal compiler error: in tsubst_copy in GCC 5.2 + + constexpr size_t width = platform<T>::vector_width * 2; + const cvec<T, 1> dc = cread<1>(out); + const size_t count = csize / 2; + + block_process(count - 1, csizes_t<width, 1>(), [&](size_t i, auto w) { + i++; + constexpr size_t width = val_of(decltype(w)()); + constexpr size_t widthm1 = width - 1; + const cvec<T, width> tw = cread<width>(rtwiddle.data() + i); + const cvec<T, width> fpk = cread<width>(out + i); + const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(out + csize - i - widthm1))); + + const cvec<T, width> f1k = fpk + fpnk; + const cvec<T, width> f2k = fpk - fpnk; + const cvec<T, width> t = cmul(f2k, tw); + cwrite<width>(out + i, T(0.5) * (f1k + t)); + cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(T(0.5) * (f1k - t)))); + }); + + { + size_t k = csize / 2; + const cvec<T, 1> fpk = cread<1>(out + k); + const cvec<T, 1> fpnk = negodd(fpk); + cwrite<1>(out + k, fpnk); + } + if (fmt == dft_pack_format::CCs) + { + cwrite<1>(out, pack(dc[0] + dc[1], 0)); + cwrite<1>(out + csize, pack(dc[0] - dc[1], 0)); + } + else + { + cwrite<1>(out, pack(dc[0] + dc[1], dc[0] - dc[1])); + } +} + +template <typename T> +void dft_plan_real<T>::from_fmt(complex<T>* out, const complex<T>* in, dft_pack_format fmt) const +{ + using namespace internal; + + const size_t csize = this->size / 2; + + cvec<T, 1> dc; + + if (fmt == dft_pack_format::CCs) + { + dc = pack(in[0].real() + in[csize].real(), in[0].real() - in[csize].real()); + } + else + { + dc = pack(in[0].real() + in[0].imag(), in[0].real() - in[0].imag()); + } + + constexpr size_t width = platform<T>::vector_width * 2; + const size_t count = csize / 2; + + block_process(count - 1, csizes_t<width, 1>(), [&](size_t i, auto w) { + i++; + constexpr size_t width = val_of(decltype(w)()); + constexpr size_t widthm1 = width - 1; + const cvec<T, width> tw = cread<width>(rtwiddle.data() + i); + const cvec<T, width> fpk = cread<width>(in + i); + const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1))); + + const cvec<T, width> f1k = fpk + fpnk; + const cvec<T, width> f2k = fpk - fpnk; + const cvec<T, width> t = cmul_conj(f2k, tw); + cwrite<width>(out + i, f1k + t); + cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t))); + }); + + { + size_t k = csize / 2; + const cvec<T, 1> fpk = cread<1>(in + k); + const cvec<T, 1> fpnk = 2 * negodd(fpk); + cwrite<1>(out + k, fpnk); + } + cwrite<1>(out, dc); +} + +template <typename T> +dft_plan<T>::~dft_plan() +{ +} + +template <typename T> +void dft_plan<T>::dump() const +{ + for (const dft_stage_ptr& s : stages) + { + s->dump(); + } +} + +} // namespace kfr + +CMT_PRAGMA_GNU(GCC diagnostic pop) + +CMT_PRAGMA_MSVC(warning(pop)) diff --git a/include/kfr/dft/impl/dft-src.cpp b/include/kfr/dft/impl/dft-src.cpp @@ -0,0 +1,130 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ + +#include "dft-impl.hpp" + +namespace kfr +{ + +extern "C" +{ + + KFR_DFT_PLAN_F32* kfr_dft_create_plan_f32(size_t size) + { + return reinterpret_cast<KFR_DFT_PLAN_F32*>(new kfr::dft_plan<float>(size)); + } + KFR_DFT_PLAN_F64* kfr_dft_create_plan_f64(size_t size) + { + return reinterpret_cast<KFR_DFT_PLAN_F64*>(new kfr::dft_plan<double>(size)); + } + + void kfr_dft_execute_f32(KFR_DFT_PLAN_F32* plan, size_t size, float* out, const float* in, uint8_t* temp) + { + reinterpret_cast<kfr::dft_plan<float>*>(plan)->execute( + reinterpret_cast<kfr::complex<float>*>(out), reinterpret_cast<const kfr::complex<float>*>(in), + temp, kfr::cfalse); + } + void kfr_dft_execute_f64(KFR_DFT_PLAN_F64* plan, size_t size, double* out, const double* in, + uint8_t* temp) + { + reinterpret_cast<kfr::dft_plan<double>*>(plan)->execute( + reinterpret_cast<kfr::complex<double>*>(out), reinterpret_cast<const kfr::complex<double>*>(in), + temp, kfr::cfalse); + } + void kfr_dft_execute_inverse_f32(KFR_DFT_PLAN_F32* plan, size_t size, float* out, const float* in, + uint8_t* temp) + { + reinterpret_cast<kfr::dft_plan<float>*>(plan)->execute( + reinterpret_cast<kfr::complex<float>*>(out), reinterpret_cast<const kfr::complex<float>*>(in), + temp, kfr::ctrue); + } + void kfr_dft_execute_inverse_f64(KFR_DFT_PLAN_F64* plan, size_t size, double* out, const double* in, + uint8_t* temp) + { + reinterpret_cast<kfr::dft_plan<double>*>(plan)->execute( + reinterpret_cast<kfr::complex<double>*>(out), reinterpret_cast<const kfr::complex<double>*>(in), + temp, kfr::ctrue); + } + + void kfr_dft_delete_plan_f32(KFR_DFT_PLAN_F32* plan) + { + delete reinterpret_cast<kfr::dft_plan<float>*>(plan); + } + void kfr_dft_delete_plan_f64(KFR_DFT_PLAN_F64* plan) + { + delete reinterpret_cast<kfr::dft_plan<double>*>(plan); + } + + // Real DFT plans + + KFR_DFT_REAL_PLAN_F32* kfr_dft_create_real_plan_f32(size_t size) + { + return reinterpret_cast<KFR_DFT_REAL_PLAN_F32*>(new kfr::dft_plan_real<float>(size)); + } + KFR_DFT_REAL_PLAN_F64* kfr_dft_create_real_plan_f64(size_t size) + { + return reinterpret_cast<KFR_DFT_REAL_PLAN_F64*>(new kfr::dft_plan_real<double>(size)); + } + + void kfr_dft_execute_real_f32(KFR_DFT_REAL_PLAN_F32* plan, size_t size, float* out, const float* in, + uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) + { + reinterpret_cast<kfr::dft_plan_real<float>*>(plan)->execute( + reinterpret_cast<kfr::complex<float>*>(out), in, temp, + static_cast<kfr::dft_pack_format>(pack_format)); + } + void kfr_dft_execute_real_f64(KFR_DFT_REAL_PLAN_F64* plan, size_t size, double* out, const double* in, + uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) + { + reinterpret_cast<kfr::dft_plan_real<double>*>(plan)->execute( + reinterpret_cast<kfr::complex<double>*>(out), in, temp, + static_cast<kfr::dft_pack_format>(pack_format)); + } + void kfr_dft_execute_real_inverse_f32(KFR_DFT_REAL_PLAN_F32* plan, size_t size, float* out, + const float* in, uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) + { + reinterpret_cast<kfr::dft_plan_real<float>*>(plan)->execute( + out, reinterpret_cast<const kfr::complex<float>*>(in), temp, + static_cast<kfr::dft_pack_format>(pack_format)); + } + void kfr_dft_execute_real_inverse__f64(KFR_DFT_REAL_PLAN_F64* plan, size_t size, double* out, + const double* in, uint8_t* temp, KFR_DFT_PACK_FORMAT pack_format) + { + reinterpret_cast<kfr::dft_plan_real<double>*>(plan)->execute( + out, reinterpret_cast<const kfr::complex<double>*>(in), temp, + static_cast<kfr::dft_pack_format>(pack_format)); + } + + void kfr_dft_delete_real_plan_f32(KFR_DFT_REAL_PLAN_F32* plan) + { + delete reinterpret_cast<kfr::dft_plan_real<float>*>(plan); + } + void kfr_dft_delete_real_plan_f64(KFR_DFT_REAL_PLAN_F64* plan) + { + delete reinterpret_cast<kfr::dft_plan_real<double>*>(plan); + } +} +} // namespace kfr diff --git a/include/kfr/dft/impl/dft-templates.hpp b/include/kfr/dft/impl/dft-templates.hpp @@ -0,0 +1,44 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ + +#include "../fft.hpp" + +namespace kfr +{ + +template dft_plan<FLOAT>::dft_plan(size_t, dft_order); +template dft_plan<FLOAT>::~dft_plan(); +template void dft_plan<FLOAT>::dump() const; +template void dft_plan<FLOAT>::execute_dft(cometa::cbool_t<false>, kfr::complex<FLOAT>* out, + const kfr::complex<FLOAT>* in, kfr::u8* temp) const; +template void dft_plan<FLOAT>::execute_dft(cometa::cbool_t<true>, kfr::complex<FLOAT>* out, + const kfr::complex<FLOAT>* in, kfr::u8* temp) const; +template dft_plan_real<FLOAT>::dft_plan_real(size_t); +template void dft_plan_real<FLOAT>::from_fmt(kfr::complex<FLOAT>* out, const kfr::complex<FLOAT>* in, + kfr::dft_pack_format fmt) const; +template void dft_plan_real<FLOAT>::to_fmt(kfr::complex<FLOAT>* out, kfr::dft_pack_format fmt) const; + +} // namespace kfr diff --git a/include/kfr/dft/impl/ft.hpp b/include/kfr/dft/impl/ft.hpp @@ -0,0 +1,1760 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ +#pragma once + +#include "../../base/complex.hpp" +#include "../../base/constants.hpp" +#include "../../base/digitreverse.hpp" +#include "../../base/read_write.hpp" +#include "../../base/sin_cos.hpp" +#include "../../base/small_buffer.hpp" +#include "../../base/univector.hpp" +#include "../../base/vec.hpp" + +#include "../../base/memory.hpp" +#include "../../data/sincos.hpp" + +CMT_PRAGMA_MSVC(warning(push)) +CMT_PRAGMA_MSVC(warning(disable : 4127)) + +namespace kfr +{ + +namespace internal +{ + +template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> +CMT_INLINE vec<T, N> cmul_impl(const vec<T, N>& x, const vec<T, N>& y) +{ + return subadd(x * dupeven(y), swap<2>(x) * dupodd(y)); +} +template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> +CMT_INLINE vec<T, N> cmul_impl(const vec<T, N>& x, const vec<T, 2>& y) +{ + vec<T, N> yy = resize<N>(y); + return cmul_impl(x, yy); +} +template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> +CMT_INLINE vec<T, N> cmul_impl(const vec<T, 2>& x, const vec<T, N>& y) +{ + vec<T, N> xx = resize<N>(x); + return cmul_impl(xx, y); +} + +/// Complex Multiplication +template <typename T, size_t N1, size_t N2> +CMT_INLINE vec<T, const_max(N1, N2)> cmul(const vec<T, N1>& x, const vec<T, N2>& y) +{ + return internal::cmul_impl(x, y); +} + +template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> +CMT_INLINE vec<T, N> cmul_conj(const vec<T, N>& x, const vec<T, N>& y) +{ + return swap<2>(subadd(swap<2>(x) * dupeven(y), x * dupodd(y))); +} +template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> +CMT_INLINE vec<T, N> cmul_2conj(const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& tw) +{ + return (in0 + in1) * dupeven(tw) + swap<2>(cnegimag(in0 - in1)) * dupodd(tw); +} +template <typename T, size_t N, KFR_ENABLE_IF(N >= 2)> +CMT_INLINE void cmul_2conj(vec<T, N>& out0, vec<T, N>& out1, const vec<T, 2>& in0, const vec<T, 2>& in1, + const vec<T, N>& tw) +{ + const vec<T, N> twr = dupeven(tw); + const vec<T, N> twi = dupodd(tw); + const vec<T, 2> sum = (in0 + in1); + const vec<T, 2> dif = swap<2>(negodd(in0 - in1)); + const vec<T, N> sumtw = resize<N>(sum) * twr; + const vec<T, N> diftw = resize<N>(dif) * twi; + out0 += sumtw + diftw; + out1 += sumtw - diftw; +} +template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> +CMT_INLINE vec<T, N> cmul_conj(const vec<T, N>& x, const vec<T, 2>& y) +{ + vec<T, N> yy = resize<N>(y); + return cmul_conj(x, yy); +} +template <typename T, size_t N, KFR_ENABLE_IF(N > 2)> +CMT_INLINE vec<T, N> cmul_conj(const vec<T, 2>& x, const vec<T, N>& y) +{ + vec<T, N> xx = resize<N>(x); + return cmul_conj(xx, y); +} + +template <typename T, size_t N> +using cvec = vec<T, N * 2>; + +template <size_t N, bool A = false, typename T> +CMT_INLINE cvec<T, N> cread(const complex<T>* src) +{ + return cvec<T, N>(ptr_cast<T>(src), cbool_t<A>()); +} + +template <size_t N, bool A = false, typename T> +CMT_INLINE void cwrite(complex<T>* dest, const cvec<T, N>& value) +{ + value.write(ptr_cast<T>(dest)); +} + +template <size_t count, size_t N, size_t stride, bool A, typename T, size_t... indices> +CMT_INLINE cvec<T, count * N> cread_group_impl(const complex<T>* src, csizes_t<indices...>) +{ + return concat(read<N * 2, A>(ptr_cast<T>(src + stride * indices))...); +} +template <size_t count, size_t N, size_t stride, bool A, typename T, size_t... indices> +CMT_INLINE void cwrite_group_impl(complex<T>* dest, const cvec<T, count * N>& value, csizes_t<indices...>) +{ + swallow{ (write<A>(ptr_cast<T>(dest + stride * indices), slice<indices * N * 2, N * 2>(value)), 0)... }; +} + +template <size_t count, size_t N, bool A, typename T, size_t... indices> +CMT_INLINE cvec<T, count * N> cread_group_impl(const complex<T>* src, size_t stride, csizes_t<indices...>) +{ + return concat(read<N * 2, A>(ptr_cast<T>(src + stride * indices))...); +} +template <size_t count, size_t N, bool A, typename T, size_t... indices> +CMT_INLINE void cwrite_group_impl(complex<T>* dest, size_t stride, const cvec<T, count * N>& value, + csizes_t<indices...>) +{ + swallow{ (write<A>(ptr_cast<T>(dest + stride * indices), slice<indices * N * 2, N * 2>(value)), 0)... }; +} + +template <size_t count, size_t N, size_t stride, bool A = false, typename T> +CMT_INLINE cvec<T, count * N> cread_group(const complex<T>* src) +{ + return cread_group_impl<count, N, stride, A>(src, csizeseq_t<count>()); +} + +template <size_t count, size_t N, size_t stride, bool A = false, typename T> +CMT_INLINE void cwrite_group(complex<T>* dest, const cvec<T, count * N>& value) +{ + return cwrite_group_impl<count, N, stride, A>(dest, value, csizeseq_t<count>()); +} + +template <size_t count, size_t N, bool A = false, typename T> +CMT_INLINE cvec<T, count * N> cread_group(const complex<T>* src, size_t stride) +{ + return cread_group_impl<count, N, A>(src, stride, csizeseq_t<count>()); +} + +template <size_t count, size_t N, bool A = false, typename T> +CMT_INLINE void cwrite_group(complex<T>* dest, size_t stride, const cvec<T, count * N>& value) +{ + return cwrite_group_impl<count, N, A>(dest, stride, value, csizeseq_t<count>()); +} + +template <size_t N, bool A = false, bool split = false, typename T> +CMT_INLINE cvec<T, N> cread_split(const complex<T>* src) +{ + cvec<T, N> temp = cvec<T, N>(ptr_cast<T>(src), cbool_t<A>()); + if (split) + temp = splitpairs(temp); + return temp; +} + +template <size_t N, bool A = false, bool split = false, typename T> +CMT_INLINE void cwrite_split(complex<T>* dest, const cvec<T, N>& value) +{ + cvec<T, N> v = value; + if (split) + v = interleavehalfs(v); + v.write(ptr_cast<T>(dest), cbool_t<A>()); +} + +template <> +inline cvec<f32, 8> cread_split<8, false, true, f32>(const complex<f32>* src) +{ + const cvec<f32, 4> l = concat(cread<2>(src), cread<2>(src + 4)); + const cvec<f32, 4> h = concat(cread<2>(src + 2), cread<2>(src + 6)); + + return concat(shuffle<0, 2, 8 + 0, 8 + 2>(l, h), shuffle<1, 3, 8 + 1, 8 + 3>(l, h)); +} +template <> +inline cvec<f32, 8> cread_split<8, true, true, f32>(const complex<f32>* src) +{ + const cvec<f32, 4> l = concat(cread<2, true>(src), cread<2, true>(src + 4)); + const cvec<f32, 4> h = concat(cread<2, true>(src + 2), cread<2, true>(src + 6)); + + return concat(shuffle<0, 2, 8 + 0, 8 + 2>(l, h), shuffle<1, 3, 8 + 1, 8 + 3>(l, h)); +} + +template <> +inline cvec<f64, 4> cread_split<4, false, true, f64>(const complex<f64>* src) +{ + const cvec<f64, 2> l = concat(cread<1>(src), cread<1>(src + 2)); + const cvec<f64, 2> h = concat(cread<1>(src + 1), cread<1>(src + 3)); + + return concat(shuffle<0, 4, 2, 6>(l, h), shuffle<1, 5, 3, 7>(l, h)); +} + +template <> +inline void cwrite_split<8, false, true, f32>(complex<f32>* dest, const cvec<f32, 8>& x) +{ + const cvec<f32, 8> xx = + concat(shuffle<0, 8 + 0, 1, 8 + 1>(low(x), high(x)), shuffle<2, 8 + 2, 3, 8 + 3>(low(x), high(x))); + + cvec<f32, 2> a, b, c, d; + split(xx, a, b, c, d); + cwrite<2>(dest, a); + cwrite<2>(dest + 4, b); + cwrite<2>(dest + 2, c); + cwrite<2>(dest + 6, d); +} +template <> +inline void cwrite_split<8, true, true, f32>(complex<f32>* dest, const cvec<f32, 8>& x) +{ + const cvec<f32, 8> xx = + concat(shuffle<0, 8 + 0, 1, 8 + 1>(low(x), high(x)), shuffle<2, 8 + 2, 3, 8 + 3>(low(x), high(x))); + + cvec<f32, 2> a, b, c, d; + split(xx, a, b, c, d); + cwrite<2, true>(dest + 0, a); + cwrite<2, true>(dest + 4, b); + cwrite<2, true>(dest + 2, c); + cwrite<2, true>(dest + 6, d); +} + +template <> +inline void cwrite_split<4, false, true, f64>(complex<f64>* dest, const cvec<f64, 4>& x) +{ + const cvec<f64, 4> xx = + concat(shuffle<0, 4, 2, 6>(low(x), high(x)), shuffle<1, 5, 3, 7>(low(x), high(x))); + cwrite<1>(dest, part<4, 0>(xx)); + cwrite<1>(dest + 2, part<4, 1>(xx)); + cwrite<1>(dest + 1, part<4, 2>(xx)); + cwrite<1>(dest + 3, part<4, 3>(xx)); +} +template <> +inline void cwrite_split<4, true, true, f64>(complex<f64>* dest, const cvec<f64, 4>& x) +{ + const cvec<f64, 4> xx = + concat(shuffle<0, 4, 2, 6>(low(x), high(x)), shuffle<1, 5, 3, 7>(low(x), high(x))); + cwrite<1, true>(dest + 0, part<4, 0>(xx)); + cwrite<1, true>(dest + 2, part<4, 1>(xx)); + cwrite<1, true>(dest + 1, part<4, 2>(xx)); + cwrite<1, true>(dest + 3, part<4, 3>(xx)); +} + +template <size_t N, size_t stride, typename T, size_t... Indices> +CMT_INLINE cvec<T, N> cgather_helper(const complex<T>* base, csizes_t<Indices...>) +{ + return concat(ref_cast<cvec<T, 1>>(base[Indices * stride])...); +} + +template <size_t N, size_t stride, typename T> +CMT_INLINE cvec<T, N> cgather(const complex<T>* base) +{ + if (stride == 1) + { + return ref_cast<cvec<T, N>>(*base); + } + else + return cgather_helper<N, stride, T>(base, csizeseq_t<N>()); +} + +CMT_INLINE size_t cgather_next(size_t& index, size_t stride, size_t size, size_t) +{ + size_t temp = index; + index += stride; + if (index >= size) + index -= size; + return temp; +} +CMT_INLINE size_t cgather_next(size_t& index, size_t stride, size_t) +{ + size_t temp = index; + index += stride; + return temp; +} + +template <size_t N, typename T, size_t... Indices> +CMT_INLINE cvec<T, N> cgather_helper(const complex<T>* base, size_t& index, size_t stride, + csizes_t<Indices...>) +{ + return concat(ref_cast<cvec<T, 1>>(base[cgather_next(index, stride, Indices)])...); +} + +template <size_t N, typename T> +CMT_INLINE cvec<T, N> cgather(const complex<T>* base, size_t& index, size_t stride) +{ + return cgather_helper<N, T>(base, index, stride, csizeseq_t<N>()); +} +template <size_t N, typename T> +CMT_INLINE cvec<T, N> cgather(const complex<T>* base, size_t stride) +{ + size_t index = 0; + return cgather_helper<N, T>(base, index, stride, csizeseq_t<N>()); +} + +template <size_t N, typename T, size_t... Indices> +CMT_INLINE cvec<T, N> cgather_helper(const complex<T>* base, size_t& index, size_t stride, size_t size, + csizes_t<Indices...>) +{ + return concat(ref_cast<cvec<T, 1>>(base[cgather_next(index, stride, size, Indices)])...); +} + +template <size_t N, typename T> +CMT_INLINE cvec<T, N> cgather(const complex<T>* base, size_t& index, size_t stride, size_t size) +{ + return cgather_helper<N, T>(base, index, stride, size, csizeseq_t<N>()); +} + +template <size_t N, size_t stride, typename T, size_t... Indices> +CMT_INLINE void cscatter_helper(complex<T>* base, const cvec<T, N>& value, csizes_t<Indices...>) +{ + swallow{ (cwrite<1>(base + Indices * stride, slice<Indices * 2, 2>(value)), 0)... }; +} + +template <size_t N, size_t stride, typename T> +CMT_INLINE void cscatter(complex<T>* base, const cvec<T, N>& value) +{ + if (stride == 1) + { + cwrite<N>(base, value); + } + else + { + return cscatter_helper<N, stride, T>(base, value, csizeseq_t<N>()); + } +} + +template <size_t N, typename T, size_t... Indices> +CMT_INLINE void cscatter_helper(complex<T>* base, size_t stride, const cvec<T, N>& value, + csizes_t<Indices...>) +{ + swallow{ (cwrite<1>(base + Indices * stride, slice<Indices * 2, 2>(value)), 0)... }; +} + +template <size_t N, typename T> +CMT_INLINE void cscatter(complex<T>* base, size_t stride, const cvec<T, N>& value) +{ + return cscatter_helper<N, T>(base, stride, value, csizeseq_t<N>()); +} + +template <size_t groupsize = 1, typename T, size_t N, typename IT> +CMT_INLINE vec<T, N * 2 * groupsize> cgather(const complex<T>* base, const vec<IT, N>& offset) +{ + return gather_helper<2 * groupsize>(ptr_cast<T>(base), offset, csizeseq_t<N>()); +} + +template <size_t groupsize = 1, typename T, size_t N, typename IT> +CMT_INLINE void cscatter(complex<T>* base, const vec<IT, N>& offset, vec<T, N * 2 * groupsize> value) +{ + return scatter_helper<2 * groupsize>(ptr_cast<T>(base), offset, value, csizeseq_t<N>()); +} + +template <typename T> +KFR_INTRIN void transpose4x8(const cvec<T, 8>& z0, const cvec<T, 8>& z1, const cvec<T, 8>& z2, + const cvec<T, 8>& z3, cvec<T, 4>& w0, cvec<T, 4>& w1, cvec<T, 4>& w2, + cvec<T, 4>& w3, cvec<T, 4>& w4, cvec<T, 4>& w5, cvec<T, 4>& w6, cvec<T, 4>& w7) +{ + cvec<T, 16> a = concat(low(z0), low(z1), low(z2), low(z3)); + cvec<T, 16> b = concat(high(z0), high(z1), high(z2), high(z3)); + a = digitreverse4<2>(a); + b = digitreverse4<2>(b); + w0 = part<4, 0>(a); + w1 = part<4, 1>(a); + w2 = part<4, 2>(a); + w3 = part<4, 3>(a); + w4 = part<4, 0>(b); + w5 = part<4, 1>(b); + w6 = part<4, 2>(b); + w7 = part<4, 3>(b); +} + +template <typename T> +KFR_INTRIN void transpose4x8(const cvec<T, 4>& w0, const cvec<T, 4>& w1, const cvec<T, 4>& w2, + const cvec<T, 4>& w3, const cvec<T, 4>& w4, const cvec<T, 4>& w5, + const cvec<T, 4>& w6, const cvec<T, 4>& w7, cvec<T, 8>& z0, cvec<T, 8>& z1, + cvec<T, 8>& z2, cvec<T, 8>& z3) +{ + cvec<T, 16> a = concat(w0, w1, w2, w3); + cvec<T, 16> b = concat(w4, w5, w6, w7); + a = digitreverse4<2>(a); + b = digitreverse4<2>(b); + z0 = concat(part<4, 0>(a), part<4, 0>(b)); + z1 = concat(part<4, 1>(a), part<4, 1>(b)); + z2 = concat(part<4, 2>(a), part<4, 2>(b)); + z3 = concat(part<4, 3>(a), part<4, 3>(b)); +} + +template <typename T> +void transpose4(cvec<T, 16>& a, cvec<T, 16>& b, cvec<T, 16>& c, cvec<T, 16>& d) +{ + cvec<T, 4> a0, a1, a2, a3; + cvec<T, 4> b0, b1, b2, b3; + cvec<T, 4> c0, c1, c2, c3; + cvec<T, 4> d0, d1, d2, d3; + + split(a, a0, a1, a2, a3); + split(b, b0, b1, b2, b3); + split(c, c0, c1, c2, c3); + split(d, d0, d1, d2, d3); + + a = concat(a0, b0, c0, d0); + b = concat(a1, b1, c1, d1); + c = concat(a2, b2, c2, d2); + d = concat(a3, b3, c3, d3); +} +template <typename T> +void transpose4(cvec<T, 16>& a, cvec<T, 16>& b, cvec<T, 16>& c, cvec<T, 16>& d, cvec<T, 16>& aa, + cvec<T, 16>& bb, cvec<T, 16>& cc, cvec<T, 16>& dd) +{ + cvec<T, 4> a0, a1, a2, a3; + cvec<T, 4> b0, b1, b2, b3; + cvec<T, 4> c0, c1, c2, c3; + cvec<T, 4> d0, d1, d2, d3; + + split(a, a0, a1, a2, a3); + split(b, b0, b1, b2, b3); + split(c, c0, c1, c2, c3); + split(d, d0, d1, d2, d3); + + aa = concat(a0, b0, c0, d0); + bb = concat(a1, b1, c1, d1); + cc = concat(a2, b2, c2, d2); + dd = concat(a3, b3, c3, d3); +} + +template <bool b, typename T> +constexpr KFR_INTRIN T chsign(T x) +{ + return b ? -x : x; +} + +template <typename T, size_t N, size_t size, size_t start, size_t step, bool inverse = false, + size_t... indices> +constexpr KFR_INTRIN cvec<T, N> get_fixed_twiddle_helper(csizes_t<indices...>) +{ + return make_vector((indices & 1 ? chsign<inverse>(-sin_using_table<T>(size, (indices / 2 * step + start))) + : cos_using_table<T>(size, (indices / 2 * step + start)))...); +} + +template <typename T, size_t width, size_t... indices> +constexpr KFR_INTRIN cvec<T, width> get_fixed_twiddle_helper(csizes_t<indices...>, size_t size, size_t start, + size_t step) +{ + return make_vector((indices & 1 ? -sin_using_table<T>(size, indices / 2 * step + start) + : cos_using_table<T>(size, indices / 2 * step + start))...); +} + +template <typename T, size_t width, size_t size, size_t start, size_t step = 0, bool inverse = false> +constexpr KFR_INTRIN cvec<T, width> fixed_twiddle() +{ + return get_fixed_twiddle_helper<T, width, size, start, step, inverse>(csizeseq_t<width * 2>()); +} + +template <typename T, size_t width> +constexpr KFR_INTRIN cvec<T, width> fixed_twiddle(size_t size, size_t start, size_t step = 0) +{ + return get_fixed_twiddle_helper<T, width>(csizeseq_t<width * 2>(), start, step, size); +} + +// template <typename T, size_t N, size_t size, size_t start, size_t step = 0, bool inverse = false> +// constexpr cvec<T, N> fixed_twiddle = get_fixed_twiddle<T, N, size, start, step, inverse>(); + +template <typename T, size_t N, bool inverse> +constexpr cvec<T, N> twiddleimagmask() +{ + return inverse ? broadcast<N * 2, T>(-1, +1) : broadcast<N * 2, T>(+1, -1); +} + +CMT_PRAGMA_GNU(GCC diagnostic push) +CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wconversion") + +CMT_PRAGMA_GNU(GCC diagnostic pop) + +template <typename T, size_t N> +CMT_NOINLINE static vec<T, N> cossin_conj(const vec<T, N>& x) +{ + return negodd(cossin(x)); +} + +template <size_t k, size_t size, bool inverse = false, typename T, size_t width, + size_t kk = (inverse ? size - k : k) % size> +KFR_INTRIN vec<T, width> cmul_by_twiddle(const vec<T, width>& x) +{ + constexpr T isqrt2 = static_cast<T>(0.70710678118654752440084436210485); + if (kk == 0) + { + return x; + } + else if (kk == size * 1 / 8) + { + return swap<2>(subadd(swap<2>(x), x)) * isqrt2; + } + else if (kk == size * 2 / 8) + { + return negodd(swap<2>(x)); + } + else if (kk == size * 3 / 8) + { + return subadd(x, swap<2>(x)) * -isqrt2; + } + else if (kk == size * 4 / 8) + { + return -x; + } + else if (kk == size * 5 / 8) + { + return swap<2>(subadd(swap<2>(x), x)) * -isqrt2; + } + else if (kk == size * 6 / 8) + { + return swap<2>(negodd(x)); + } + else if (kk == size * 7 / 8) + { + return subadd(x, swap<2>(x)) * isqrt2; + } + else + { + return cmul(x, resize<width>(fixed_twiddle<T, 1, size, kk>())); + } +} + +template <size_t N, typename T> +KFR_INTRIN void butterfly2(const cvec<T, N>& a0, const cvec<T, N>& a1, cvec<T, N>& w0, cvec<T, N>& w1) +{ + const cvec<T, N> sum = a0 + a1; + const cvec<T, N> dif = a0 - a1; + w0 = sum; + w1 = dif; +} + +template <size_t N, typename T> +KFR_INTRIN void butterfly2(cvec<T, N>& a0, cvec<T, N>& a1) +{ + butterfly2<N>(a0, a1, a0, a1); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly4(cfalse_t /*split_format*/, const cvec<T, N>& a0, const cvec<T, N>& a1, + const cvec<T, N>& a2, const cvec<T, N>& a3, cvec<T, N>& w0, cvec<T, N>& w1, + cvec<T, N>& w2, cvec<T, N>& w3) +{ + cvec<T, N> sum02, sum13, diff02, diff13; + cvec<T, N * 2> a01, a23, sum0213, diff0213; + + a01 = concat(a0, a1); + a23 = concat(a2, a3); + sum0213 = a01 + a23; + diff0213 = a01 - a23; + + sum02 = low(sum0213); + sum13 = high(sum0213); + diff02 = low(diff0213); + diff13 = high(diff0213); + w0 = sum02 + sum13; + w2 = sum02 - sum13; + if (inverse) + { + diff13 = (diff13 ^ broadcast<N * 2, T>(T(), -T())); + diff13 = swap<2>(diff13); + } + else + { + diff13 = swap<2>(diff13); + diff13 = (diff13 ^ broadcast<N * 2, T>(T(), -T())); + } + + w1 = diff02 + diff13; + w3 = diff02 - diff13; +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly4(ctrue_t /*split_format*/, const cvec<T, N>& a0, const cvec<T, N>& a1, + const cvec<T, N>& a2, const cvec<T, N>& a3, cvec<T, N>& w0, cvec<T, N>& w1, + cvec<T, N>& w2, cvec<T, N>& w3) +{ + vec<T, N> re0, im0, re1, im1, re2, im2, re3, im3; + vec<T, N> wre0, wim0, wre1, wim1, wre2, wim2, wre3, wim3; + + cvec<T, N> sum02, sum13, diff02, diff13; + vec<T, N> sum02re, sum13re, diff02re, diff13re; + vec<T, N> sum02im, sum13im, diff02im, diff13im; + + sum02 = a0 + a2; + sum13 = a1 + a3; + + w0 = sum02 + sum13; + w2 = sum02 - sum13; + + diff02 = a0 - a2; + diff13 = a1 - a3; + split(diff02, diff02re, diff02im); + split(diff13, diff13re, diff13im); + + (inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re); + (inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly8(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, + const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, + const cvec<T, N>& a6, const cvec<T, N>& a7, cvec<T, N>& w0, cvec<T, N>& w1, + cvec<T, N>& w2, cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5, cvec<T, N>& w6, + cvec<T, N>& w7) +{ + cvec<T, N> b0 = a0, b2 = a2, b4 = a4, b6 = a6; + butterfly4<N, inverse>(cfalse, b0, b2, b4, b6, b0, b2, b4, b6); + cvec<T, N> b1 = a1, b3 = a3, b5 = a5, b7 = a7; + butterfly4<N, inverse>(cfalse, b1, b3, b5, b7, b1, b3, b5, b7); + w0 = b0 + b1; + w4 = b0 - b1; + + b3 = cmul_by_twiddle<1, 8, inverse>(b3); + b5 = cmul_by_twiddle<2, 8, inverse>(b5); + b7 = cmul_by_twiddle<3, 8, inverse>(b7); + + w1 = b2 + b3; + w5 = b2 - b3; + w2 = b4 + b5; + w6 = b4 - b5; + w3 = b6 + b7; + w7 = b6 - b7; +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly8(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, + cvec<T, N>& a5, cvec<T, N>& a6, cvec<T, N>& a7) +{ + butterfly8<N, inverse>(a0, a1, a2, a3, a4, a5, a6, a7, a0, a1, a2, a3, a4, a5, a6, a7); +} + +template <bool inverse = false, typename T> +KFR_INTRIN void butterfly8(cvec<T, 2>& a01, cvec<T, 2>& a23, cvec<T, 2>& a45, cvec<T, 2>& a67) +{ + cvec<T, 2> b01 = a01, b23 = a23, b45 = a45, b67 = a67; + + butterfly4<2, inverse>(cfalse, b01, b23, b45, b67, b01, b23, b45, b67); + + cvec<T, 2> b02, b13, b46, b57; + + cvec<T, 8> b01234567 = concat(b01, b23, b45, b67); + cvec<T, 8> b02461357 = concat(even<2>(b01234567), odd<2>(b01234567)); + split(b02461357, b02, b46, b13, b57); + + b13 = cmul(b13, fixed_twiddle<T, 2, 8, 0, 1, inverse>()); + b57 = cmul(b57, fixed_twiddle<T, 2, 8, 2, 1, inverse>()); + a01 = b02 + b13; + a23 = b46 + b57; + a45 = b02 - b13; + a67 = b46 - b57; +} + +template <bool inverse = false, typename T> +KFR_INTRIN void butterfly8(cvec<T, 8>& v8) +{ + cvec<T, 2> w0, w1, w2, w3; + split(v8, w0, w1, w2, w3); + butterfly8<inverse>(w0, w1, w2, w3); + v8 = concat(w0, w1, w2, w3); +} + +template <bool inverse = false, typename T> +KFR_INTRIN void butterfly32(cvec<T, 32>& v32) +{ + cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7; + split(v32, w0, w1, w2, w3, w4, w5, w6, w7); + butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7); + + w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>()); + w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>()); + w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>()); + w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>()); + w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>()); + w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>()); + w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>()); + + cvec<T, 8> z0, z1, z2, z3; + transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3); + + butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3); + v32 = concat(z0, z1, z2, z3); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly4(cvec<T, N * 4>& a0123) +{ + cvec<T, N> a0; + cvec<T, N> a1; + cvec<T, N> a2; + cvec<T, N> a3; + split(a0123, a0, a1, a2, a3); + butterfly4<N, inverse>(cfalse, a0, a1, a2, a3, a0, a1, a2, a3); + a0123 = concat(a0, a1, a2, a3); +} + +template <size_t N, typename T> +KFR_INTRIN void butterfly2(cvec<T, N * 2>& a01) +{ + cvec<T, N> a0; + cvec<T, N> a1; + split(a01, a0, a1); + butterfly2<N>(a0, a1); + a01 = concat(a0, a1); +} + +template <size_t N, bool inverse = false, bool split_format = false, typename T> +KFR_INTRIN void apply_twiddle(const cvec<T, N>& a1, const cvec<T, N>& tw1, cvec<T, N>& w1) +{ + if (split_format) + { + vec<T, N> re1, im1, tw1re, tw1im; + split(a1, re1, im1); + split(tw1, tw1re, tw1im); + vec<T, N> b1re = re1 * tw1re; + vec<T, N> b1im = im1 * tw1re; + if (inverse) + w1 = concat(b1re + im1 * tw1im, b1im - re1 * tw1im); + else + w1 = concat(b1re - im1 * tw1im, b1im + re1 * tw1im); + } + else + { + const cvec<T, N> b1 = a1 * dupeven(tw1); + const cvec<T, N> a1_ = swap<2>(a1); + + cvec<T, N> tw1_ = tw1; + if (inverse) + tw1_ = -(tw1_); + w1 = subadd(b1, a1_ * dupodd(tw1_)); + } +} + +template <size_t N, bool inverse = false, bool split_format = false, typename T> +KFR_INTRIN void apply_twiddles4(const cvec<T, N>& a1, const cvec<T, N>& a2, const cvec<T, N>& a3, + const cvec<T, N>& tw1, const cvec<T, N>& tw2, const cvec<T, N>& tw3, + cvec<T, N>& w1, cvec<T, N>& w2, cvec<T, N>& w3) +{ + apply_twiddle<N, inverse, split_format>(a1, tw1, w1); + apply_twiddle<N, inverse, split_format>(a2, tw2, w2); + apply_twiddle<N, inverse, split_format>(a3, tw3, w3); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void apply_twiddles4(cvec<T, N>& __restrict a1, cvec<T, N>& __restrict a2, + cvec<T, N>& __restrict a3, const cvec<T, N>& tw1, const cvec<T, N>& tw2, + const cvec<T, N>& tw3) +{ + apply_twiddles4<N, inverse>(a1, a2, a3, tw1, tw2, tw3, a1, a2, a3); +} + +template <size_t N, bool inverse = false, typename T, typename = u8[N - 1]> +KFR_INTRIN void apply_twiddles4(cvec<T, N>& __restrict a1, cvec<T, N>& __restrict a2, + cvec<T, N>& __restrict a3, const cvec<T, 1>& tw1, const cvec<T, 1>& tw2, + const cvec<T, 1>& tw3) +{ + apply_twiddles4<N, inverse>(a1, a2, a3, resize<N * 2>(tw1), resize<N * 2>(tw2), resize<N * 2>(tw3)); +} + +template <size_t N, bool inverse = false, typename T, typename = u8[N - 2]> +KFR_INTRIN void apply_twiddles4(cvec<T, N>& __restrict a1, cvec<T, N>& __restrict a2, + cvec<T, N>& __restrict a3, cvec<T, N / 2> tw1, cvec<T, N / 2> tw2, + cvec<T, N / 2> tw3) +{ + apply_twiddles4<N, inverse>(a1, a2, a3, resize<N * 2>(tw1), resize<N * 2>(tw2), resize<N * 2>(tw3)); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void apply_vertical_twiddles4(cvec<T, N * 4>& b, cvec<T, N * 4>& c, cvec<T, N * 4>& d) +{ + cvec<T, 4> b0, b1, b2, b3; + cvec<T, 4> c0, c1, c2, c3; + cvec<T, 4> d0, d1, d2, d3; + + split(b, b0, b1, b2, b3); + split(c, c0, c1, c2, c3); + split(d, d0, d1, d2, d3); + + b1 = cmul_by_twiddle<4, 64, inverse>(b1); + b2 = cmul_by_twiddle<8, 64, inverse>(b2); + b3 = cmul_by_twiddle<12, 64, inverse>(b3); + + c1 = cmul_by_twiddle<8, 64, inverse>(c1); + c2 = cmul_by_twiddle<16, 64, inverse>(c2); + c3 = cmul_by_twiddle<24, 64, inverse>(c3); + + d1 = cmul_by_twiddle<12, 64, inverse>(d1); + d2 = cmul_by_twiddle<24, 64, inverse>(d2); + d3 = cmul_by_twiddle<36, 64, inverse>(d3); + + b = concat(b0, b1, b2, b3); + c = concat(c0, c1, c2, c3); + d = concat(d0, d1, d2, d3); +} + +template <size_t n2, size_t nnstep, size_t N, bool inverse = false, typename T> +KFR_INTRIN void apply_twiddles4(cvec<T, N * 4>& __restrict a0123) +{ + cvec<T, N> a0; + cvec<T, N> a1; + cvec<T, N> a2; + cvec<T, N> a3; + split(a0123, a0, a1, a2, a3); + + cvec<T, N> tw1 = fixed_twiddle<T, N, 64, n2 * nnstep * 1, nnstep * 1, inverse>(), + tw2 = fixed_twiddle<T, N, 64, n2 * nnstep * 2, nnstep * 2, inverse>(), + tw3 = fixed_twiddle<T, N, 64, n2 * nnstep * 3, nnstep * 3, inverse>(); + + apply_twiddles4<N>(a1, a2, a3, tw1, tw2, tw3); + + a0123 = concat(a0, a1, a2, a3); +} + +template <bool inverse, bool aligned, typename T> +KFR_INTRIN void butterfly64(cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in) +{ + cvec<T, 16> w0, w1, w2, w3; + + w0 = cread_group<4, 4, 16, aligned>( + in); // concat(cread<4>(in + 0), cread<4>(in + 16), cread<4>(in + 32), cread<4>(in + 48)); + butterfly4<4, inverse>(w0); + apply_twiddles4<0, 1, 4, inverse>(w0); + + w1 = cread_group<4, 4, 16, aligned>( + in + 4); // concat(cread<4>(in + 4), cread<4>(in + 20), cread<4>(in + 36), cread<4>(in + 52)); + butterfly4<4, inverse>(w1); + apply_twiddles4<4, 1, 4, inverse>(w1); + + w2 = cread_group<4, 4, 16, aligned>( + in + 8); // concat(cread<4>(in + 8), cread<4>(in + 24), cread<4>(in + 40), cread<4>(in + 56)); + butterfly4<4, inverse>(w2); + apply_twiddles4<8, 1, 4, inverse>(w2); + + w3 = cread_group<4, 4, 16, aligned>( + in + 12); // concat(cread<4>(in + 12), cread<4>(in + 28), cread<4>(in + 44), cread<4>(in + 60)); + butterfly4<4, inverse>(w3); + apply_twiddles4<12, 1, 4, inverse>(w3); + + transpose4(w0, w1, w2, w3); + // pass 2: + + butterfly4<4, inverse>(w0); + butterfly4<4, inverse>(w1); + butterfly4<4, inverse>(w2); + butterfly4<4, inverse>(w3); + + transpose4(w0, w1, w2, w3); + + w0 = digitreverse4<2>(w0); + w1 = digitreverse4<2>(w1); + w2 = digitreverse4<2>(w2); + w3 = digitreverse4<2>(w3); + + apply_vertical_twiddles4<4, inverse>(w1, w2, w3); + + // pass 3: + butterfly4<4, inverse>(w3); + cwrite_group<4, 4, 16, aligned>(out + 12, w3); // split(w3, out[3], out[7], out[11], out[15]); + + butterfly4<4, inverse>(w2); + cwrite_group<4, 4, 16, aligned>(out + 8, w2); // split(w2, out[2], out[6], out[10], out[14]); + + butterfly4<4, inverse>(w1); + cwrite_group<4, 4, 16, aligned>(out + 4, w1); // split(w1, out[1], out[5], out[9], out[13]); + + butterfly4<4, inverse>(w0); + cwrite_group<4, 4, 16, aligned>(out, w0); // split(w0, out[0], out[4], out[8], out[12]); +} + +template <bool inverse = false, typename T> +KFR_INTRIN void butterfly16(cvec<T, 16>& v16) +{ + butterfly4<4, inverse>(v16); + apply_twiddles4<0, 4, 4, inverse>(v16); + v16 = digitreverse4<2>(v16); + butterfly4<4, inverse>(v16); +} + +template <size_t index, bool inverse = false, typename T> +KFR_INTRIN void butterfly16_multi_natural(complex<T>* out, const complex<T>* in) +{ + constexpr size_t N = 4; + + cvec<T, 4> a1 = cread<4>(in + index * 4 + 16 * 1); + cvec<T, 4> a5 = cread<4>(in + index * 4 + 16 * 5); + cvec<T, 4> a9 = cread<4>(in + index * 4 + 16 * 9); + cvec<T, 4> a13 = cread<4>(in + index * 4 + 16 * 13); + butterfly4<N, inverse>(cfalse, a1, a5, a9, a13, a1, a5, a9, a13); + a5 = cmul_by_twiddle<1, 16, inverse>(a5); + a9 = cmul_by_twiddle<2, 16, inverse>(a9); + a13 = cmul_by_twiddle<3, 16, inverse>(a13); + + cvec<T, 4> a2 = cread<4>(in + index * 4 + 16 * 2); + cvec<T, 4> a6 = cread<4>(in + index * 4 + 16 * 6); + cvec<T, 4> a10 = cread<4>(in + index * 4 + 16 * 10); + cvec<T, 4> a14 = cread<4>(in + index * 4 + 16 * 14); + butterfly4<N, inverse>(cfalse, a2, a6, a10, a14, a2, a6, a10, a14); + a6 = cmul_by_twiddle<2, 16, inverse>(a6); + a10 = cmul_by_twiddle<4, 16, inverse>(a10); + a14 = cmul_by_twiddle<6, 16, inverse>(a14); + + cvec<T, 4> a3 = cread<4>(in + index * 4 + 16 * 3); + cvec<T, 4> a7 = cread<4>(in + index * 4 + 16 * 7); + cvec<T, 4> a11 = cread<4>(in + index * 4 + 16 * 11); + cvec<T, 4> a15 = cread<4>(in + index * 4 + 16 * 15); + butterfly4<N, inverse>(cfalse, a3, a7, a11, a15, a3, a7, a11, a15); + a7 = cmul_by_twiddle<3, 16, inverse>(a7); + a11 = cmul_by_twiddle<6, 16, inverse>(a11); + a15 = cmul_by_twiddle<9, 16, inverse>(a15); + + cvec<T, 4> a0 = cread<4>(in + index * 4 + 16 * 0); + cvec<T, 4> a4 = cread<4>(in + index * 4 + 16 * 4); + cvec<T, 4> a8 = cread<4>(in + index * 4 + 16 * 8); + cvec<T, 4> a12 = cread<4>(in + index * 4 + 16 * 12); + butterfly4<N, inverse>(cfalse, a0, a4, a8, a12, a0, a4, a8, a12); + butterfly4<N, inverse>(cfalse, a0, a1, a2, a3, a0, a1, a2, a3); + cwrite<4>(out + index * 4 + 16 * 0, a0); + cwrite<4>(out + index * 4 + 16 * 4, a1); + cwrite<4>(out + index * 4 + 16 * 8, a2); + cwrite<4>(out + index * 4 + 16 * 12, a3); + butterfly4<N, inverse>(cfalse, a4, a5, a6, a7, a4, a5, a6, a7); + cwrite<4>(out + index * 4 + 16 * 1, a4); + cwrite<4>(out + index * 4 + 16 * 5, a5); + cwrite<4>(out + index * 4 + 16 * 9, a6); + cwrite<4>(out + index * 4 + 16 * 13, a7); + butterfly4<N, inverse>(cfalse, a8, a9, a10, a11, a8, a9, a10, a11); + cwrite<4>(out + index * 4 + 16 * 2, a8); + cwrite<4>(out + index * 4 + 16 * 6, a9); + cwrite<4>(out + index * 4 + 16 * 10, a10); + cwrite<4>(out + index * 4 + 16 * 14, a11); + butterfly4<N, inverse>(cfalse, a12, a13, a14, a15, a12, a13, a14, a15); + cwrite<4>(out + index * 4 + 16 * 3, a12); + cwrite<4>(out + index * 4 + 16 * 7, a13); + cwrite<4>(out + index * 4 + 16 * 11, a14); + cwrite<4>(out + index * 4 + 16 * 15, a15); +} + +template <size_t index, bool inverse = false, typename T> +KFR_INTRIN void butterfly16_multi_flip(complex<T>* out, const complex<T>* in) +{ + constexpr size_t N = 4; + + cvec<T, 4> a1 = cread<4>(in + index * 4 + 16 * 1); + cvec<T, 4> a5 = cread<4>(in + index * 4 + 16 * 5); + cvec<T, 4> a9 = cread<4>(in + index * 4 + 16 * 9); + cvec<T, 4> a13 = cread<4>(in + index * 4 + 16 * 13); + butterfly4<N, inverse>(cfalse, a1, a5, a9, a13, a1, a5, a9, a13); + a5 = cmul_by_twiddle<1, 16, inverse>(a5); + a9 = cmul_by_twiddle<2, 16, inverse>(a9); + a13 = cmul_by_twiddle<3, 16, inverse>(a13); + + cvec<T, 4> a2 = cread<4>(in + index * 4 + 16 * 2); + cvec<T, 4> a6 = cread<4>(in + index * 4 + 16 * 6); + cvec<T, 4> a10 = cread<4>(in + index * 4 + 16 * 10); + cvec<T, 4> a14 = cread<4>(in + index * 4 + 16 * 14); + butterfly4<N, inverse>(cfalse, a2, a6, a10, a14, a2, a6, a10, a14); + a6 = cmul_by_twiddle<2, 16, inverse>(a6); + a10 = cmul_by_twiddle<4, 16, inverse>(a10); + a14 = cmul_by_twiddle<6, 16, inverse>(a14); + + cvec<T, 4> a3 = cread<4>(in + index * 4 + 16 * 3); + cvec<T, 4> a7 = cread<4>(in + index * 4 + 16 * 7); + cvec<T, 4> a11 = cread<4>(in + index * 4 + 16 * 11); + cvec<T, 4> a15 = cread<4>(in + index * 4 + 16 * 15); + butterfly4<N, inverse>(cfalse, a3, a7, a11, a15, a3, a7, a11, a15); + a7 = cmul_by_twiddle<3, 16, inverse>(a7); + a11 = cmul_by_twiddle<6, 16, inverse>(a11); + a15 = cmul_by_twiddle<9, 16, inverse>(a15); + + cvec<T, 16> w1 = concat(a1, a5, a9, a13); + cvec<T, 16> w2 = concat(a2, a6, a10, a14); + cvec<T, 16> w3 = concat(a3, a7, a11, a15); + + cvec<T, 4> a0 = cread<4>(in + index * 4 + 16 * 0); + cvec<T, 4> a4 = cread<4>(in + index * 4 + 16 * 4); + cvec<T, 4> a8 = cread<4>(in + index * 4 + 16 * 8); + cvec<T, 4> a12 = cread<4>(in + index * 4 + 16 * 12); + butterfly4<N, inverse>(cfalse, a0, a4, a8, a12, a0, a4, a8, a12); + cvec<T, 16> w0 = concat(a0, a4, a8, a12); + + butterfly4<N * 4, inverse>(cfalse, w0, w1, w2, w3, w0, w1, w2, w3); + + w0 = digitreverse4<2>(w0); + w1 = digitreverse4<2>(w1); + w2 = digitreverse4<2>(w2); + w3 = digitreverse4<2>(w3); + + transpose4(w0, w1, w2, w3); + cwrite<16>(out + index * 64 + 16 * 0, cmul(w0, fixed_twiddle<T, 16, 256, 0, index * 4 + 0, inverse>())); + cwrite<16>(out + index * 64 + 16 * 1, cmul(w1, fixed_twiddle<T, 16, 256, 0, index * 4 + 1, inverse>())); + cwrite<16>(out + index * 64 + 16 * 2, cmul(w2, fixed_twiddle<T, 16, 256, 0, index * 4 + 2, inverse>())); + cwrite<16>(out + index * 64 + 16 * 3, cmul(w3, fixed_twiddle<T, 16, 256, 0, index * 4 + 3, inverse>())); +} + +template <size_t n2, size_t nnstep, size_t N, typename T> +KFR_INTRIN void apply_twiddles2(cvec<T, N>& a1) +{ + cvec<T, N> tw1 = fixed_twiddle<T, N, 64, n2 * nnstep * 1, nnstep * 1>(); + + a1 = cmul(a1, tw1); +} + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw3r1 = static_cast<T>(-0.5 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw3i1 = + static_cast<T>(0.86602540378443864676372317075) * twiddleimagmask<T, N, inverse>(); + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly3(cvec<T, N> a00, cvec<T, N> a01, cvec<T, N> a02, cvec<T, N>& w00, cvec<T, N>& w01, + cvec<T, N>& w02) +{ + + const cvec<T, N> sum1 = a01 + a02; + const cvec<T, N> dif1 = swap<2>(a01 - a02); + w00 = a00 + sum1; + + const cvec<T, N> s1 = w00 + sum1 * tw3r1<T, N, inverse>; + + const cvec<T, N> d1 = dif1 * tw3i1<T, N, inverse>; + + w01 = s1 + d1; + w02 = s1 - d1; +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly3(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2) +{ + butterfly3<N, inverse>(a0, a1, a2, a0, a1, a2); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly6(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, + const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, cvec<T, N>& w0, + cvec<T, N>& w1, cvec<T, N>& w2, cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5) +{ + cvec<T, N* 2> a03 = concat(a0, a3); + cvec<T, N* 2> a25 = concat(a2, a5); + cvec<T, N* 2> a41 = concat(a4, a1); + butterfly3<N * 2, inverse>(a03, a25, a41, a03, a25, a41); + cvec<T, N> t0, t1, t2, t3, t4, t5; + split(a03, t0, t1); + split(a25, t2, t3); + split(a41, t4, t5); + t3 = -t3; + cvec<T, N* 2> a04 = concat(t0, t4); + cvec<T, N* 2> a15 = concat(t1, t5); + cvec<T, N * 2> w02, w35; + butterfly2<N * 2>(a04, a15, w02, w35); + split(w02, w0, w2); + split(w35, w3, w5); + + butterfly2<N>(t2, t3, w1, w4); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly6(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, + cvec<T, N>& a5) +{ + butterfly6<N, inverse>(a0, a1, a2, a3, a4, a5, a0, a1, a2, a3, a4, a5); +} + +template <typename T, bool inverse = false> +const static cvec<T, 1> tw9_1 = { T(0.76604444311897803520239265055541), + (inverse ? -1 : 1) * T(-0.64278760968653932632264340990727) }; +template <typename T, bool inverse = false> +const static cvec<T, 1> tw9_2 = { T(0.17364817766693034885171662676931), + (inverse ? -1 : 1) * T(-0.98480775301220805936674302458952) }; +template <typename T, bool inverse = false> +const static cvec<T, 1> tw9_4 = { T(-0.93969262078590838405410927732473), + (inverse ? -1 : 1) * T(-0.34202014332566873304409961468226) }; + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly9(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, + const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, + const cvec<T, N>& a6, const cvec<T, N>& a7, const cvec<T, N>& a8, cvec<T, N>& w0, + cvec<T, N>& w1, cvec<T, N>& w2, cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5, + cvec<T, N>& w6, cvec<T, N>& w7, cvec<T, N>& w8) +{ + cvec<T, N* 3> a012 = concat(a0, a1, a2); + cvec<T, N* 3> a345 = concat(a3, a4, a5); + cvec<T, N* 3> a678 = concat(a6, a7, a8); + butterfly3<N * 3, inverse>(a012, a345, a678, a012, a345, a678); + cvec<T, N> t0, t1, t2, t3, t4, t5, t6, t7, t8; + split(a012, t0, t1, t2); + split(a345, t3, t4, t5); + split(a678, t6, t7, t8); + + t4 = cmul(t4, tw9_1<T, inverse>); + t5 = cmul(t5, tw9_2<T, inverse>); + t7 = cmul(t7, tw9_2<T, inverse>); + t8 = cmul(t8, tw9_4<T, inverse>); + + cvec<T, N* 3> t036 = concat(t0, t3, t6); + cvec<T, N* 3> t147 = concat(t1, t4, t7); + cvec<T, N* 3> t258 = concat(t2, t5, t8); + + butterfly3<N * 3, inverse>(t036, t147, t258, t036, t147, t258); + split(t036, w0, w1, w2); + split(t147, w3, w4, w5); + split(t258, w6, w7, w8); +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly9(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, + cvec<T, N>& a5, cvec<T, N>& a6, cvec<T, N>& a7, cvec<T, N>& a8) +{ + butterfly9<N, inverse>(a0, a1, a2, a3, a4, a5, a6, a7, a8, a0, a1, a2, a3, a4, a5, a6, a7, a8); +} + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw7r1 = static_cast<T>(0.623489801858733530525004884 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw7i1 = + static_cast<T>(0.78183148246802980870844452667) * twiddleimagmask<T, N, inverse>(); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw7r2 = static_cast<T>(-0.2225209339563144042889025645 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw7i2 = + static_cast<T>(0.97492791218182360701813168299) * twiddleimagmask<T, N, inverse>(); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw7r3 = static_cast<T>(-0.90096886790241912623610231951 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw7i3 = + static_cast<T>(0.43388373911755812047576833285) * twiddleimagmask<T, N, inverse>(); + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly7(cvec<T, N> a00, cvec<T, N> a01, cvec<T, N> a02, cvec<T, N> a03, cvec<T, N> a04, + cvec<T, N> a05, cvec<T, N> a06, cvec<T, N>& w00, cvec<T, N>& w01, cvec<T, N>& w02, + cvec<T, N>& w03, cvec<T, N>& w04, cvec<T, N>& w05, cvec<T, N>& w06) +{ + const cvec<T, N> sum1 = a01 + a06; + const cvec<T, N> dif1 = swap<2>(a01 - a06); + const cvec<T, N> sum2 = a02 + a05; + const cvec<T, N> dif2 = swap<2>(a02 - a05); + const cvec<T, N> sum3 = a03 + a04; + const cvec<T, N> dif3 = swap<2>(a03 - a04); + w00 = a00 + sum1 + sum2 + sum3; + + const cvec<T, N> s1 = + w00 + sum1 * tw7r1<T, N, inverse> + sum2 * tw7r2<T, N, inverse> + sum3 * tw7r3<T, N, inverse>; + const cvec<T, N> s2 = + w00 + sum1 * tw7r2<T, N, inverse> + sum2 * tw7r3<T, N, inverse> + sum3 * tw7r1<T, N, inverse>; + const cvec<T, N> s3 = + w00 + sum1 * tw7r3<T, N, inverse> + sum2 * tw7r1<T, N, inverse> + sum3 * tw7r2<T, N, inverse>; + + const cvec<T, N> d1 = + dif1 * tw7i1<T, N, inverse> + dif2 * tw7i2<T, N, inverse> + dif3 * tw7i3<T, N, inverse>; + const cvec<T, N> d2 = + dif1 * tw7i2<T, N, inverse> - dif2 * tw7i3<T, N, inverse> - dif3 * tw7i1<T, N, inverse>; + const cvec<T, N> d3 = + dif1 * tw7i3<T, N, inverse> - dif2 * tw7i1<T, N, inverse> + dif3 * tw7i2<T, N, inverse>; + + w01 = s1 + d1; + w06 = s1 - d1; + w02 = s2 + d2; + w05 = s2 - d2; + w03 = s3 + d3; + w04 = s3 - d3; +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly7(cvec<T, N>& a0, cvec<T, N>& a1, cvec<T, N>& a2, cvec<T, N>& a3, cvec<T, N>& a4, + cvec<T, N>& a5, cvec<T, N>& a6) +{ + butterfly7<N, inverse>(a0, a1, a2, a3, a4, a5, a6, a0, a1, a2, a3, a4, a5, a6); +} + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11r1 = static_cast<T>(0.84125353283118116886181164892 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11i1 = + static_cast<T>(0.54064081745559758210763595432) * twiddleimagmask<T, N, inverse>(); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11r2 = static_cast<T>(0.41541501300188642552927414923 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11i2 = + static_cast<T>(0.90963199535451837141171538308) * twiddleimagmask<T, N, inverse>(); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11r3 = static_cast<T>(-0.14231483827328514044379266862 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11i3 = + static_cast<T>(0.98982144188093273237609203778) * twiddleimagmask<T, N, inverse>(); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11r4 = static_cast<T>(-0.65486073394528506405692507247 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11i4 = + static_cast<T>(0.75574957435425828377403584397) * twiddleimagmask<T, N, inverse>(); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11r5 = static_cast<T>(-0.95949297361449738989036805707 - 1.0); + +template <typename T, size_t N, bool inverse> +static const cvec<T, N> tw11i5 = + static_cast<T>(0.28173255684142969771141791535) * twiddleimagmask<T, N, inverse>(); + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly11(cvec<T, N> a00, cvec<T, N> a01, cvec<T, N> a02, cvec<T, N> a03, cvec<T, N> a04, + cvec<T, N> a05, cvec<T, N> a06, cvec<T, N> a07, cvec<T, N> a08, cvec<T, N> a09, + cvec<T, N> a10, cvec<T, N>& w00, cvec<T, N>& w01, cvec<T, N>& w02, + cvec<T, N>& w03, cvec<T, N>& w04, cvec<T, N>& w05, cvec<T, N>& w06, + cvec<T, N>& w07, cvec<T, N>& w08, cvec<T, N>& w09, cvec<T, N>& w10) +{ + const cvec<T, N> sum1 = a01 + a10; + const cvec<T, N> dif1 = swap<2>(a01 - a10); + const cvec<T, N> sum2 = a02 + a09; + const cvec<T, N> dif2 = swap<2>(a02 - a09); + const cvec<T, N> sum3 = a03 + a08; + const cvec<T, N> dif3 = swap<2>(a03 - a08); + const cvec<T, N> sum4 = a04 + a07; + const cvec<T, N> dif4 = swap<2>(a04 - a07); + const cvec<T, N> sum5 = a05 + a06; + const cvec<T, N> dif5 = swap<2>(a05 - a06); + w00 = a00 + sum1 + sum2 + sum3 + sum4 + sum5; + + const cvec<T, N> s1 = w00 + sum1 * tw11r1<T, N, inverse> + sum2 * tw11r2<T, N, inverse> + + sum3 * tw11r3<T, N, inverse> + sum4 * tw11r4<T, N, inverse> + + sum5 * tw11r5<T, N, inverse>; + const cvec<T, N> s2 = w00 + sum1 * tw11r2<T, N, inverse> + sum2 * tw11r3<T, N, inverse> + + sum3 * tw11r4<T, N, inverse> + sum4 * tw11r5<T, N, inverse> + + sum5 * tw11r1<T, N, inverse>; + const cvec<T, N> s3 = w00 + sum1 * tw11r3<T, N, inverse> + sum2 * tw11r4<T, N, inverse> + + sum3 * tw11r5<T, N, inverse> + sum4 * tw11r1<T, N, inverse> + + sum5 * tw11r2<T, N, inverse>; + const cvec<T, N> s4 = w00 + sum1 * tw11r4<T, N, inverse> + sum2 * tw11r5<T, N, inverse> + + sum3 * tw11r1<T, N, inverse> + sum4 * tw11r2<T, N, inverse> + + sum5 * tw11r3<T, N, inverse>; + const cvec<T, N> s5 = w00 + sum1 * tw11r5<T, N, inverse> + sum2 * tw11r1<T, N, inverse> + + sum3 * tw11r2<T, N, inverse> + sum4 * tw11r3<T, N, inverse> + + sum5 * tw11r4<T, N, inverse>; + + const cvec<T, N> d1 = dif1 * tw11i1<T, N, inverse> + dif2 * tw11i2<T, N, inverse> + + dif3 * tw11i3<T, N, inverse> + dif4 * tw11i4<T, N, inverse> + + dif5 * tw11i5<T, N, inverse>; + const cvec<T, N> d2 = dif1 * tw11i2<T, N, inverse> - dif2 * tw11i3<T, N, inverse> - + dif3 * tw11i4<T, N, inverse> - dif4 * tw11i5<T, N, inverse> - + dif5 * tw11i1<T, N, inverse>; + const cvec<T, N> d3 = dif1 * tw11i3<T, N, inverse> - dif2 * tw11i4<T, N, inverse> + + dif3 * tw11i5<T, N, inverse> + dif4 * tw11i1<T, N, inverse> + + dif5 * tw11i2<T, N, inverse>; + const cvec<T, N> d4 = dif1 * tw11i4<T, N, inverse> - dif2 * tw11i5<T, N, inverse> + + dif3 * tw11i1<T, N, inverse> - dif4 * tw11i2<T, N, inverse> - + dif5 * tw11i3<T, N, inverse>; + const cvec<T, N> d5 = dif1 * tw11i5<T, N, inverse> - dif2 * tw11i1<T, N, inverse> + + dif3 * tw11i2<T, N, inverse> - dif4 * tw11i3<T, N, inverse> + + dif5 * tw11i4<T, N, inverse>; + + w01 = s1 + d1; + w10 = s1 - d1; + w02 = s2 + d2; + w09 = s2 - d2; + w03 = s3 + d3; + w08 = s3 - d3; + w04 = s4 + d4; + w07 = s4 - d4; + w05 = s5 + d5; + w06 = s5 - d5; +} + +template <typename T, size_t N, bool inverse> +const static cvec<T, N> tw5r1 = static_cast<T>(0.30901699437494742410229341718 - 1.0); +template <typename T, size_t N, bool inverse> +const static cvec<T, N> tw5i1 = + static_cast<T>(0.95105651629515357211643933338) * twiddleimagmask<T, N, inverse>(); +template <typename T, size_t N, bool inverse> +const static cvec<T, N> tw5r2 = static_cast<T>(-0.80901699437494742410229341718 - 1.0); +template <typename T, size_t N, bool inverse> +const static cvec<T, N> tw5i2 = + static_cast<T>(0.58778525229247312916870595464) * twiddleimagmask<T, N, inverse>(); + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly5(const cvec<T, N>& a00, const cvec<T, N>& a01, const cvec<T, N>& a02, + const cvec<T, N>& a03, const cvec<T, N>& a04, cvec<T, N>& w00, cvec<T, N>& w01, + cvec<T, N>& w02, cvec<T, N>& w03, cvec<T, N>& w04) +{ + const cvec<T, N> sum1 = a01 + a04; + const cvec<T, N> dif1 = swap<2>(a01 - a04); + const cvec<T, N> sum2 = a02 + a03; + const cvec<T, N> dif2 = swap<2>(a02 - a03); + w00 = a00 + sum1 + sum2; + + const cvec<T, N> s1 = w00 + sum1 * tw5r1<T, N, inverse> + sum2 * tw5r2<T, N, inverse>; + const cvec<T, N> s2 = w00 + sum1 * tw5r2<T, N, inverse> + sum2 * tw5r1<T, N, inverse>; + + const cvec<T, N> d1 = dif1 * tw5i1<T, N, inverse> + dif2 * tw5i2<T, N, inverse>; + const cvec<T, N> d2 = dif1 * tw5i2<T, N, inverse> - dif2 * tw5i1<T, N, inverse>; + + w01 = s1 + d1; + w04 = s1 - d1; + w02 = s2 + d2; + w03 = s2 - d2; +} + +template <size_t N, bool inverse = false, typename T> +KFR_INTRIN void butterfly10(const cvec<T, N>& a0, const cvec<T, N>& a1, const cvec<T, N>& a2, + const cvec<T, N>& a3, const cvec<T, N>& a4, const cvec<T, N>& a5, + const cvec<T, N>& a6, const cvec<T, N>& a7, const cvec<T, N>& a8, + const cvec<T, N>& a9, cvec<T, N>& w0, cvec<T, N>& w1, cvec<T, N>& w2, + cvec<T, N>& w3, cvec<T, N>& w4, cvec<T, N>& w5, cvec<T, N>& w6, cvec<T, N>& w7, + cvec<T, N>& w8, cvec<T, N>& w9) +{ + cvec<T, N* 2> a05 = concat(a0, a5); + cvec<T, N* 2> a27 = concat(a2, a7); + cvec<T, N* 2> a49 = concat(a4, a9); + cvec<T, N* 2> a61 = concat(a6, a1); + cvec<T, N* 2> a83 = concat(a8, a3); + butterfly5<N * 2, inverse>(a05, a27, a49, a61, a83, a05, a27, a49, a61, a83); + cvec<T, N> t0, t1, t2, t3, t4, t5, t6, t7, t8, t9; + split(a05, t0, t1); + split(a27, t2, t3); + split(a49, t4, t5); + split(a61, t6, t7); + split(a83, t8, t9); + t5 = -t5; + + cvec<T, N * 2> t02, t13; + cvec<T, N * 2> w06, w51; + t02 = concat(t0, t2); + t13 = concat(t1, t3); + butterfly2<N * 2>(t02, t13, w06, w51); + split(w06, w0, w6); + split(w51, w5, w1); + + cvec<T, N * 2> t68, t79; + cvec<T, N * 2> w84, w39; + t68 = concat(t6, t8); + t79 = concat(t7, t9); + butterfly2<N * 2>(t68, t79, w84, w39); + split(w84, w8, w4); + split(w39, w3, w9); + butterfly2<N>(t4, t5, w7, w2); +} + +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, vec<T, N>& out0, + vec<T, N>& out1) +{ + butterfly2<N / 2>(in0, in1, out0, out1); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2) +{ + butterfly3<N / 2, inverse>(in0, in1, in2, out0, out1, out2); +} + +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2, + vec<T, N>& out3) +{ + butterfly4<N / 2, inverse>(cfalse, in0, in1, in2, in3, out0, out1, out2, out3); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, vec<T, N>& out0, vec<T, N>& out1, + vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4) +{ + butterfly5<N / 2, inverse>(in0, in1, in2, in3, in4, out0, out1, out2, out3, out4); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, vec<T, N>& out0, + vec<T, N>& out1, vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5) +{ + butterfly6<N / 2, inverse>(in0, in1, in2, in3, in4, in5, out0, out1, out2, out3, out4, out5); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, + const vec<T, N>& in6, vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2, + vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6) +{ + butterfly7<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, out0, out1, out2, out3, out4, out5, out6); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, + const vec<T, N>& in6, const vec<T, N>& in7, vec<T, N>& out0, vec<T, N>& out1, + vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6, + vec<T, N>& out7) +{ + butterfly8<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, out0, out1, out2, out3, out4, out5, + out6, out7); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, + const vec<T, N>& in6, const vec<T, N>& in7, const vec<T, N>& in8, vec<T, N>& out0, + vec<T, N>& out1, vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, + vec<T, N>& out6, vec<T, N>& out7, vec<T, N>& out8) +{ + butterfly9<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, in8, out0, out1, out2, out3, out4, + out5, out6, out7, out8); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, + const vec<T, N>& in6, const vec<T, N>& in7, const vec<T, N>& in8, + const vec<T, N>& in9, vec<T, N>& out0, vec<T, N>& out1, vec<T, N>& out2, + vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6, vec<T, N>& out7, + vec<T, N>& out8, vec<T, N>& out9) +{ + butterfly10<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, in8, in9, out0, out1, out2, out3, + out4, out5, out6, out7, out8, out9); +} +template <bool inverse, typename T, size_t N> +KFR_INTRIN void butterfly(cbool_t<inverse>, const vec<T, N>& in0, const vec<T, N>& in1, const vec<T, N>& in2, + const vec<T, N>& in3, const vec<T, N>& in4, const vec<T, N>& in5, + const vec<T, N>& in6, const vec<T, N>& in7, const vec<T, N>& in8, + const vec<T, N>& in9, const vec<T, N>& in10, vec<T, N>& out0, vec<T, N>& out1, + vec<T, N>& out2, vec<T, N>& out3, vec<T, N>& out4, vec<T, N>& out5, vec<T, N>& out6, + vec<T, N>& out7, vec<T, N>& out8, vec<T, N>& out9, vec<T, N>& out10) +{ + butterfly11<N / 2, inverse>(in0, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out0, out1, out2, + out3, out4, out5, out6, out7, out8, out9, out10); +} +template <bool transposed, typename T, size_t... N, size_t Nout = csum<size_t, N...>()> +KFR_INTRIN void cread_transposed(cbool_t<transposed>, const complex<T>* ptr, vec<T, N>&... w) +{ + vec<T, Nout> temp = read<Nout>(ptr_cast<T>(ptr)); + if (transposed) + temp = ctranspose<sizeof...(N)>(temp); + split(temp, w...); +} + +// Warning: Reads past the end. Use with care +KFR_INTRIN void cread_transposed(cbool_t<true>, const complex<f32>* ptr, cvec<f32, 4>& w0, cvec<f32, 4>& w1, + cvec<f32, 4>& w2) +{ + cvec<f32, 4> w3; + cvec<f32, 16> v16 = concat(cread<4>(ptr), cread<4>(ptr + 3), cread<4>(ptr + 6), cread<4>(ptr + 9)); + v16 = digitreverse4<2>(v16); + split(v16, w0, w1, w2, w3); +} + +KFR_INTRIN void cread_transposed(cbool_t<true>, const complex<f32>* ptr, cvec<f32, 4>& w0, cvec<f32, 4>& w1, + cvec<f32, 4>& w2, cvec<f32, 4>& w3, cvec<f32, 4>& w4) +{ + cvec<f32, 16> v16 = concat(cread<4>(ptr), cread<4>(ptr + 5), cread<4>(ptr + 10), cread<4>(ptr + 15)); + v16 = digitreverse4<2>(v16); + split(v16, w0, w1, w2, w3); + w4 = cgather<4, 5>(ptr + 4); +} + +template <bool transposed, typename T, size_t... N, size_t Nout = csum<size_t, N...>()> +KFR_INTRIN void cwrite_transposed(cbool_t<transposed>, complex<T>* ptr, vec<T, N>... args) +{ + auto temp = concat(args...); + if (transposed) + temp = ctransposeinverse<sizeof...(N)>(temp); + write(ptr_cast<T>(ptr), temp); +} + +template <size_t I, size_t radix, typename T, size_t N, size_t width = N / 2> +KFR_INTRIN vec<T, N> mul_tw(cbool_t<false>, const vec<T, N>& x, const complex<T>* twiddle) +{ + return I == 0 ? x : cmul(x, cread<width>(twiddle + width * (I - 1))); +} +template <size_t I, size_t radix, typename T, size_t N, size_t width = N / 2> +KFR_INTRIN vec<T, N> mul_tw(cbool_t<true>, const vec<T, N>& x, const complex<T>* twiddle) +{ + return I == 0 ? x : cmul_conj(x, cread<width>(twiddle + width * (I - 1))); +} + +// Non-final +template <typename T, size_t width, size_t radix, bool inverse, size_t... I> +KFR_INTRIN void butterfly_helper(csizes_t<I...>, size_t i, csize_t<width>, csize_t<radix>, cbool_t<inverse>, + complex<T>* out, const complex<T>* in, const complex<T>* tw, size_t stride) +{ + carray<cvec<T, width>, radix> inout; + + swallow{ (inout.get(csize_t<I>()) = cread<width>(in + i + stride * I))... }; + + butterfly(cbool_t<inverse>(), inout.template get<I>()..., inout.template get<I>()...); + + swallow{ ( + cwrite<width>(out + i + stride * I, + mul_tw<I, radix>(cbool_t<inverse>(), inout.template get<I>(), tw + i * (radix - 1))), + 0)... }; +} + +// Final +template <typename T, size_t width, size_t radix, bool inverse, size_t... I> +KFR_INTRIN void butterfly_helper(csizes_t<I...>, size_t i, csize_t<width>, csize_t<radix>, cbool_t<inverse>, + complex<T>* out, const complex<T>* in, size_t stride) +{ + carray<cvec<T, width>, radix> inout; + + // swallow{ ( inout.get( csize<I> ) = infn( i, I, cvec<T, width>( ) ) )... }; + cread_transposed(ctrue, in + i * radix, inout.template get<I>()...); + + butterfly(cbool_t<inverse>(), inout.template get<I>()..., inout.template get<I>()...); + + swallow{ (cwrite<width>(out + i + stride * I, inout.get(csize_t<I>())), 0)... }; +} + +template <size_t width, size_t radix, typename... Args> +KFR_INTRIN void butterfly(size_t i, csize_t<width>, csize_t<radix>, Args&&... args) +{ + butterfly_helper(csizeseq_t<radix>(), i, csize_t<width>(), csize_t<radix>(), std::forward<Args>(args)...); +} + +template <typename... Args> +KFR_INTRIN void butterfly_cycle(size_t&, size_t, csize_t<0>, Args&&...) +{ +} +template <size_t width, typename... Args> +KFR_INTRIN void butterfly_cycle(size_t& i, size_t count, csize_t<width>, Args&&... args) +{ + CMT_LOOP_NOUNROLL + for (; i < count / width * width; i += width) + butterfly(i, csize_t<width>(), std::forward<Args>(args)...); + butterfly_cycle(i, count, csize_t<width / 2>(), std::forward<Args>(args)...); +} + +template <size_t width, typename... Args> +KFR_INTRIN void butterflies(size_t count, csize_t<width>, Args&&... args) +{ + CMT_ASSUME(count > 0); + size_t i = 0; + butterfly_cycle(i, count, csize_t<width>(), std::forward<Args>(args)...); +} + +template <typename T, bool inverse, typename Tradix, typename Tstride> +KFR_INTRIN void generic_butterfly_cycle(csize_t<0>, Tradix radix, cbool_t<inverse>, complex<T>*, + const complex<T>*, Tstride, size_t, size_t, const complex<T>*, size_t) +{ +} + +template <size_t width, bool inverse, typename T, typename Tradix, typename Thalfradix, + typename Thalfradixsqr, typename Tstride> +KFR_INTRIN void generic_butterfly_cycle(csize_t<width>, Tradix radix, cbool_t<inverse>, complex<T>* out, + const complex<T>* in, Tstride ostride, Thalfradix halfradix, + Thalfradixsqr halfradix_sqr, const complex<T>* twiddle, size_t i) +{ + CMT_LOOP_NOUNROLL + for (; i < halfradix / width * width; i += width) + { + const cvec<T, 1> in0 = cread<1>(in); + cvec<T, width> sum0 = resize<2 * width>(in0); + cvec<T, width> sum1 = sum0; + + for (size_t j = 0; j < halfradix; j++) + { + const cvec<T, 1> ina = cread<1>(in + (1 + j)); + const cvec<T, 1> inb = cread<1>(in + radix - (j + 1)); + cvec<T, width> tw = cread<width>(twiddle); + if (inverse) + tw = negodd /*cconj*/ (tw); + + cmul_2conj(sum0, sum1, ina, inb, tw); + twiddle += halfradix; + } + twiddle = twiddle - halfradix_sqr + width; + + // if (inverse) + // std::swap(sum0, sum1); + + if (is_constant_val(ostride)) + { + cwrite<width>(out + (1 + i), sum0); + cwrite<width>(out + (radix - (i + 1)) - (width - 1), reverse<2>(sum1)); + } + else + { + cscatter<width>(out + (i + 1) * ostride, ostride, sum0); + cscatter<width>(out + (radix - (i + 1)) * ostride - (width - 1) * ostride, ostride, + reverse<2>(sum1)); + } + } + generic_butterfly_cycle(csize_t<width / 2>(), radix, cbool_t<inverse>(), out, in, ostride, halfradix, + halfradix_sqr, twiddle, i); +} + +template <typename T> +KFR_SINTRIN vec<T, 2> hcadd(vec<T, 2> value) +{ + return value; +} +template <typename T, size_t N, KFR_ENABLE_IF(N >= 4)> +KFR_SINTRIN vec<T, 2> hcadd(vec<T, N> value) +{ + return hcadd(low(value) + high(value)); +} + +template <size_t width, typename T, bool inverse, typename Tstride = csize_t<1>> +KFR_INTRIN void generic_butterfly_w(size_t radix, cbool_t<inverse>, complex<T>* out, const complex<T>* in, + const complex<T>* twiddle, Tstride ostride = Tstride{}) +{ + CMT_ASSUME(radix > 0); + { + cvec<T, width> sum = T(); + size_t j = 0; + CMT_LOOP_NOUNROLL + for (; j < radix / width * width; j += width) + { + sum += cread<width>(in + j); + } + cvec<T, 1> sums = T(); + CMT_LOOP_NOUNROLL + for (; j < radix; j++) + { + sums += cread<1>(in + j); + } + cwrite<1>(out, hcadd(sum) + sums); + } + const auto halfradix = radix / 2; + const auto halfradix_sqr = halfradix * halfradix; + CMT_ASSUME(halfradix > 0); + size_t i = 0; + + generic_butterfly_cycle(csize_t<width>(), radix, cbool_t<inverse>(), out, in, ostride, halfradix, + halfradix * halfradix, twiddle, i); +} + +template <size_t width, size_t radix, typename T, bool inverse, typename Tstride = csize_t<1>> +KFR_INTRIN void spec_generic_butterfly_w(csize_t<radix>, cbool_t<inverse>, complex<T>* out, + const complex<T>* in, const complex<T>* twiddle, + Tstride ostride = Tstride{}) +{ + { + cvec<T, width> sum = T(); + size_t j = 0; + CMT_LOOP_UNROLL + for (; j < radix / width * width; j += width) + { + sum += cread<width>(in + j); + } + cvec<T, 1> sums = T(); + CMT_LOOP_UNROLL + for (; j < radix; j++) + { + sums += cread<1>(in + j); + } + cwrite<1>(out, hcadd(sum) + sums); + } + const size_t halfradix = radix / 2; + const size_t halfradix_sqr = halfradix * halfradix; + CMT_ASSUME(halfradix > 0); + size_t i = 0; + + generic_butterfly_cycle(csize_t<width>(), radix, cbool_t<inverse>(), out, in, ostride, halfradix, + halfradix_sqr, twiddle, i); +} + +template <typename T, bool inverse, typename Tstride = csize_t<1>> +KFR_INTRIN void generic_butterfly(size_t radix, cbool_t<inverse>, complex<T>* out, const complex<T>* in, + complex<T>* temp, const complex<T>* twiddle, Tstride ostride = Tstride{}) +{ + constexpr size_t width = platform<T>::vector_width; + + cswitch(csizes_t<11, 13>(), radix, + [&](auto radix_) CMT_INLINE_LAMBDA { + spec_generic_butterfly_w<width>(radix_, cbool_t<inverse>(), out, in, twiddle, ostride); + }, + [&]() CMT_INLINE_LAMBDA { + generic_butterfly_w<width>(radix, cbool_t<inverse>(), out, in, twiddle, ostride); + }); +} + +template <typename T, size_t N> +constexpr cvec<T, N> cmask08 = broadcast<N * 2, T>(T(), -T()); + +template <typename T, size_t N> +constexpr cvec<T, N> cmask0088 = broadcast<N * 4, T>(T(), T(), -T(), -T()); + +template <bool A = false, typename T, size_t N> +KFR_INTRIN void cbitreverse_write(complex<T>* dest, const vec<T, N>& x) +{ + cwrite<N / 2, A>(dest, bitreverse<2>(x)); +} + +template <bool A = false, typename T, size_t N> +KFR_INTRIN void cdigitreverse4_write(complex<T>* dest, const vec<T, N>& x) +{ + cwrite<N / 2, A>(dest, digitreverse4<2>(x)); +} + +template <size_t N, bool A = false, typename T> +KFR_INTRIN cvec<T, N> cbitreverse_read(const complex<T>* src) +{ + return bitreverse<2>(cread<N, A>(src)); +} + +template <size_t N, bool A = false, typename T> +KFR_INTRIN cvec<T, N> cdigitreverse4_read(const complex<T>* src) +{ + return digitreverse4<2>(cread<N, A>(src)); +} + +#if 1 + +template <> +KFR_INTRIN cvec<f64, 16> cdigitreverse4_read<16, false, f64>(const complex<f64>* src) +{ + return concat(cread<1>(src + 0), cread<1>(src + 4), cread<1>(src + 8), cread<1>(src + 12), + cread<1>(src + 1), cread<1>(src + 5), cread<1>(src + 9), cread<1>(src + 13), + cread<1>(src + 2), cread<1>(src + 6), cread<1>(src + 10), cread<1>(src + 14), + cread<1>(src + 3), cread<1>(src + 7), cread<1>(src + 11), cread<1>(src + 15)); +} +template <> +KFR_INTRIN void cdigitreverse4_write<false, f64, 32>(complex<f64>* dest, const vec<f64, 32>& x) +{ + cwrite<1>(dest, part<16, 0>(x)); + cwrite<1>(dest + 4, part<16, 1>(x)); + cwrite<1>(dest + 8, part<16, 2>(x)); + cwrite<1>(dest + 12, part<16, 3>(x)); + + cwrite<1>(dest + 1, part<16, 4>(x)); + cwrite<1>(dest + 5, part<16, 5>(x)); + cwrite<1>(dest + 9, part<16, 6>(x)); + cwrite<1>(dest + 13, part<16, 7>(x)); + + cwrite<1>(dest + 2, part<16, 8>(x)); + cwrite<1>(dest + 6, part<16, 9>(x)); + cwrite<1>(dest + 10, part<16, 10>(x)); + cwrite<1>(dest + 14, part<16, 11>(x)); + + cwrite<1>(dest + 3, part<16, 12>(x)); + cwrite<1>(dest + 7, part<16, 13>(x)); + cwrite<1>(dest + 11, part<16, 14>(x)); + cwrite<1>(dest + 15, part<16, 15>(x)); +} +#endif +} // namespace internal +} // namespace kfr + +CMT_PRAGMA_MSVC(warning(pop)) diff --git a/sources.cmake b/sources.cmake @@ -91,13 +91,15 @@ set( ${PROJECT_SOURCE_DIR}/include/kfr/cpuid/cpuid_auto.hpp ${PROJECT_SOURCE_DIR}/include/kfr/data/bitrev.hpp ${PROJECT_SOURCE_DIR}/include/kfr/data/sincos.hpp - ${PROJECT_SOURCE_DIR}/include/kfr/dft/bitrev.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/cache.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/convolution.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/fft.hpp - ${PROJECT_SOURCE_DIR}/include/kfr/dft/ft.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/reference_dft.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/dft_c.h + ${PROJECT_SOURCE_DIR}/include/kfr/dft/impl/bitrev.hpp + ${PROJECT_SOURCE_DIR}/include/kfr/dft/impl/dft-impl.hpp + ${PROJECT_SOURCE_DIR}/include/kfr/dft/impl/dft-templates.hpp + ${PROJECT_SOURCE_DIR}/include/kfr/dft/impl/ft.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dsp/biquad.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dsp/biquad_design.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dsp/dcremove.hpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt @@ -52,14 +52,14 @@ else () message(STATUS "MPFR is not found. Skipping transcendental_test") endif () -add_executable(all_tests ${ALL_TESTS_CPP} ../include/kfr/dft/dft-src.cpp) +add_executable(all_tests ${ALL_TESTS_CPP}) target_compile_definitions(all_tests PRIVATE KFR_NO_MAIN) -target_link_libraries(all_tests kfr) +target_link_libraries(all_tests kfr kfr_dft) add_executable(intrinsic_test intrinsic_test.cpp) target_link_libraries(intrinsic_test kfr) -add_executable(dft_test dft_test.cpp ../include/kfr/dft/dft-src.cpp) -target_link_libraries(dft_test kfr) +add_executable(dft_test dft_test.cpp) +target_link_libraries(dft_test kfr kfr_dft) if (MPFR_FOUND AND GMP_FOUND) add_definitions(-DHAVE_MPFR) @@ -72,7 +72,7 @@ endif () function(add_x86_test NAME FLAGS) separate_arguments(FLAGS) - add_executable(all_tests_${NAME} ${ALL_TESTS_CPP} ../include/kfr/dft/dft-src.cpp) + add_executable(all_tests_${NAME} ${ALL_TESTS_CPP} ${KFR_DFT_SRC}) target_compile_options(all_tests_${NAME} PRIVATE ${FLAGS}) target_compile_definitions(all_tests_${NAME} PRIVATE KFR_NO_MAIN) target_link_libraries(all_tests_${NAME} kfr)