kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 00dbe97ada07f41e12c445c2a2e5630c23f91bb5
parent a99d76846dc5809042536f3981e046e3b22db441
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Mon, 22 Jan 2024 07:30:15 +0000

Multidimensional DFT

Diffstat:
Minclude/kfr/base/shape.hpp | 187++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
Minclude/kfr/base/tensor.hpp | 6++++++
Minclude/kfr/dft/fft.hpp | 418++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
Minclude/kfr/dft/reference_dft.hpp | 1-
Minclude/kfr/simd/impl/intrinsics.h | 5+++++
Msrc/dft/fft-impl.hpp | 19++++++++++---------
Mtests/dft_test.cpp | 341++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------
7 files changed, 821 insertions(+), 156 deletions(-)

diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp @@ -97,26 +97,28 @@ KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& st index_t dim = dims - 1); } // namespace internal_generic -template <index_t dims> -struct shape : static_array_base<index_t, csizeseq_t<dims>> +template <index_t Dims> +struct shape : static_array_base<index_t, csizeseq_t<Dims>> { - static_assert(dims <= 256, "Too many dimensions"); - using base = static_array_base<index_t, csizeseq_t<dims>>; + static_assert(Dims <= 256, "Too many dimensions"); + using base = static_array_base<index_t, csizeseq_t<Dims>>; using base::base; constexpr shape(const base& a) : base(a) {} - static_assert(dims < maximum_dims); + static_assert(Dims <= maximum_dims); + + static constexpr size_t dims() { return base::static_size; } - template <int dummy = 0, KFR_ENABLE_IF(dummy == 0 && dims == 1)> + template <int dummy = 0, KFR_ENABLE_IF(dummy == 0 && Dims == 1)> operator index_t() const { return this->front(); } template <typename TI> - static constexpr shape from_std_array(const std::array<TI, dims>& a) + static constexpr shape from_std_array(const std::array<TI, Dims>& a) { shape result; std::copy(a.begin(), a.end(), result.begin()); @@ -124,16 +126,16 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> } template <typename TI = index_t> - constexpr std::array<TI, dims> to_std_array() const + constexpr std::array<TI, Dims> to_std_array() const { - std::array<TI, dims> result{}; + std::array<TI, Dims> result{}; std::copy(this->begin(), this->end(), result.begin()); return result; } bool ge(const shape& other) const { - if constexpr (dims == 1) + if constexpr (Dims == 1) { return this->front() >= other.front(); } @@ -145,17 +147,17 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> index_t trailing_zeros() const { - for (index_t i = 0; i < dims; ++i) + for (index_t i = 0; i < Dims; ++i) { if (revindex(i) != 0) return i; } - return dims; + return Dims; } bool le(const shape& other) const { - if constexpr (dims == 1) + if constexpr (Dims == 1) { return this->front() <= other.front(); } @@ -184,7 +186,7 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> constexpr bool has_infinity() const { - for (index_t i = 0; i < dims; ++i) + for (index_t i = 0; i < Dims; ++i) { if (CMT_UNLIKELY(this->operator[](i) == infinite_size)) return true; @@ -228,13 +230,13 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> constexpr const base* operator->() const { return static_cast<const base*>(this); } - KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<dims>& indices) const + KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<Dims>& indices) const { - if constexpr (dims == 1) + if constexpr (Dims == 1) { return indices[0]; } - else if constexpr (dims == 2) + else if constexpr (Dims == 2) { return (*this)[1] * indices[0] + indices[1]; } @@ -243,33 +245,33 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> size_t result = 0; size_t scale = 1; CMT_LOOP_UNROLL - for (size_t i = 0; i < dims; ++i) + for (size_t i = 0; i < Dims; ++i) { - result += scale * indices[dims - 1 - i]; - scale *= (*this)[dims - 1 - i]; + result += scale * indices[Dims - 1 - i]; + scale *= (*this)[Dims - 1 - i]; } return result; } } - KFR_MEM_INTRINSIC constexpr shape<dims> from_flat(size_t index) const + KFR_MEM_INTRINSIC constexpr shape<Dims> from_flat(size_t index) const { - if constexpr (dims == 1) + if constexpr (Dims == 1) { return { static_cast<index_t>(index) }; } - else if constexpr (dims == 2) + else if constexpr (Dims == 2) { index_t sz = (*this)[1]; return { static_cast<index_t>(index / sz), static_cast<index_t>(index % sz) }; } else { - shape<dims> indices; + shape<Dims> indices; CMT_LOOP_UNROLL - for (size_t i = 0; i < dims; ++i) + for (size_t i = 0; i < Dims; ++i) { - size_t sz = (*this)[dims - 1 - i]; - indices[dims - 1 - i] = index % sz; + size_t sz = (*this)[Dims - 1 - i]; + indices[Dims - 1 - i] = index % sz; index /= sz; } return indices; @@ -281,11 +283,11 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> template <index_t indims, bool stop = false> KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other, cbool_t<stop> = {}) const { - static_assert(indims >= dims); + static_assert(indims >= Dims); if constexpr (stop) - return other.template trim<dims>()->min(**this); + return other.template trim<Dims>()->min(**this); else - return other.template trim<dims>()->min(**this - 1); + return other.template trim<Dims>()->min(**this - 1); } KFR_MEM_INTRINSIC constexpr index_t product() const { return (*this)->product(); } @@ -293,9 +295,9 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> KFR_MEM_INTRINSIC constexpr dimset tomask() const { dimset result(0); - for (index_t i = 0; i < dims; ++i) + for (index_t i = 0; i < Dims; ++i) { - result[i + maximum_dims - dims] = this->operator[](i) == 1 ? 0 : -1; + result[i + maximum_dims - Dims] = this->operator[](i) == 1 ? 0 : -1; } return result; } @@ -303,20 +305,20 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> template <index_t new_dims> constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const { - static_assert(new_dims >= dims); - if constexpr (new_dims == dims) + static_assert(new_dims >= Dims); + if constexpr (new_dims == Dims) return *this; else - return shape<new_dims>{ shape<new_dims - dims>(value), *this }; + return shape<new_dims>{ shape<new_dims - Dims>(value), *this }; } template <index_t odims> constexpr shape<odims> trim() const { - static_assert(odims <= dims); + static_assert(odims <= Dims); if constexpr (odims > 0) { - return this->template slice<dims - odims, odims>(); + return this->template slice<Dims - odims, odims>(); } else { @@ -324,11 +326,34 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> } } - constexpr KFR_MEM_INTRINSIC shape<dims - 1> trunc() const + // 0,1,2,3 -> 1,2,3,0 + constexpr KFR_MEM_INTRINSIC shape rotate_left() const + { + return this->shuffle(csizeseq<Dims, 1> % csize<Dims>); + } + + // 0,1,2,3 -> 3,0,1,2 + constexpr KFR_MEM_INTRINSIC shape rotate_right() const { - if constexpr (dims > 1) + return this->shuffle(csizeseq<Dims, Dims - 1> % csize<Dims>); + } + + constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_back() const + { + if constexpr (Dims > 1) { - return this->template slice<0, dims - 1>(); + return this->template slice<0, Dims - 1>(); + } + else + { + return {}; + } + } + constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_front() const + { + if constexpr (Dims > 1) + { + return this->template slice<1, Dims - 1>(); } else { @@ -336,19 +361,21 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>> } } + constexpr KFR_MEM_INTRINSIC shape<Dims - 1> trunc() const { return remove_back(); } + KFR_MEM_INTRINSIC constexpr index_t revindex(size_t index) const { - return index < dims ? this->operator[](dims - 1 - index) : 1; + return index < Dims ? this->operator[](Dims - 1 - index) : 1; } KFR_MEM_INTRINSIC constexpr void set_revindex(size_t index, index_t val) { - if (CMT_LIKELY(index < dims)) - this->operator[](dims - 1 - index) = val; + if (CMT_LIKELY(index < Dims)) + this->operator[](Dims - 1 - index) = val; } KFR_MEM_INTRINSIC constexpr shape transpose() const { - return this->shuffle(csizeseq<dims, dims - 1, -1>); + return this->shuffle(csizeseq<Dims, Dims - 1, -1>); } }; @@ -359,6 +386,8 @@ struct shape<0> static constexpr size_t size() { return static_size; } + static constexpr size_t dims() { return static_size; } + constexpr shape() = default; constexpr shape(index_t value) {} @@ -404,6 +433,76 @@ struct shape<0> KFR_MEM_INTRINSIC void set_revindex(size_t index, index_t val) {} }; +constexpr inline size_t dynamic_shape = std::numeric_limits<size_t>::max(); + +template <> +struct shape<dynamic_shape> : protected std::vector<index_t> +{ + using std::vector<index_t>::vector; + + using std::vector<index_t>::begin; + using std::vector<index_t>::end; + using std::vector<index_t>::data; + using std::vector<index_t>::size; + using std::vector<index_t>::front; + using std::vector<index_t>::back; + using std::vector<index_t>::operator[]; + + template <index_t Dims, CMT_ENABLE_IF(Dims != dynamic_shape)> + shape(shape<Dims> sh) : shape(sh.begin(), sh.end()) + { + } + + size_t dims() const { return size(); } + + KFR_MEM_INTRINSIC index_t product() const + { + if (std::vector<index_t>::empty()) + return 0; + index_t p = this->front(); + for (size_t i = 1; i < size(); ++i) + { + p *= this->operator[](i); + } + return p; + } + + // 0,1,2,3 -> 1,2,3,0 + KFR_MEM_INTRINSIC shape rotate_left() const + { + shape result = *this; + if (result.size() > 1) + std::rotate(result.begin(), result.begin() + 1, result.end()); + return result; + } + + // 0,1,2,3 -> 3,0,1,2 + KFR_MEM_INTRINSIC shape rotate_right() const + { + shape result = *this; + if (result.size() > 1) + std::rotate(result.begin(), result.end() - 1, result.end()); + return result; + } + + KFR_MEM_INTRINSIC shape remove_back() const + { + shape result = *this; + if (!result.empty()) + result.erase(result.end() - 1); + return result; + } + KFR_MEM_INTRINSIC shape remove_front() const + { + shape result = *this; + if (!result.empty()) + { + result.erase(result.begin()); + } + return result; + } +}; + template <typename... Args> shape(Args&&... args) -> shape<sizeof...(Args)>; diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp @@ -880,6 +880,12 @@ private: memory_finalizer m_finalizer; }; +template <typename T> +struct tensor<T, dynamic_shape> +{ + // Not implemented yet +}; + // template <typename T> // struct tensor<T, 0> // { diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp @@ -27,11 +27,13 @@ #include "../base/basic_expressions.hpp" #include "../base/memory.hpp" +#include "../base/tensor.hpp" #include "../base/univector.hpp" #include "../math/sin_cos.hpp" #include "../simd/complex.hpp" #include "../simd/constants.hpp" #include <bitset> +#include <functional> CMT_PRAGMA_GNU(GCC diagnostic push) #if CMT_HAS_WARNING("-Wshadow") @@ -76,6 +78,10 @@ struct dft_stage printf("%s: %zu, %zu, %zu, %zu, %zu, %zu, %zu, %d, %d\n", name ? name : "unnamed", radix, stage_size, data_size, temp_size, repeats, out_offset, blocks, recursion, can_inplace); } + virtual void copy_input(bool invert, complex<T>* out, const complex<T>* in, size_t size) + { + builtin_memcpy(out, in, sizeof(complex<T>) * size); + } KFR_MEM_INTRINSIC void execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) { @@ -201,6 +207,22 @@ struct dft_plan execute_dft(inv, out.data(), in.data(), temp.data()); } + template <univector_tag Tag1, univector_tag Tag2> + KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, + u8* temp, bool inverse = false) const + { + if (inverse) + execute_dft(ctrue, out.data(), in.data(), temp); + else + execute_dft(cfalse, out.data(), in.data(), temp); + } + template <bool inverse, univector_tag Tag1, univector_tag Tag2> + KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, + u8* temp, cbool_t<inverse> inv) const + { + execute_dft(inv, out.data(), in.data(), temp); + } + autofree<u8> data; size_t data_size; @@ -309,6 +331,11 @@ protected: template <bool inverse> void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const { + if (temp == nullptr && temp_size > 0) + { + return call_with_temp(temp_size, std::bind(&dft_plan<T>::execute_dft<inverse>, this, + cbool_t<inverse>{}, out, in, std::placeholders::_1)); + } auto&& stages = this->stages[inverse]; if (stages.size() == 1 && (stages[0]->can_inplace || in != out)) { @@ -321,12 +348,12 @@ protected: complex<T>* scratch = ptr_cast<complex<T>>( temp + this->temp_size - - align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment)); + align_up(sizeof(complex<T>) * (this->size + 1), platform<>::native_cache_alignment)); bool in_scratch = disposition.test(0); if (in_scratch) { - builtin_memcpy(scratch, in, sizeof(complex<T>) * this->size); + stages[0]->copy_input(inverse, scratch, in, this->size); } const size_t count = stages.size(); @@ -392,6 +419,9 @@ struct dft_plan_real : dft_plan<T> bool is_initialized() const { return size != 0; } + size_t complex_size() const { return complex_size_for(size); } + constexpr static size_t complex_size_for(size_t size) { return size / 2 + 1; } + [[deprecated("cpu parameter is deprecated. Runtime dispatch is used if built with " "KFR_ENABLE_MULTIARCH")]] explicit dft_plan_real(cpu_t cpu, size_t size, dft_pack_format fmt = dft_pack_format::CCs) @@ -420,6 +450,14 @@ struct dft_plan_real : dft_plan<T> void execute(univector<complex<T>, Tag1>&, const univector<complex<T>, Tag2>&, univector<u8, Tag3>&, cbool_t<inverse>) const = delete; + template <univector_tag Tag1, univector_tag Tag2> + void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, u8* temp, + bool inverse = false) const = delete; + + template <bool inverse, univector_tag Tag1, univector_tag Tag2> + void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, u8* temp, + cbool_t<inverse> inv) const = delete; + KFR_MEM_INTRINSIC void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const { this->execute_dft(cfalse, out, ptr_cast<complex<T>>(in), temp); @@ -442,17 +480,373 @@ struct dft_plan_real : dft_plan<T> this->execute_dft(ctrue, ptr_cast<complex<T>>(out.data()), in.data(), temp.data()); } - // Deprecated. fmt must be passed to constructor instead - void execute(complex<T>*, const T*, u8*, dft_pack_format) const = delete; - void execute(T*, const complex<T>*, u8*, dft_pack_format) const = delete; + template <univector_tag Tag1, univector_tag Tag2> + KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<T, Tag2>& in, u8* temp, + cdirect_t = {}) const + { + this->execute_dft(cfalse, out.data(), ptr_cast<complex<T>>(in.data()), temp); + } + template <univector_tag Tag1, univector_tag Tag2> + KFR_MEM_INTRINSIC void execute(univector<T, Tag1>& out, const univector<complex<T>, Tag2>& in, u8* temp, + cinvert_t = {}) const + { + this->execute_dft(ctrue, ptr_cast<complex<T>>(out.data()), in.data(), temp); + } +}; - // Deprecated. fmt must be passed to constructor instead - template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3> - void execute(univector<complex<T>, Tag1>&, const univector<T, Tag2>&, univector<u8, Tag3>&, - dft_pack_format) const = delete; - template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3> - void execute(univector<T, Tag1>&, const univector<complex<T>, Tag2>&, univector<u8, Tag3>&, - dft_pack_format) const = delete; +/// @brief Multidimensional DFT +template <typename T, index_t Dims> +struct dft_plan_md +{ + shape<Dims> size; + size_t temp_size; + + dft_plan_md(const dft_plan_md&) = delete; + dft_plan_md(dft_plan_md&&) = default; + dft_plan_md& operator=(const dft_plan_md&) = delete; + dft_plan_md& operator=(dft_plan_md&&) = default; + + bool is_initialized() const { return size.product() != 0; } + + void dump() const + { + for (const auto& d : dfts) + { + d.dump(); + } + } + + explicit dft_plan_md(shape<Dims> size) : size(std::move(size)), temp_size(0) + { + if constexpr (Dims == dynamic_shape) + { + dfts.resize(this->size.dims()); + } + for (index_t i = 0; i < this->size.dims(); ++i) + { + dfts[i] = dft_plan<T>(this->size[i]); + temp_size = std::max(temp_size, dfts[i].temp_size); + } + } + + void execute(complex<T>* out, const complex<T>* in, u8* temp, bool inverse = false) const + { + if (inverse) + execute_dft(ctrue, out, in, temp); + else + execute_dft(cfalse, out, in, temp); + } + + template <index_t UDims = Dims, CMT_ENABLE_IF(UDims != dynamic_shape)> + void execute(const tensor<complex<T>, Dims>& out, const tensor<complex<T>, Dims>& in, u8* temp, + bool inverse = false) const + { + KFR_LOGIC_CHECK(in.shape() == this->size && out.shape() == this->size, + "dft_plan_md: incorrect tensor shapes"); + KFR_LOGIC_CHECK(in.is_contiguous() && out.is_contiguous(), "dft_plan_md: tensors must be contiguous"); + if (inverse) + execute_dft(ctrue, out.data(), in.data(), temp); + else + execute_dft(cfalse, out.data(), in.data(), temp); + } + template <bool inverse = false> + void execute(complex<T>* out, const complex<T>* in, u8* temp, cbool_t<inverse> = {}) const + { + execute_dft(cbool<inverse>, out, in, temp); + } + +private: + template <bool inverse> + KFR_INTRINSIC void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const + { + if (temp == nullptr && temp_size > 0) + { + return call_with_temp(temp_size, std::bind(&dft_plan_md<T, Dims>::execute_dft<inverse>, this, + cbool_t<inverse>{}, out, in, std::placeholders::_1)); + } + if (size.dims() == 1) + { + dfts[0].execute(out, in, temp, cbool<inverse>); + } + else + { + execute_dim(cbool<inverse>, out, in, temp); + } + } + KFR_INTRINSIC void execute_dim(cfalse_t, complex<T>* out, const complex<T>* in, u8* temp) const + { + shape<Dims> sh = size; + index_t total = size.product(); + index_t axis = size.dims() - 1; + for (;;) + { + if (size[axis] > 1) + { + for (index_t o = 0; o < total; o += sh.back()) + dfts[axis].execute(out + o, in + o, temp, cfalse); + } + else + { + builtin_memcpy(out, in, sizeof(complex<T>) * total); + } + + matrix_transpose(out, out, shape{ sh.remove_back().product(), sh.back() }); + + if (axis == 0) + break; + + sh = sh.rotate_right(); + in = out; + --axis; + } + } + KFR_INTRINSIC void execute_dim(ctrue_t, complex<T>* out, const complex<T>* in, u8* temp) const + { + shape<Dims> sh = size; + index_t total = size.product(); + index_t axis = 0; + for (;;) + { + matrix_transpose(out, in, shape{ sh.front(), sh.remove_front().product() }); + + if (size[axis] > 1) + { + for (index_t o = 0; o < total; o += sh.front()) + dfts[axis].execute(out + o, out + o, temp, ctrue); + } + + if (axis == size.dims() - 1) + break; + + sh = sh.rotate_left(); + in = out; + ++axis; + } + } + using dft_list = + std::conditional_t<Dims == dynamic_shape, std::vector<dft_plan<T>>, std::array<dft_plan<T>, Dims>>; + dft_list dfts; +}; + +/// @brief Multidimensional DFT +template <typename T, index_t Dims> +struct dft_plan_md_real +{ + shape<Dims> size; + size_t temp_size; + bool real_out_is_enough; + + dft_plan_md_real(const dft_plan_md_real&) = delete; + dft_plan_md_real(dft_plan_md_real&&) = default; + dft_plan_md_real& operator=(const dft_plan_md_real&) = delete; + dft_plan_md_real& operator=(dft_plan_md_real&&) = default; + + bool is_initialized() const { return size.product() != 0; } + + void dump() const + { + for (const auto& d : dfts) + { + d.dump(); + } + dft_real.dump(); + } + + shape<Dims> complex_size() const { return complex_size_for(size); } + constexpr static shape<Dims> complex_size_for(shape<Dims> size) + { + if (size.dims() > 0) + size.back() = dft_plan_real<T>::complex_size_for(size.back()); + return size; + } + + size_t real_out_size() const { return real_out_size_for(size); } + constexpr static size_t real_out_size_for(shape<Dims> size) + { + return complex_size_for(size).product() * 2; + } + + explicit dft_plan_md_real(shape<Dims> size, bool real_out_is_enough = false) + : size(std::move(size)), temp_size(0), real_out_is_enough(real_out_is_enough) + { + if (this->size.dims() > 0) + { + if constexpr (Dims == dynamic_shape) + { + dfts.resize(this->size.dims()); + } + for (index_t i = 0; i < this->size.dims() - 1; ++i) + { + dfts[i] = dft_plan<T>(this->size[i]); + temp_size = std::max(temp_size, dfts[i].temp_size); + } + dft_real = dft_plan_real<T>(this->size.back()); + temp_size = std::max(temp_size, dft_real.temp_size); + } + if (!this->real_out_is_enough) + { + temp_size += complex_size().product() * sizeof(complex<T>); + } + } + + void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const + { + execute_dft(cfalse, out, in, temp); + } + void execute(T* out, const complex<T>* in, u8* temp, cinvert_t = {}) const + { + execute_dft(ctrue, out, in, temp); + } + + template <index_t UDims = Dims, CMT_ENABLE_IF(UDims != dynamic_shape)> + void execute(const tensor<complex<T>, Dims>& out, const tensor<T, Dims>& in, u8* temp, + cdirect_t = {}) const + { + KFR_LOGIC_CHECK(in.shape() == this->size && out.shape() == complex_size(), + "dft_plan_md_real: incorrect tensor shapes"); + KFR_LOGIC_CHECK(in.is_contiguous() && out.is_contiguous(), + "dft_plan_md_real: tensors must be contiguous"); + execute_dft(cfalse, out.data(), in.data(), temp); + } + template <index_t UDims = Dims, CMT_ENABLE_IF(UDims != dynamic_shape)> + void execute(const tensor<T, Dims>& out, const tensor<complex<T>, Dims>& in, u8* temp, + cinvert_t = {}) const + { + KFR_LOGIC_CHECK(in.shape() == complex_size() && out.shape() == this->size, + "dft_plan_md_real: incorrect tensor shapes"); + KFR_LOGIC_CHECK(in.is_contiguous() && out.is_contiguous(), + "dft_plan_md_real: tensors must be contiguous"); + execute_dft(ctrue, out.data(), in.data(), temp); + } + void execute(complex<T>* out, const T* in, u8* temp, bool inverse) const + { + KFR_LOGIC_CHECK(inverse, "dft_plan_md_real: incorrect usage"); + execute_dft(cfalse, out, in, temp); + } + void execute(T* out, const complex<T>* in, u8* temp, bool inverse) const + { + KFR_LOGIC_CHECK(!inverse, "dft_plan_md_real: incorrect usage"); + execute_dft(ctrue, out, in, temp); + } + +private: + template <bool inverse, typename Tout, typename Tin> + KFR_INTRINSIC void execute_dft(cbool_t<inverse>, Tout* out, const Tin* in, u8* temp) const + { + if (temp == nullptr && temp_size > 0) + { + return call_with_temp(temp_size, + std::bind(&dft_plan_md_real<T, Dims>::execute_dft<inverse, Tout, Tin>, this, + cbool_t<inverse>{}, out, in, std::placeholders::_1)); + } + if (this->size.dims() == 1) + { + dft_real.execute(out, in, temp, cbool<inverse>); + } + else + { + execute_dim(cbool<inverse>, out, in, temp); + } + } + void expand(T* out, const T* in, size_t count, size_t last_axis) const + { + size_t last_axis_ex = dft_real.complex_size() * 2; + if (in != out) + { + builtin_memmove(out, in, last_axis * sizeof(T)); + } + in += last_axis * (count - 1); + out += last_axis_ex * (count - 1); + for (size_t i = 1; i < count; ++i) + { + builtin_memmove(out, in, last_axis * sizeof(T)); + in -= last_axis; + out -= last_axis_ex; + } +#ifdef KFR_DEBUG + for (size_t i = 0; i < count; ++i) + { + builtin_memset(out + last_axis, 0xFF, (last_axis_ex - last_axis) * sizeof(T)); + out += last_axis_ex; + } +#endif + } + void contract(T* out, const T* in, size_t count, size_t last_axis) const + { + size_t last_axis_ex = dft_real.complex_size() * 2; + if (in != out) + builtin_memmove(out, in, last_axis * sizeof(T)); + in += last_axis_ex; + out += last_axis; + for (size_t i = 1; i < count; ++i) + { + builtin_memmove(out, in, last_axis * sizeof(T)); + in += last_axis_ex; + out += last_axis; + } + } + KFR_INTRINSIC void execute_dim(cfalse_t, complex<T>* out, const T* in_real, u8* temp) const + { + shape<Dims> sh = complex_size(); + index_t total = sh.product(); + index_t axis = size.dims() - 1; + expand(ptr_cast<T>(out), in_real, size.remove_back().product(), size.back()); + for (;;) + { + if (size[axis] > 1) + { + if (axis == size.dims() - 1) + for (index_t o = 0; o < total; o += sh.back()) + dft_real.execute(out + o, ptr_cast<T>(out + o), temp, cfalse); + else + for (index_t o = 0; o < total; o += sh.back()) + dfts[axis].execute(out + o, out + o, temp, cfalse); + } + + matrix_transpose(out, out, shape{ sh.remove_back().product(), sh.back() }); + + if (axis == 0) + break; + + sh = sh.rotate_right(); + --axis; + } + } + KFR_INTRINSIC void execute_dim(ctrue_t, T* out_real, const complex<T>* in, u8* temp) const + { + shape<Dims> sh = complex_size(); + index_t total = sh.product(); + complex<T>* out = real_out_is_enough + ? ptr_cast<complex<T>>(out_real) + : ptr_cast<complex<T>>(temp + temp_size - total * sizeof(complex<T>)); + index_t axis = 0; + for (;;) + { + matrix_transpose(out, in, shape{ sh.front(), sh.remove_front().product() }); + + if (size[axis] > 1) + { + if (axis == size.dims() - 1) + for (index_t o = 0; o < total; o += sh.front()) + dft_real.execute(ptr_cast<T>(out + o), out + o, temp, ctrue); + else + for (index_t o = 0; o < total; o += sh.front()) + dfts[axis].execute(out + o, out + o, temp, ctrue); + } + + if (axis == size.dims() - 1) + break; + + sh = sh.rotate_left(); + in = out; + ++axis; + } + contract(out_real, ptr_cast<T>(out), size.remove_back().product(), size.back()); + } + using dft_list = std::conditional_t<Dims == dynamic_shape, std::vector<dft_plan<T>>, + std::array<dft_plan<T>, const_max(Dims, 1) - 1>>; + dft_list dfts; + dft_plan_real<T> dft_real; }; /// @brief DCT type 2 (unscaled) diff --git a/include/kfr/dft/reference_dft.hpp b/include/kfr/dft/reference_dft.hpp @@ -26,7 +26,6 @@ #pragma once #include "../base/memory.hpp" -#include "../base/small_buffer.hpp" #include "../base/univector.hpp" #include "../simd/complex.hpp" #include "../simd/constants.hpp" diff --git a/include/kfr/simd/impl/intrinsics.h b/include/kfr/simd/impl/intrinsics.h @@ -38,12 +38,17 @@ CMT_INLINE void builtin_memcpy(void* dest, const void* src, size_t size) { __builtin_memcpy(dest, src, size); } +CMT_INLINE void builtin_memmove(void* dest, const void* src, size_t size) +{ + __builtin_memmove(dest, src, size); +} CMT_INLINE void builtin_memset(void* dest, int val, size_t size) { __builtin_memset(dest, val, size); } #else CMT_INLINE float builtin_sqrt(float x) { return ::sqrtf(x); } CMT_INLINE double builtin_sqrt(double x) { return ::sqrt(x); } CMT_INLINE long double builtin_sqrt(long double x) { return ::sqrtl(x); } CMT_INLINE void builtin_memcpy(void* dest, const void* src, size_t size) { ::memcpy(dest, src, size); } +CMT_INLINE void builtin_memmove(void* dest, const void* src, size_t size) { ::memmove(dest, src, size); } CMT_INLINE void builtin_memset(void* dest, int val, size_t size) { ::memset(dest, val, size); } #endif diff --git a/src/dft/fft-impl.hpp b/src/dft/fft-impl.hpp @@ -1759,7 +1759,8 @@ KFR_INTRINSIC void initialize_order(dft_plan<T>* self) typename dft_plan<T>::bitset ored = self->disposition_inplace[0] | self->disposition_inplace[1] | self->disposition_outofplace[0] | self->disposition_outofplace[1]; if (ored.any()) // if scratch needed - self->temp_size += align_up(sizeof(complex<T>) * self->size, platform<>::native_cache_alignment); + self->temp_size += + align_up(sizeof(complex<T>) * (self->size + 1), platform<>::native_cache_alignment); } template <typename T> @@ -1812,6 +1813,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp constexpr size_t width = vector_width<T> * 2; const cvec<T, 1> dc = cread<1>(in); + cvec<T, 1> inmid = cread<1>(in + csize / 2); const size_t count = (csize + 1) / 2; block_process(count - 1, csizes_t<width, 1>(), @@ -1833,10 +1835,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp if (is_even(csize)) { - size_t k = csize / 2; - const cvec<T, 1> fpk = cread<1>(in + k); - const cvec<T, 1> fpnk = negodd(fpk); - cwrite<1>(out + k, fpnk); + cwrite<1>(out + csize / 2, negodd(inmid)); } if (fmt == dft_pack_format::CCs) { @@ -1899,10 +1898,7 @@ void from_fmt(size_t real_size, complex<T>* rtwiddle, complex<T>* out, const com }); if (is_even(csize)) { - size_t k = csize / 2; - const cvec<T, 1> fpk = inmid; - const cvec<T, 1> fpnk = 2 * negodd(fpk); - cwrite<1>(out + k, fpnk); + cwrite<1>(out + csize / 2, 2 * negodd(inmid)); } cwrite<1>(out, dc); } @@ -1985,6 +1981,11 @@ public: from_fmt(this->stage_size, ptr_cast<complex<T>>(this->data), out, in, static_cast<dft_pack_format>(this->user)); } + void copy_input(bool invert, complex<T>* out, const complex<T>* in, size_t size) final + { + size_t extra = invert && static_cast<dft_pack_format>(this->user) == dft_pack_format::CCs ? 1 : 0; + builtin_memcpy(out, in, sizeof(complex<T>) * (size + extra)); + } }; } // namespace intrinsics diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp @@ -175,38 +175,6 @@ constexpr size_t dft_stopsize = 257; #endif #endif -TEST(fft_real) -{ - using float_type = double; - random_state gen = random_init(2247448713, 915890490, 864203735, 2982561); - - constexpr size_t size = 64; - - kfr::univector<float_type, size> in = gen_random_range<float_type>(gen, -1.0, +1.0); - kfr::univector<kfr::complex<float_type>, size / 2 + 1> out = realdft(in); - kfr::univector<float_type, size> rev = irealdft(out) / size; - CHECK(rms(rev - in) <= 0.00001f); -} - -#ifndef KFR_DFT_NO_NPo2 -TEST(fft_real_not_size_4N) -{ - kfr::univector<double, 6> in = counter(); - auto out = realdft(in); - kfr::univector<kfr::complex<double>> expected{ 15.0, { -3, 5.19615242 }, { -3, +1.73205081 }, -3.0 }; - CHECK(rms(cabs(out - expected)) <= 0.00001f); - kfr::univector<double, 6> rev = irealdft(out) / 6; - CHECK(rms(rev - in) <= 0.00001f); - - random_state gen = random_init(2247448713, 915890490, 864203735, 2982561); - constexpr size_t size = 66; - kfr::univector<double, size> in2 = gen_random_range<double>(gen, -1.0, +1.0); - kfr::univector<kfr::complex<double>, size / 2 + 1> out2 = realdft(in2); - kfr::univector<double, size> rev2 = irealdft(out2) / size; - CHECK(rms(rev2 - in2) <= 0.00001f); -} -#endif - TEST(fft_accuracy) { #ifdef DEBUG_DFT_PROGRESS @@ -229,66 +197,85 @@ TEST(fft_accuracy) println(sizes); #endif - testo::matrix(named("type") = dft_float_types, // - named("size") = sizes, // - [&gen](auto type, size_t size) - { - using float_type = type_of<decltype(type)>; - const double min_prec = 0.000001 * std::log(size) * size; - - for (bool inverse : { false, true }) - { - testo::scope s(inverse ? "complex-inverse" : "complex-direct"); - univector<complex<float_type>> in = - truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size); - univector<complex<float_type>> out = in; - univector<complex<float_type>> refout = out; - univector<complex<float_type>> outo = in; - const dft_plan<float_type> dft(size); - double min_prec2 = dft.arblen ? 2 * min_prec : min_prec; - if (!inverse) - { + testo::matrix( + named("type") = dft_float_types, // + named("size") = sizes, // + [&gen](auto type, size_t size) + { + using float_type = type_of<decltype(type)>; + const double min_prec = 0.000001 * std::log(size) * size; + + for (bool inverse : { false, true }) + { + testo::scope s(inverse ? "complex-inverse" : "complex-direct"); + univector<complex<float_type>> in = + truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size); + univector<complex<float_type>> out = in; + univector<complex<float_type>> refout = out; + univector<complex<float_type>> outo = in; + const dft_plan<float_type> dft(size); + double min_prec2 = dft.arblen ? 2 * min_prec : min_prec; + if (!inverse) + { #if DEBUG_DFT_PROGRESS - dft.dump(); + dft.dump(); #endif - } - univector<u8> temp(dft.temp_size); - - reference_dft(refout.data(), in.data(), size, inverse); - dft.execute(outo, in, temp, inverse); - dft.execute(out, out, temp, inverse); - - const float_type rms_diff_inplace = rms(cabs(refout - out)); - CHECK(rms_diff_inplace <= min_prec2); - const float_type rms_diff_outofplace = rms(cabs(refout - outo)); - CHECK(rms_diff_outofplace <= min_prec2); - } - - if (is_even(size)) - { - univector<float_type> in = - truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size); - - univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), size); - univector<complex<float_type>> refout = truncate(dimensions<1>(scalar(qnan)), size); - const dft_plan_real<float_type> dft(size); - univector<u8> temp(dft.temp_size); - - testo::scope s("real-direct"); - reference_dft(refout.data(), in.data(), size); - dft.execute(out, in, temp); - float_type rms_diff = - rms(cabs(refout.truncate(size / 2 + 1) - out.truncate(size / 2 + 1))); - CHECK(rms_diff <= min_prec); - - univector<float_type> out2(size, 0.f); - s.text = "real-inverse"; - dft.execute(out2, out, temp); - out2 = out2 / size; - rms_diff = rms(in - out2); - CHECK(rms_diff <= min_prec); - } - }); + } + univector<u8> temp(dft.temp_size); + + reference_dft(refout.data(), in.data(), size, inverse); + dft.execute(outo, in, temp, inverse); + dft.execute(out, out, temp, inverse); + + const float_type rms_diff_inplace = rms(cabs(refout - out)); + CHECK(rms_diff_inplace <= min_prec2); + const float_type rms_diff_outofplace = rms(cabs(refout - outo)); + CHECK(rms_diff_outofplace <= min_prec2); + } + + if (is_even(size)) + { + index_t csize = dft_plan_real<float_type>::complex_size_for(size); + univector<float_type> in = truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size); + + univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), csize); + univector<complex<float_type>> refout = truncate(dimensions<1>(scalar(qnan)), csize); + const dft_plan_real<float_type> dft(size); + univector<u8> temp(dft.temp_size); + + { + testo::scope s("real-direct"); + reference_dft(refout.data(), in.data(), size); + dft.execute(out, in, temp); + float_type rms_diff_outofplace = rms(cabs(refout - out)); + CHECK(rms_diff_outofplace <= min_prec); + + univector<complex<float_type>> outi(csize); + outi = padded(make_univector(ptr_cast<complex<float_type>>(in.data()), size / 2), + complex<float_type>{ 0.f }); + dft.execute(outi.data(), ptr_cast<float_type>(outi.data()), temp.data()); + float_type rms_diff_inplace = rms(cabs(refout - outi.truncate(csize))); + CHECK(rms_diff_inplace <= min_prec); + } + + { + testo::scope s("real-inverse"); + univector<float_type> out2(size, 0.f); + dft.execute(out2, out, temp); + out2 = out2 / size; + float_type rms_diff_outofplace = rms(in - out2); + CHECK(rms_diff_outofplace <= min_prec); + + univector<float_type> outi(2 * csize); + outi = make_univector(ptr_cast<float_type>(out.data()), 2 * csize); + + dft.execute(outi.data(), ptr_cast<complex<float_type>>(outi.data()), temp.data()); + outi = outi / size; + float_type rms_diff_inplace = rms(in - outi.truncate(size)); + CHECK(rms_diff_inplace <= min_prec); + } + } + }); } TEST(dct) @@ -319,6 +306,180 @@ TEST(dct) CHECK(rms(refoutinv - outinv) < 0.00001f); } + +template <typename T, index_t Dims, typename dft_type, typename dft_real_type> +static void test_dft_md_t(random_state& gen, shape<Dims> shape) +{ + index_t size = shape.product(); + testo::scope s(as_string("shape=", shape)); + + const double min_prec = 0.000002 * std::log(size) * size; + + { + const dft_type dft(shape); +#if DEBUG_DFT_PROGRESS + dft.dump(); +#endif + univector<complex<T>> in = truncate(gen_random_range<T>(gen, -1.0, +1.0), size); + for (bool inverse : { false, true }) + { + testo::scope s(inverse ? "complex-inverse" : "complex-direct"); + univector<complex<T>> out = in; + univector<complex<T>> refout = out; + univector<complex<T>> outo = in; + univector<u8> temp(dft.temp_size); + + reference_dft_md(refout.data(), in.data(), shape, inverse); + dft.execute(outo.data(), in.data(), temp.data(), inverse); + dft.execute(out.data(), out.data(), temp.data(), inverse); + + const T rms_diff_inplace = rms(cabs(refout - out)); + CHECK(rms_diff_inplace <= min_prec); + const T rms_diff_outofplace = rms(cabs(refout - outo)); + CHECK(rms_diff_outofplace <= min_prec); + } + } + + if (is_even(shape.back())) + { + index_t csize = dft_plan_md_real<float, Dims>::complex_size_for(shape).product(); + univector<T> in = truncate(gen_random_range<T>(gen, -1.0, +1.0), size); + + univector<complex<T>> out = truncate(dimensions<1>(scalar(qnan)), csize); + univector<complex<T>> refout = truncate(dimensions<1>(scalar(qnan)), csize); + const dft_real_type dft(shape, true); +#if DEBUG_DFT_PROGRESS + dft.dump(); +#endif + univector<u8> temp(dft.temp_size); + + { + testo::scope s("real-direct"); + reference_dft_md(refout.data(), in.data(), shape); + dft.execute(out.data(), in.data(), temp.data()); + T rms_diff_outofplace = rms(cabs(refout - out)); + CHECK(rms_diff_outofplace <= min_prec); + + univector<complex<T>> outi(csize); + outi = padded(make_univector(ptr_cast<complex<T>>(in.data()), size / 2), complex<T>{ 0.f }); + dft.execute(outi.data(), ptr_cast<T>(outi.data()), temp.data()); + T rms_diff_inplace = rms(cabs(refout - outi)); + CHECK(rms_diff_inplace <= min_prec); + } + + { + testo::scope s("real-inverse"); + univector<T> out2(dft.real_out_size(), 0.f); + dft.execute(out2.data(), out.data(), temp.data()); + out2 = out2 / size; + T rms_diff_outofplace = rms(in - out2.truncate(size)); + CHECK(rms_diff_outofplace <= min_prec); + + univector<T> outi(2 * csize); + outi = make_univector(ptr_cast<T>(out.data()), 2 * csize); + dft.execute(outi.data(), ptr_cast<complex<T>>(outi.data()), temp.data()); + outi = outi / size; + T rms_diff_inplace = rms(in - outi.truncate(size)); + CHECK(rms_diff_inplace <= min_prec); + } + } +} + +template <typename T, index_t Dims> +static void test_dft_md(random_state& gen, shape<Dims> shape) +{ + { + testo::scope s("compile-time dims"); + test_dft_md_t<T, Dims, dft_plan_md<T, Dims>, dft_plan_md_real<T, Dims>>(gen, shape); + } + { + testo::scope s("runtime dims"); + test_dft_md_t<T, Dims, dft_plan_md<T, dynamic_shape>, dft_plan_md_real<T, dynamic_shape>>(gen, shape); + } +} + +TEST(dft_md) +{ + random_state gen = random_init(2247448713, 915890490, 864203735, 2982561); + + testo::matrix(named("type") = dft_float_types, // + [&gen](auto type) + { + using float_type = type_of<decltype(type)>; + test_dft_md<float_type>(gen, shape{ 120 }); + test_dft_md<float_type>(gen, shape{ 2, 60 }); + test_dft_md<float_type>(gen, shape{ 3, 40 }); + test_dft_md<float_type>(gen, shape{ 4, 30 }); + test_dft_md<float_type>(gen, shape{ 5, 24 }); + test_dft_md<float_type>(gen, shape{ 6, 20 }); + test_dft_md<float_type>(gen, shape{ 8, 15 }); + test_dft_md<float_type>(gen, shape{ 10, 12 }); + test_dft_md<float_type>(gen, shape{ 12, 10 }); + test_dft_md<float_type>(gen, shape{ 15, 8 }); + test_dft_md<float_type>(gen, shape{ 20, 6 }); + test_dft_md<float_type>(gen, shape{ 24, 5 }); + test_dft_md<float_type>(gen, shape{ 30, 4 }); + test_dft_md<float_type>(gen, shape{ 40, 3 }); + test_dft_md<float_type>(gen, shape{ 60, 2 }); + + test_dft_md<float_type>(gen, shape{ 2, 3, 24 }); + test_dft_md<float_type>(gen, shape{ 12, 5, 2 }); + test_dft_md<float_type>(gen, shape{ 5, 12, 2 }); + + test_dft_md<float_type>(gen, shape{ 2, 3, 2, 12 }); + test_dft_md<float_type>(gen, shape{ 3, 4, 5, 2 }); + test_dft_md<float_type>(gen, shape{ 5, 4, 3, 2 }); + + test_dft_md<float_type>(gen, shape{ 5, 2, 2, 3, 2 }); + test_dft_md<float_type>(gen, shape{ 2, 5, 2, 2, 3 }); + + test_dft_md<float_type>(gen, shape{ 1, 120 }); + test_dft_md<float_type>(gen, shape{ 120, 1 }); + test_dft_md<float_type>(gen, shape{ 2, 1, 1, 60 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 10, 2, 1, 3 }); + + test_dft_md<float_type>(gen, shape{ 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4, 4 }); + test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }); +#if defined NDEBUG + test_dft_md<float_type>(gen, shape{ 512, 512 }); + test_dft_md<float_type>(gen, shape{ 32, 32, 32 }); + test_dft_md<float_type>(gen, shape{ 8, 8, 8, 8 }); + test_dft_md<float_type>(gen, shape{ 2, 2, 2, 2, 2, 2 }); + + test_dft_md<float_type>(gen, shape{ 1, 65536 }); + test_dft_md<float_type>(gen, shape{ 2, 65536 }); + test_dft_md<float_type>(gen, shape{ 3, 65536 }); + test_dft_md<float_type>(gen, shape{ 4, 65536 }); + test_dft_md<float_type>(gen, shape{ 65536, 1 }); + test_dft_md<float_type>(gen, shape{ 65536, 2 }); + test_dft_md<float_type>(gen, shape{ 65536, 3 }); + test_dft_md<float_type>(gen, shape{ 65536, 4 }); + + test_dft_md<float_type>(gen, shape{ 1, 2 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 3 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6, 7 }); + test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6, 7, 8 }); + test_dft_md<float_type>(gen, shape{ 2, 1 }); + test_dft_md<float_type>(gen, shape{ 3, 2, 1 }); + test_dft_md<float_type>(gen, shape{ 4, 3, 2, 1 }); + test_dft_md<float_type>(gen, shape{ 5, 4, 3, 2, 1 }); + test_dft_md<float_type>(gen, shape{ 6, 5, 4, 3, 2, 1 }); + test_dft_md<float_type>(gen, shape{ 7, 6, 5, 4, 3, 2, 1 }); + test_dft_md<float_type>(gen, shape{ 8, 7, 6, 5, 4, 3, 2, 1 }); +#endif + }); +} + } // namespace CMT_ARCH_NAME #ifndef KFR_NO_MAIN