commit 00dbe97ada07f41e12c445c2a2e5630c23f91bb5
parent a99d76846dc5809042536f3981e046e3b22db441
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Mon, 22 Jan 2024 07:30:15 +0000
Multidimensional DFT
Diffstat:
7 files changed, 821 insertions(+), 156 deletions(-)
diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp
@@ -97,26 +97,28 @@ KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& st
index_t dim = dims - 1);
} // namespace internal_generic
-template <index_t dims>
-struct shape : static_array_base<index_t, csizeseq_t<dims>>
+template <index_t Dims>
+struct shape : static_array_base<index_t, csizeseq_t<Dims>>
{
- static_assert(dims <= 256, "Too many dimensions");
- using base = static_array_base<index_t, csizeseq_t<dims>>;
+ static_assert(Dims <= 256, "Too many dimensions");
+ using base = static_array_base<index_t, csizeseq_t<Dims>>;
using base::base;
constexpr shape(const base& a) : base(a) {}
- static_assert(dims < maximum_dims);
+ static_assert(Dims <= maximum_dims);
+
+ static constexpr size_t dims() { return base::static_size; }
- template <int dummy = 0, KFR_ENABLE_IF(dummy == 0 && dims == 1)>
+ template <int dummy = 0, KFR_ENABLE_IF(dummy == 0 && Dims == 1)>
operator index_t() const
{
return this->front();
}
template <typename TI>
- static constexpr shape from_std_array(const std::array<TI, dims>& a)
+ static constexpr shape from_std_array(const std::array<TI, Dims>& a)
{
shape result;
std::copy(a.begin(), a.end(), result.begin());
@@ -124,16 +126,16 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
}
template <typename TI = index_t>
- constexpr std::array<TI, dims> to_std_array() const
+ constexpr std::array<TI, Dims> to_std_array() const
{
- std::array<TI, dims> result{};
+ std::array<TI, Dims> result{};
std::copy(this->begin(), this->end(), result.begin());
return result;
}
bool ge(const shape& other) const
{
- if constexpr (dims == 1)
+ if constexpr (Dims == 1)
{
return this->front() >= other.front();
}
@@ -145,17 +147,17 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
index_t trailing_zeros() const
{
- for (index_t i = 0; i < dims; ++i)
+ for (index_t i = 0; i < Dims; ++i)
{
if (revindex(i) != 0)
return i;
}
- return dims;
+ return Dims;
}
bool le(const shape& other) const
{
- if constexpr (dims == 1)
+ if constexpr (Dims == 1)
{
return this->front() <= other.front();
}
@@ -184,7 +186,7 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
constexpr bool has_infinity() const
{
- for (index_t i = 0; i < dims; ++i)
+ for (index_t i = 0; i < Dims; ++i)
{
if (CMT_UNLIKELY(this->operator[](i) == infinite_size))
return true;
@@ -228,13 +230,13 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
constexpr const base* operator->() const { return static_cast<const base*>(this); }
- KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<dims>& indices) const
+ KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<Dims>& indices) const
{
- if constexpr (dims == 1)
+ if constexpr (Dims == 1)
{
return indices[0];
}
- else if constexpr (dims == 2)
+ else if constexpr (Dims == 2)
{
return (*this)[1] * indices[0] + indices[1];
}
@@ -243,33 +245,33 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
size_t result = 0;
size_t scale = 1;
CMT_LOOP_UNROLL
- for (size_t i = 0; i < dims; ++i)
+ for (size_t i = 0; i < Dims; ++i)
{
- result += scale * indices[dims - 1 - i];
- scale *= (*this)[dims - 1 - i];
+ result += scale * indices[Dims - 1 - i];
+ scale *= (*this)[Dims - 1 - i];
}
return result;
}
}
- KFR_MEM_INTRINSIC constexpr shape<dims> from_flat(size_t index) const
+ KFR_MEM_INTRINSIC constexpr shape<Dims> from_flat(size_t index) const
{
- if constexpr (dims == 1)
+ if constexpr (Dims == 1)
{
return { static_cast<index_t>(index) };
}
- else if constexpr (dims == 2)
+ else if constexpr (Dims == 2)
{
index_t sz = (*this)[1];
return { static_cast<index_t>(index / sz), static_cast<index_t>(index % sz) };
}
else
{
- shape<dims> indices;
+ shape<Dims> indices;
CMT_LOOP_UNROLL
- for (size_t i = 0; i < dims; ++i)
+ for (size_t i = 0; i < Dims; ++i)
{
- size_t sz = (*this)[dims - 1 - i];
- indices[dims - 1 - i] = index % sz;
+ size_t sz = (*this)[Dims - 1 - i];
+ indices[Dims - 1 - i] = index % sz;
index /= sz;
}
return indices;
@@ -281,11 +283,11 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
template <index_t indims, bool stop = false>
KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other, cbool_t<stop> = {}) const
{
- static_assert(indims >= dims);
+ static_assert(indims >= Dims);
if constexpr (stop)
- return other.template trim<dims>()->min(**this);
+ return other.template trim<Dims>()->min(**this);
else
- return other.template trim<dims>()->min(**this - 1);
+ return other.template trim<Dims>()->min(**this - 1);
}
KFR_MEM_INTRINSIC constexpr index_t product() const { return (*this)->product(); }
@@ -293,9 +295,9 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
KFR_MEM_INTRINSIC constexpr dimset tomask() const
{
dimset result(0);
- for (index_t i = 0; i < dims; ++i)
+ for (index_t i = 0; i < Dims; ++i)
{
- result[i + maximum_dims - dims] = this->operator[](i) == 1 ? 0 : -1;
+ result[i + maximum_dims - Dims] = this->operator[](i) == 1 ? 0 : -1;
}
return result;
}
@@ -303,20 +305,20 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
template <index_t new_dims>
constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const
{
- static_assert(new_dims >= dims);
- if constexpr (new_dims == dims)
+ static_assert(new_dims >= Dims);
+ if constexpr (new_dims == Dims)
return *this;
else
- return shape<new_dims>{ shape<new_dims - dims>(value), *this };
+ return shape<new_dims>{ shape<new_dims - Dims>(value), *this };
}
template <index_t odims>
constexpr shape<odims> trim() const
{
- static_assert(odims <= dims);
+ static_assert(odims <= Dims);
if constexpr (odims > 0)
{
- return this->template slice<dims - odims, odims>();
+ return this->template slice<Dims - odims, odims>();
}
else
{
@@ -324,11 +326,34 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
}
}
- constexpr KFR_MEM_INTRINSIC shape<dims - 1> trunc() const
+ // 0,1,2,3 -> 1,2,3,0
+ constexpr KFR_MEM_INTRINSIC shape rotate_left() const
+ {
+ return this->shuffle(csizeseq<Dims, 1> % csize<Dims>);
+ }
+
+ // 0,1,2,3 -> 3,0,1,2
+ constexpr KFR_MEM_INTRINSIC shape rotate_right() const
{
- if constexpr (dims > 1)
+ return this->shuffle(csizeseq<Dims, Dims - 1> % csize<Dims>);
+ }
+
+ constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_back() const
+ {
+ if constexpr (Dims > 1)
{
- return this->template slice<0, dims - 1>();
+ return this->template slice<0, Dims - 1>();
+ }
+ else
+ {
+ return {};
+ }
+ }
+ constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_front() const
+ {
+ if constexpr (Dims > 1)
+ {
+ return this->template slice<1, Dims - 1>();
}
else
{
@@ -336,19 +361,21 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
}
}
+ constexpr KFR_MEM_INTRINSIC shape<Dims - 1> trunc() const { return remove_back(); }
+
KFR_MEM_INTRINSIC constexpr index_t revindex(size_t index) const
{
- return index < dims ? this->operator[](dims - 1 - index) : 1;
+ return index < Dims ? this->operator[](Dims - 1 - index) : 1;
}
KFR_MEM_INTRINSIC constexpr void set_revindex(size_t index, index_t val)
{
- if (CMT_LIKELY(index < dims))
- this->operator[](dims - 1 - index) = val;
+ if (CMT_LIKELY(index < Dims))
+ this->operator[](Dims - 1 - index) = val;
}
KFR_MEM_INTRINSIC constexpr shape transpose() const
{
- return this->shuffle(csizeseq<dims, dims - 1, -1>);
+ return this->shuffle(csizeseq<Dims, Dims - 1, -1>);
}
};
@@ -359,6 +386,8 @@ struct shape<0>
static constexpr size_t size() { return static_size; }
+ static constexpr size_t dims() { return static_size; }
+
constexpr shape() = default;
constexpr shape(index_t value) {}
@@ -404,6 +433,76 @@ struct shape<0>
KFR_MEM_INTRINSIC void set_revindex(size_t index, index_t val) {}
};
+constexpr inline size_t dynamic_shape = std::numeric_limits<size_t>::max();
+
+template <>
+struct shape<dynamic_shape> : protected std::vector<index_t>
+{
+ using std::vector<index_t>::vector;
+
+ using std::vector<index_t>::begin;
+ using std::vector<index_t>::end;
+ using std::vector<index_t>::data;
+ using std::vector<index_t>::size;
+ using std::vector<index_t>::front;
+ using std::vector<index_t>::back;
+ using std::vector<index_t>::operator[];
+
+ template <index_t Dims, CMT_ENABLE_IF(Dims != dynamic_shape)>
+ shape(shape<Dims> sh) : shape(sh.begin(), sh.end())
+ {
+ }
+
+ size_t dims() const { return size(); }
+
+ KFR_MEM_INTRINSIC index_t product() const
+ {
+ if (std::vector<index_t>::empty())
+ return 0;
+ index_t p = this->front();
+ for (size_t i = 1; i < size(); ++i)
+ {
+ p *= this->operator[](i);
+ }
+ return p;
+ }
+
+ // 0,1,2,3 -> 1,2,3,0
+ KFR_MEM_INTRINSIC shape rotate_left() const
+ {
+ shape result = *this;
+ if (result.size() > 1)
+ std::rotate(result.begin(), result.begin() + 1, result.end());
+ return result;
+ }
+
+ // 0,1,2,3 -> 3,0,1,2
+ KFR_MEM_INTRINSIC shape rotate_right() const
+ {
+ shape result = *this;
+ if (result.size() > 1)
+ std::rotate(result.begin(), result.end() - 1, result.end());
+ return result;
+ }
+
+ KFR_MEM_INTRINSIC shape remove_back() const
+ {
+ shape result = *this;
+ if (!result.empty())
+ result.erase(result.end() - 1);
+ return result;
+ }
+ KFR_MEM_INTRINSIC shape remove_front() const
+ {
+ shape result = *this;
+ if (!result.empty())
+ {
+ result.erase(result.begin());
+ }
+ return result;
+ }
+};
+
template <typename... Args>
shape(Args&&... args) -> shape<sizeof...(Args)>;
diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp
@@ -880,6 +880,12 @@ private:
memory_finalizer m_finalizer;
};
+template <typename T>
+struct tensor<T, dynamic_shape>
+{
+ // Not implemented yet
+};
+
// template <typename T>
// struct tensor<T, 0>
// {
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -27,11 +27,13 @@
#include "../base/basic_expressions.hpp"
#include "../base/memory.hpp"
+#include "../base/tensor.hpp"
#include "../base/univector.hpp"
#include "../math/sin_cos.hpp"
#include "../simd/complex.hpp"
#include "../simd/constants.hpp"
#include <bitset>
+#include <functional>
CMT_PRAGMA_GNU(GCC diagnostic push)
#if CMT_HAS_WARNING("-Wshadow")
@@ -76,6 +78,10 @@ struct dft_stage
printf("%s: %zu, %zu, %zu, %zu, %zu, %zu, %zu, %d, %d\n", name ? name : "unnamed", radix, stage_size,
data_size, temp_size, repeats, out_offset, blocks, recursion, can_inplace);
}
+ virtual void copy_input(bool invert, complex<T>* out, const complex<T>* in, size_t size)
+ {
+ builtin_memcpy(out, in, sizeof(complex<T>) * size);
+ }
KFR_MEM_INTRINSIC void execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp)
{
@@ -201,6 +207,22 @@ struct dft_plan
execute_dft(inv, out.data(), in.data(), temp.data());
}
+ template <univector_tag Tag1, univector_tag Tag2>
+ KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in,
+ u8* temp, bool inverse = false) const
+ {
+ if (inverse)
+ execute_dft(ctrue, out.data(), in.data(), temp);
+ else
+ execute_dft(cfalse, out.data(), in.data(), temp);
+ }
+ template <bool inverse, univector_tag Tag1, univector_tag Tag2>
+ KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in,
+ u8* temp, cbool_t<inverse> inv) const
+ {
+ execute_dft(inv, out.data(), in.data(), temp);
+ }
+
autofree<u8> data;
size_t data_size;
@@ -309,6 +331,11 @@ protected:
template <bool inverse>
void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
{
+ if (temp == nullptr && temp_size > 0)
+ {
+ return call_with_temp(temp_size, std::bind(&dft_plan<T>::execute_dft<inverse>, this,
+ cbool_t<inverse>{}, out, in, std::placeholders::_1));
+ }
auto&& stages = this->stages[inverse];
if (stages.size() == 1 && (stages[0]->can_inplace || in != out))
{
@@ -321,12 +348,12 @@ protected:
complex<T>* scratch = ptr_cast<complex<T>>(
temp + this->temp_size -
- align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment));
+ align_up(sizeof(complex<T>) * (this->size + 1), platform<>::native_cache_alignment));
bool in_scratch = disposition.test(0);
if (in_scratch)
{
- builtin_memcpy(scratch, in, sizeof(complex<T>) * this->size);
+ stages[0]->copy_input(inverse, scratch, in, this->size);
}
const size_t count = stages.size();
@@ -392,6 +419,9 @@ struct dft_plan_real : dft_plan<T>
bool is_initialized() const { return size != 0; }
+ size_t complex_size() const { return complex_size_for(size); }
+ constexpr static size_t complex_size_for(size_t size) { return size / 2 + 1; }
+
[[deprecated("cpu parameter is deprecated. Runtime dispatch is used if built with "
"KFR_ENABLE_MULTIARCH")]] explicit dft_plan_real(cpu_t cpu, size_t size,
dft_pack_format fmt = dft_pack_format::CCs)
@@ -420,6 +450,14 @@ struct dft_plan_real : dft_plan<T>
void execute(univector<complex<T>, Tag1>&, const univector<complex<T>, Tag2>&, univector<u8, Tag3>&,
cbool_t<inverse>) const = delete;
+ template <univector_tag Tag1, univector_tag Tag2>
+ void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, u8* temp,
+ bool inverse = false) const = delete;
+
+ template <bool inverse, univector_tag Tag1, univector_tag Tag2>
+ void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, u8* temp,
+ cbool_t<inverse> inv) const = delete;
+
KFR_MEM_INTRINSIC void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const
{
this->execute_dft(cfalse, out, ptr_cast<complex<T>>(in), temp);
@@ -442,17 +480,373 @@ struct dft_plan_real : dft_plan<T>
this->execute_dft(ctrue, ptr_cast<complex<T>>(out.data()), in.data(), temp.data());
}
- // Deprecated. fmt must be passed to constructor instead
- void execute(complex<T>*, const T*, u8*, dft_pack_format) const = delete;
- void execute(T*, const complex<T>*, u8*, dft_pack_format) const = delete;
+ template <univector_tag Tag1, univector_tag Tag2>
+ KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<T, Tag2>& in, u8* temp,
+ cdirect_t = {}) const
+ {
+ this->execute_dft(cfalse, out.data(), ptr_cast<complex<T>>(in.data()), temp);
+ }
+ template <univector_tag Tag1, univector_tag Tag2>
+ KFR_MEM_INTRINSIC void execute(univector<T, Tag1>& out, const univector<complex<T>, Tag2>& in, u8* temp,
+ cinvert_t = {}) const
+ {
+ this->execute_dft(ctrue, ptr_cast<complex<T>>(out.data()), in.data(), temp);
+ }
+};
- // Deprecated. fmt must be passed to constructor instead
- template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3>
- void execute(univector<complex<T>, Tag1>&, const univector<T, Tag2>&, univector<u8, Tag3>&,
- dft_pack_format) const = delete;
- template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3>
- void execute(univector<T, Tag1>&, const univector<complex<T>, Tag2>&, univector<u8, Tag3>&,
- dft_pack_format) const = delete;
+/// @brief Multidimensional DFT
+template <typename T, index_t Dims>
+struct dft_plan_md
+{
+ shape<Dims> size;
+ size_t temp_size;
+
+ dft_plan_md(const dft_plan_md&) = delete;
+ dft_plan_md(dft_plan_md&&) = default;
+ dft_plan_md& operator=(const dft_plan_md&) = delete;
+ dft_plan_md& operator=(dft_plan_md&&) = default;
+
+ bool is_initialized() const { return size.product() != 0; }
+
+ void dump() const
+ {
+ for (const auto& d : dfts)
+ {
+ d.dump();
+ }
+ }
+
+ explicit dft_plan_md(shape<Dims> size) : size(std::move(size)), temp_size(0)
+ {
+ if constexpr (Dims == dynamic_shape)
+ {
+ dfts.resize(this->size.dims());
+ }
+ for (index_t i = 0; i < this->size.dims(); ++i)
+ {
+ dfts[i] = dft_plan<T>(this->size[i]);
+ temp_size = std::max(temp_size, dfts[i].temp_size);
+ }
+ }
+
+ void execute(complex<T>* out, const complex<T>* in, u8* temp, bool inverse = false) const
+ {
+ if (inverse)
+ execute_dft(ctrue, out, in, temp);
+ else
+ execute_dft(cfalse, out, in, temp);
+ }
+
+ template <index_t UDims = Dims, CMT_ENABLE_IF(UDims != dynamic_shape)>
+ void execute(const tensor<complex<T>, Dims>& out, const tensor<complex<T>, Dims>& in, u8* temp,
+ bool inverse = false) const
+ {
+ KFR_LOGIC_CHECK(in.shape() == this->size && out.shape() == this->size,
+ "dft_plan_md: incorrect tensor shapes");
+ KFR_LOGIC_CHECK(in.is_contiguous() && out.is_contiguous(), "dft_plan_md: tensors must be contiguous");
+ if (inverse)
+ execute_dft(ctrue, out.data(), in.data(), temp);
+ else
+ execute_dft(cfalse, out.data(), in.data(), temp);
+ }
+ template <bool inverse = false>
+ void execute(complex<T>* out, const complex<T>* in, u8* temp, cbool_t<inverse> = {}) const
+ {
+ execute_dft(cbool<inverse>, out, in, temp);
+ }
+
+private:
+ template <bool inverse>
+ KFR_INTRINSIC void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
+ {
+ if (temp == nullptr && temp_size > 0)
+ {
+ return call_with_temp(temp_size, std::bind(&dft_plan_md<T, Dims>::execute_dft<inverse>, this,
+ cbool_t<inverse>{}, out, in, std::placeholders::_1));
+ }
+ if (size.dims() == 1)
+ {
+ dfts[0].execute(out, in, temp, cbool<inverse>);
+ }
+ else
+ {
+ execute_dim(cbool<inverse>, out, in, temp);
+ }
+ }
+ KFR_INTRINSIC void execute_dim(cfalse_t, complex<T>* out, const complex<T>* in, u8* temp) const
+ {
+ shape<Dims> sh = size;
+ index_t total = size.product();
+ index_t axis = size.dims() - 1;
+ for (;;)
+ {
+ if (size[axis] > 1)
+ {
+ for (index_t o = 0; o < total; o += sh.back())
+ dfts[axis].execute(out + o, in + o, temp, cfalse);
+ }
+ else
+ {
+ builtin_memcpy(out, in, sizeof(complex<T>) * total);
+ }
+
+ matrix_transpose(out, out, shape{ sh.remove_back().product(), sh.back() });
+
+ if (axis == 0)
+ break;
+
+ sh = sh.rotate_right();
+ in = out;
+ --axis;
+ }
+ }
+ KFR_INTRINSIC void execute_dim(ctrue_t, complex<T>* out, const complex<T>* in, u8* temp) const
+ {
+ shape<Dims> sh = size;
+ index_t total = size.product();
+ index_t axis = 0;
+ for (;;)
+ {
+ matrix_transpose(out, in, shape{ sh.front(), sh.remove_front().product() });
+
+ if (size[axis] > 1)
+ {
+ for (index_t o = 0; o < total; o += sh.front())
+ dfts[axis].execute(out + o, out + o, temp, ctrue);
+ }
+
+ if (axis == size.dims() - 1)
+ break;
+
+ sh = sh.rotate_left();
+ in = out;
+ ++axis;
+ }
+ }
+ using dft_list =
+ std::conditional_t<Dims == dynamic_shape, std::vector<dft_plan<T>>, std::array<dft_plan<T>, Dims>>;
+ dft_list dfts;
+};
+
+/// @brief Multidimensional DFT
+template <typename T, index_t Dims>
+struct dft_plan_md_real
+{
+ shape<Dims> size;
+ size_t temp_size;
+ bool real_out_is_enough;
+
+ dft_plan_md_real(const dft_plan_md_real&) = delete;
+ dft_plan_md_real(dft_plan_md_real&&) = default;
+ dft_plan_md_real& operator=(const dft_plan_md_real&) = delete;
+ dft_plan_md_real& operator=(dft_plan_md_real&&) = default;
+
+ bool is_initialized() const { return size.product() != 0; }
+
+ void dump() const
+ {
+ for (const auto& d : dfts)
+ {
+ d.dump();
+ }
+ dft_real.dump();
+ }
+
+ shape<Dims> complex_size() const { return complex_size_for(size); }
+ constexpr static shape<Dims> complex_size_for(shape<Dims> size)
+ {
+ if (size.dims() > 0)
+ size.back() = dft_plan_real<T>::complex_size_for(size.back());
+ return size;
+ }
+
+ size_t real_out_size() const { return real_out_size_for(size); }
+ constexpr static size_t real_out_size_for(shape<Dims> size)
+ {
+ return complex_size_for(size).product() * 2;
+ }
+
+ explicit dft_plan_md_real(shape<Dims> size, bool real_out_is_enough = false)
+ : size(std::move(size)), temp_size(0), real_out_is_enough(real_out_is_enough)
+ {
+ if (this->size.dims() > 0)
+ {
+ if constexpr (Dims == dynamic_shape)
+ {
+ dfts.resize(this->size.dims());
+ }
+ for (index_t i = 0; i < this->size.dims() - 1; ++i)
+ {
+ dfts[i] = dft_plan<T>(this->size[i]);
+ temp_size = std::max(temp_size, dfts[i].temp_size);
+ }
+ dft_real = dft_plan_real<T>(this->size.back());
+ temp_size = std::max(temp_size, dft_real.temp_size);
+ }
+ if (!this->real_out_is_enough)
+ {
+ temp_size += complex_size().product() * sizeof(complex<T>);
+ }
+ }
+
+ void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const
+ {
+ execute_dft(cfalse, out, in, temp);
+ }
+ void execute(T* out, const complex<T>* in, u8* temp, cinvert_t = {}) const
+ {
+ execute_dft(ctrue, out, in, temp);
+ }
+
+ template <index_t UDims = Dims, CMT_ENABLE_IF(UDims != dynamic_shape)>
+ void execute(const tensor<complex<T>, Dims>& out, const tensor<T, Dims>& in, u8* temp,
+ cdirect_t = {}) const
+ {
+ KFR_LOGIC_CHECK(in.shape() == this->size && out.shape() == complex_size(),
+ "dft_plan_md_real: incorrect tensor shapes");
+ KFR_LOGIC_CHECK(in.is_contiguous() && out.is_contiguous(),
+ "dft_plan_md_real: tensors must be contiguous");
+ execute_dft(cfalse, out.data(), in.data(), temp);
+ }
+ template <index_t UDims = Dims, CMT_ENABLE_IF(UDims != dynamic_shape)>
+ void execute(const tensor<T, Dims>& out, const tensor<complex<T>, Dims>& in, u8* temp,
+ cinvert_t = {}) const
+ {
+ KFR_LOGIC_CHECK(in.shape() == complex_size() && out.shape() == this->size,
+ "dft_plan_md_real: incorrect tensor shapes");
+ KFR_LOGIC_CHECK(in.is_contiguous() && out.is_contiguous(),
+ "dft_plan_md_real: tensors must be contiguous");
+ execute_dft(ctrue, out.data(), in.data(), temp);
+ }
+ void execute(complex<T>* out, const T* in, u8* temp, bool inverse) const
+ {
+ KFR_LOGIC_CHECK(inverse, "dft_plan_md_real: incorrect usage");
+ execute_dft(cfalse, out, in, temp);
+ }
+ void execute(T* out, const complex<T>* in, u8* temp, bool inverse) const
+ {
+ KFR_LOGIC_CHECK(!inverse, "dft_plan_md_real: incorrect usage");
+ execute_dft(ctrue, out, in, temp);
+ }
+
+private:
+ template <bool inverse, typename Tout, typename Tin>
+ KFR_INTRINSIC void execute_dft(cbool_t<inverse>, Tout* out, const Tin* in, u8* temp) const
+ {
+ if (temp == nullptr && temp_size > 0)
+ {
+ return call_with_temp(temp_size,
+ std::bind(&dft_plan_md_real<T, Dims>::execute_dft<inverse, Tout, Tin>, this,
+ cbool_t<inverse>{}, out, in, std::placeholders::_1));
+ }
+ if (this->size.dims() == 1)
+ {
+ dft_real.execute(out, in, temp, cbool<inverse>);
+ }
+ else
+ {
+ execute_dim(cbool<inverse>, out, in, temp);
+ }
+ }
+ void expand(T* out, const T* in, size_t count, size_t last_axis) const
+ {
+ size_t last_axis_ex = dft_real.complex_size() * 2;
+ if (in != out)
+ {
+ builtin_memmove(out, in, last_axis * sizeof(T));
+ }
+ in += last_axis * (count - 1);
+ out += last_axis_ex * (count - 1);
+ for (size_t i = 1; i < count; ++i)
+ {
+ builtin_memmove(out, in, last_axis * sizeof(T));
+ in -= last_axis;
+ out -= last_axis_ex;
+ }
+#ifdef KFR_DEBUG
+ for (size_t i = 0; i < count; ++i)
+ {
+ builtin_memset(out + last_axis, 0xFF, (last_axis_ex - last_axis) * sizeof(T));
+ out += last_axis_ex;
+ }
+#endif
+ }
+ void contract(T* out, const T* in, size_t count, size_t last_axis) const
+ {
+ size_t last_axis_ex = dft_real.complex_size() * 2;
+ if (in != out)
+ builtin_memmove(out, in, last_axis * sizeof(T));
+ in += last_axis_ex;
+ out += last_axis;
+ for (size_t i = 1; i < count; ++i)
+ {
+ builtin_memmove(out, in, last_axis * sizeof(T));
+ in += last_axis_ex;
+ out += last_axis;
+ }
+ }
+ KFR_INTRINSIC void execute_dim(cfalse_t, complex<T>* out, const T* in_real, u8* temp) const
+ {
+ shape<Dims> sh = complex_size();
+ index_t total = sh.product();
+ index_t axis = size.dims() - 1;
+ expand(ptr_cast<T>(out), in_real, size.remove_back().product(), size.back());
+ for (;;)
+ {
+ if (size[axis] > 1)
+ {
+ if (axis == size.dims() - 1)
+ for (index_t o = 0; o < total; o += sh.back())
+ dft_real.execute(out + o, ptr_cast<T>(out + o), temp, cfalse);
+ else
+ for (index_t o = 0; o < total; o += sh.back())
+ dfts[axis].execute(out + o, out + o, temp, cfalse);
+ }
+
+ matrix_transpose(out, out, shape{ sh.remove_back().product(), sh.back() });
+
+ if (axis == 0)
+ break;
+
+ sh = sh.rotate_right();
+ --axis;
+ }
+ }
+ KFR_INTRINSIC void execute_dim(ctrue_t, T* out_real, const complex<T>* in, u8* temp) const
+ {
+ shape<Dims> sh = complex_size();
+ index_t total = sh.product();
+ complex<T>* out = real_out_is_enough
+ ? ptr_cast<complex<T>>(out_real)
+ : ptr_cast<complex<T>>(temp + temp_size - total * sizeof(complex<T>));
+ index_t axis = 0;
+ for (;;)
+ {
+ matrix_transpose(out, in, shape{ sh.front(), sh.remove_front().product() });
+
+ if (size[axis] > 1)
+ {
+ if (axis == size.dims() - 1)
+ for (index_t o = 0; o < total; o += sh.front())
+ dft_real.execute(ptr_cast<T>(out + o), out + o, temp, ctrue);
+ else
+ for (index_t o = 0; o < total; o += sh.front())
+ dfts[axis].execute(out + o, out + o, temp, ctrue);
+ }
+
+ if (axis == size.dims() - 1)
+ break;
+
+ sh = sh.rotate_left();
+ in = out;
+ ++axis;
+ }
+ contract(out_real, ptr_cast<T>(out), size.remove_back().product(), size.back());
+ }
+ using dft_list = std::conditional_t<Dims == dynamic_shape, std::vector<dft_plan<T>>,
+ std::array<dft_plan<T>, const_max(Dims, 1) - 1>>;
+ dft_list dfts;
+ dft_plan_real<T> dft_real;
};
/// @brief DCT type 2 (unscaled)
diff --git a/include/kfr/dft/reference_dft.hpp b/include/kfr/dft/reference_dft.hpp
@@ -26,7 +26,6 @@
#pragma once
#include "../base/memory.hpp"
-#include "../base/small_buffer.hpp"
#include "../base/univector.hpp"
#include "../simd/complex.hpp"
#include "../simd/constants.hpp"
diff --git a/include/kfr/simd/impl/intrinsics.h b/include/kfr/simd/impl/intrinsics.h
@@ -38,12 +38,17 @@ CMT_INLINE void builtin_memcpy(void* dest, const void* src, size_t size)
{
__builtin_memcpy(dest, src, size);
}
+CMT_INLINE void builtin_memmove(void* dest, const void* src, size_t size)
+{
+ __builtin_memmove(dest, src, size);
+}
CMT_INLINE void builtin_memset(void* dest, int val, size_t size) { __builtin_memset(dest, val, size); }
#else
CMT_INLINE float builtin_sqrt(float x) { return ::sqrtf(x); }
CMT_INLINE double builtin_sqrt(double x) { return ::sqrt(x); }
CMT_INLINE long double builtin_sqrt(long double x) { return ::sqrtl(x); }
CMT_INLINE void builtin_memcpy(void* dest, const void* src, size_t size) { ::memcpy(dest, src, size); }
+CMT_INLINE void builtin_memmove(void* dest, const void* src, size_t size) { ::memmove(dest, src, size); }
CMT_INLINE void builtin_memset(void* dest, int val, size_t size) { ::memset(dest, val, size); }
#endif
diff --git a/src/dft/fft-impl.hpp b/src/dft/fft-impl.hpp
@@ -1759,7 +1759,8 @@ KFR_INTRINSIC void initialize_order(dft_plan<T>* self)
typename dft_plan<T>::bitset ored = self->disposition_inplace[0] | self->disposition_inplace[1] |
self->disposition_outofplace[0] | self->disposition_outofplace[1];
if (ored.any()) // if scratch needed
- self->temp_size += align_up(sizeof(complex<T>) * self->size, platform<>::native_cache_alignment);
+ self->temp_size +=
+ align_up(sizeof(complex<T>) * (self->size + 1), platform<>::native_cache_alignment);
}
template <typename T>
@@ -1812,6 +1813,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp
constexpr size_t width = vector_width<T> * 2;
const cvec<T, 1> dc = cread<1>(in);
+ cvec<T, 1> inmid = cread<1>(in + csize / 2);
const size_t count = (csize + 1) / 2;
block_process(count - 1, csizes_t<width, 1>(),
@@ -1833,10 +1835,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp
if (is_even(csize))
{
- size_t k = csize / 2;
- const cvec<T, 1> fpk = cread<1>(in + k);
- const cvec<T, 1> fpnk = negodd(fpk);
- cwrite<1>(out + k, fpnk);
+ cwrite<1>(out + csize / 2, negodd(inmid));
}
if (fmt == dft_pack_format::CCs)
{
@@ -1899,10 +1898,7 @@ void from_fmt(size_t real_size, complex<T>* rtwiddle, complex<T>* out, const com
});
if (is_even(csize))
{
- size_t k = csize / 2;
- const cvec<T, 1> fpk = inmid;
- const cvec<T, 1> fpnk = 2 * negodd(fpk);
- cwrite<1>(out + k, fpnk);
+ cwrite<1>(out + csize / 2, 2 * negodd(inmid));
}
cwrite<1>(out, dc);
}
@@ -1985,6 +1981,11 @@ public:
from_fmt(this->stage_size, ptr_cast<complex<T>>(this->data), out, in,
static_cast<dft_pack_format>(this->user));
}
+ void copy_input(bool invert, complex<T>* out, const complex<T>* in, size_t size) final
+ {
+ size_t extra = invert && static_cast<dft_pack_format>(this->user) == dft_pack_format::CCs ? 1 : 0;
+ builtin_memcpy(out, in, sizeof(complex<T>) * (size + extra));
+ }
};
} // namespace intrinsics
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -175,38 +175,6 @@ constexpr size_t dft_stopsize = 257;
#endif
#endif
-TEST(fft_real)
-{
- using float_type = double;
- random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
-
- constexpr size_t size = 64;
-
- kfr::univector<float_type, size> in = gen_random_range<float_type>(gen, -1.0, +1.0);
- kfr::univector<kfr::complex<float_type>, size / 2 + 1> out = realdft(in);
- kfr::univector<float_type, size> rev = irealdft(out) / size;
- CHECK(rms(rev - in) <= 0.00001f);
-}
-
-#ifndef KFR_DFT_NO_NPo2
-TEST(fft_real_not_size_4N)
-{
- kfr::univector<double, 6> in = counter();
- auto out = realdft(in);
- kfr::univector<kfr::complex<double>> expected{ 15.0, { -3, 5.19615242 }, { -3, +1.73205081 }, -3.0 };
- CHECK(rms(cabs(out - expected)) <= 0.00001f);
- kfr::univector<double, 6> rev = irealdft(out) / 6;
- CHECK(rms(rev - in) <= 0.00001f);
-
- random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
- constexpr size_t size = 66;
- kfr::univector<double, size> in2 = gen_random_range<double>(gen, -1.0, +1.0);
- kfr::univector<kfr::complex<double>, size / 2 + 1> out2 = realdft(in2);
- kfr::univector<double, size> rev2 = irealdft(out2) / size;
- CHECK(rms(rev2 - in2) <= 0.00001f);
-}
-#endif
-
TEST(fft_accuracy)
{
#ifdef DEBUG_DFT_PROGRESS
@@ -229,66 +197,85 @@ TEST(fft_accuracy)
println(sizes);
#endif
- testo::matrix(named("type") = dft_float_types, //
- named("size") = sizes, //
- [&gen](auto type, size_t size)
- {
- using float_type = type_of<decltype(type)>;
- const double min_prec = 0.000001 * std::log(size) * size;
-
- for (bool inverse : { false, true })
- {
- testo::scope s(inverse ? "complex-inverse" : "complex-direct");
- univector<complex<float_type>> in =
- truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
- univector<complex<float_type>> out = in;
- univector<complex<float_type>> refout = out;
- univector<complex<float_type>> outo = in;
- const dft_plan<float_type> dft(size);
- double min_prec2 = dft.arblen ? 2 * min_prec : min_prec;
- if (!inverse)
- {
+ testo::matrix(
+ named("type") = dft_float_types, //
+ named("size") = sizes, //
+ [&gen](auto type, size_t size)
+ {
+ using float_type = type_of<decltype(type)>;
+ const double min_prec = 0.000001 * std::log(size) * size;
+
+ for (bool inverse : { false, true })
+ {
+ testo::scope s(inverse ? "complex-inverse" : "complex-direct");
+ univector<complex<float_type>> in =
+ truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
+ univector<complex<float_type>> out = in;
+ univector<complex<float_type>> refout = out;
+ univector<complex<float_type>> outo = in;
+ const dft_plan<float_type> dft(size);
+ double min_prec2 = dft.arblen ? 2 * min_prec : min_prec;
+ if (!inverse)
+ {
#if DEBUG_DFT_PROGRESS
- dft.dump();
+ dft.dump();
#endif
- }
- univector<u8> temp(dft.temp_size);
-
- reference_dft(refout.data(), in.data(), size, inverse);
- dft.execute(outo, in, temp, inverse);
- dft.execute(out, out, temp, inverse);
-
- const float_type rms_diff_inplace = rms(cabs(refout - out));
- CHECK(rms_diff_inplace <= min_prec2);
- const float_type rms_diff_outofplace = rms(cabs(refout - outo));
- CHECK(rms_diff_outofplace <= min_prec2);
- }
-
- if (is_even(size))
- {
- univector<float_type> in =
- truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
-
- univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), size);
- univector<complex<float_type>> refout = truncate(dimensions<1>(scalar(qnan)), size);
- const dft_plan_real<float_type> dft(size);
- univector<u8> temp(dft.temp_size);
-
- testo::scope s("real-direct");
- reference_dft(refout.data(), in.data(), size);
- dft.execute(out, in, temp);
- float_type rms_diff =
- rms(cabs(refout.truncate(size / 2 + 1) - out.truncate(size / 2 + 1)));
- CHECK(rms_diff <= min_prec);
-
- univector<float_type> out2(size, 0.f);
- s.text = "real-inverse";
- dft.execute(out2, out, temp);
- out2 = out2 / size;
- rms_diff = rms(in - out2);
- CHECK(rms_diff <= min_prec);
- }
- });
+ }
+ univector<u8> temp(dft.temp_size);
+
+ reference_dft(refout.data(), in.data(), size, inverse);
+ dft.execute(outo, in, temp, inverse);
+ dft.execute(out, out, temp, inverse);
+
+ const float_type rms_diff_inplace = rms(cabs(refout - out));
+ CHECK(rms_diff_inplace <= min_prec2);
+ const float_type rms_diff_outofplace = rms(cabs(refout - outo));
+ CHECK(rms_diff_outofplace <= min_prec2);
+ }
+
+ if (is_even(size))
+ {
+ index_t csize = dft_plan_real<float_type>::complex_size_for(size);
+ univector<float_type> in = truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
+
+ univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), csize);
+ univector<complex<float_type>> refout = truncate(dimensions<1>(scalar(qnan)), csize);
+ const dft_plan_real<float_type> dft(size);
+ univector<u8> temp(dft.temp_size);
+
+ {
+ testo::scope s("real-direct");
+ reference_dft(refout.data(), in.data(), size);
+ dft.execute(out, in, temp);
+ float_type rms_diff_outofplace = rms(cabs(refout - out));
+ CHECK(rms_diff_outofplace <= min_prec);
+
+ univector<complex<float_type>> outi(csize);
+ outi = padded(make_univector(ptr_cast<complex<float_type>>(in.data()), size / 2),
+ complex<float_type>{ 0.f });
+ dft.execute(outi.data(), ptr_cast<float_type>(outi.data()), temp.data());
+ float_type rms_diff_inplace = rms(cabs(refout - outi.truncate(csize)));
+ CHECK(rms_diff_inplace <= min_prec);
+ }
+
+ {
+ testo::scope s("real-inverse");
+ univector<float_type> out2(size, 0.f);
+ dft.execute(out2, out, temp);
+ out2 = out2 / size;
+ float_type rms_diff_outofplace = rms(in - out2);
+ CHECK(rms_diff_outofplace <= min_prec);
+
+ univector<float_type> outi(2 * csize);
+ outi = make_univector(ptr_cast<float_type>(out.data()), 2 * csize);
+
+ dft.execute(outi.data(), ptr_cast<complex<float_type>>(outi.data()), temp.data());
+ outi = outi / size;
+ float_type rms_diff_inplace = rms(in - outi.truncate(size));
+ CHECK(rms_diff_inplace <= min_prec);
+ }
+ }
+ });
}
TEST(dct)
@@ -319,6 +306,180 @@ TEST(dct)
CHECK(rms(refoutinv - outinv) < 0.00001f);
}
+
+template <typename T, index_t Dims, typename dft_type, typename dft_real_type>
+static void test_dft_md_t(random_state& gen, shape<Dims> shape)
+{
+ index_t size = shape.product();
+ testo::scope s(as_string("shape=", shape));
+
+ const double min_prec = 0.000002 * std::log(size) * size;
+
+ {
+ const dft_type dft(shape);
+#if DEBUG_DFT_PROGRESS
+ dft.dump();
+#endif
+ univector<complex<T>> in = truncate(gen_random_range<T>(gen, -1.0, +1.0), size);
+ for (bool inverse : { false, true })
+ {
+ testo::scope s(inverse ? "complex-inverse" : "complex-direct");
+ univector<complex<T>> out = in;
+ univector<complex<T>> refout = out;
+ univector<complex<T>> outo = in;
+ univector<u8> temp(dft.temp_size);
+
+ reference_dft_md(refout.data(), in.data(), shape, inverse);
+ dft.execute(outo.data(), in.data(), temp.data(), inverse);
+ dft.execute(out.data(), out.data(), temp.data(), inverse);
+
+ const T rms_diff_inplace = rms(cabs(refout - out));
+ CHECK(rms_diff_inplace <= min_prec);
+ const T rms_diff_outofplace = rms(cabs(refout - outo));
+ CHECK(rms_diff_outofplace <= min_prec);
+ }
+ }
+
+ if (is_even(shape.back()))
+ {
+ index_t csize = dft_plan_md_real<float, Dims>::complex_size_for(shape).product();
+ univector<T> in = truncate(gen_random_range<T>(gen, -1.0, +1.0), size);
+
+ univector<complex<T>> out = truncate(dimensions<1>(scalar(qnan)), csize);
+ univector<complex<T>> refout = truncate(dimensions<1>(scalar(qnan)), csize);
+ const dft_real_type dft(shape, true);
+#if DEBUG_DFT_PROGRESS
+ dft.dump();
+#endif
+ univector<u8> temp(dft.temp_size);
+
+ {
+ testo::scope s("real-direct");
+ reference_dft_md(refout.data(), in.data(), shape);
+ dft.execute(out.data(), in.data(), temp.data());
+ T rms_diff_outofplace = rms(cabs(refout - out));
+ CHECK(rms_diff_outofplace <= min_prec);
+
+ univector<complex<T>> outi(csize);
+ outi = padded(make_univector(ptr_cast<complex<T>>(in.data()), size / 2), complex<T>{ 0.f });
+ dft.execute(outi.data(), ptr_cast<T>(outi.data()), temp.data());
+ T rms_diff_inplace = rms(cabs(refout - outi));
+ CHECK(rms_diff_inplace <= min_prec);
+ }
+
+ {
+ testo::scope s("real-inverse");
+ univector<T> out2(dft.real_out_size(), 0.f);
+ dft.execute(out2.data(), out.data(), temp.data());
+ out2 = out2 / size;
+ T rms_diff_outofplace = rms(in - out2.truncate(size));
+ CHECK(rms_diff_outofplace <= min_prec);
+
+ univector<T> outi(2 * csize);
+ outi = make_univector(ptr_cast<T>(out.data()), 2 * csize);
+ dft.execute(outi.data(), ptr_cast<complex<T>>(outi.data()), temp.data());
+ outi = outi / size;
+ T rms_diff_inplace = rms(in - outi.truncate(size));
+ CHECK(rms_diff_inplace <= min_prec);
+ }
+ }
+}
+
+template <typename T, index_t Dims>
+static void test_dft_md(random_state& gen, shape<Dims> shape)
+{
+ {
+ testo::scope s("compile-time dims");
+ test_dft_md_t<T, Dims, dft_plan_md<T, Dims>, dft_plan_md_real<T, Dims>>(gen, shape);
+ }
+ {
+ testo::scope s("runtime dims");
+ test_dft_md_t<T, Dims, dft_plan_md<T, dynamic_shape>, dft_plan_md_real<T, dynamic_shape>>(gen, shape);
+ }
+}
+
+TEST(dft_md)
+{
+ random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
+
+ testo::matrix(named("type") = dft_float_types, //
+ [&gen](auto type)
+ {
+ using float_type = type_of<decltype(type)>;
+ test_dft_md<float_type>(gen, shape{ 120 });
+ test_dft_md<float_type>(gen, shape{ 2, 60 });
+ test_dft_md<float_type>(gen, shape{ 3, 40 });
+ test_dft_md<float_type>(gen, shape{ 4, 30 });
+ test_dft_md<float_type>(gen, shape{ 5, 24 });
+ test_dft_md<float_type>(gen, shape{ 6, 20 });
+ test_dft_md<float_type>(gen, shape{ 8, 15 });
+ test_dft_md<float_type>(gen, shape{ 10, 12 });
+ test_dft_md<float_type>(gen, shape{ 12, 10 });
+ test_dft_md<float_type>(gen, shape{ 15, 8 });
+ test_dft_md<float_type>(gen, shape{ 20, 6 });
+ test_dft_md<float_type>(gen, shape{ 24, 5 });
+ test_dft_md<float_type>(gen, shape{ 30, 4 });
+ test_dft_md<float_type>(gen, shape{ 40, 3 });
+ test_dft_md<float_type>(gen, shape{ 60, 2 });
+
+ test_dft_md<float_type>(gen, shape{ 2, 3, 24 });
+ test_dft_md<float_type>(gen, shape{ 12, 5, 2 });
+ test_dft_md<float_type>(gen, shape{ 5, 12, 2 });
+
+ test_dft_md<float_type>(gen, shape{ 2, 3, 2, 12 });
+ test_dft_md<float_type>(gen, shape{ 3, 4, 5, 2 });
+ test_dft_md<float_type>(gen, shape{ 5, 4, 3, 2 });
+
+ test_dft_md<float_type>(gen, shape{ 5, 2, 2, 3, 2 });
+ test_dft_md<float_type>(gen, shape{ 2, 5, 2, 2, 3 });
+
+ test_dft_md<float_type>(gen, shape{ 1, 120 });
+ test_dft_md<float_type>(gen, shape{ 120, 1 });
+ test_dft_md<float_type>(gen, shape{ 2, 1, 1, 60 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 10, 2, 1, 3 });
+
+ test_dft_md<float_type>(gen, shape{ 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4, 4 });
+ test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 });
+#if defined NDEBUG
+ test_dft_md<float_type>(gen, shape{ 512, 512 });
+ test_dft_md<float_type>(gen, shape{ 32, 32, 32 });
+ test_dft_md<float_type>(gen, shape{ 8, 8, 8, 8 });
+ test_dft_md<float_type>(gen, shape{ 2, 2, 2, 2, 2, 2 });
+
+ test_dft_md<float_type>(gen, shape{ 1, 65536 });
+ test_dft_md<float_type>(gen, shape{ 2, 65536 });
+ test_dft_md<float_type>(gen, shape{ 3, 65536 });
+ test_dft_md<float_type>(gen, shape{ 4, 65536 });
+ test_dft_md<float_type>(gen, shape{ 65536, 1 });
+ test_dft_md<float_type>(gen, shape{ 65536, 2 });
+ test_dft_md<float_type>(gen, shape{ 65536, 3 });
+ test_dft_md<float_type>(gen, shape{ 65536, 4 });
+
+ test_dft_md<float_type>(gen, shape{ 1, 2 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 3 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6, 7 });
+ test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6, 7, 8 });
+ test_dft_md<float_type>(gen, shape{ 2, 1 });
+ test_dft_md<float_type>(gen, shape{ 3, 2, 1 });
+ test_dft_md<float_type>(gen, shape{ 4, 3, 2, 1 });
+ test_dft_md<float_type>(gen, shape{ 5, 4, 3, 2, 1 });
+ test_dft_md<float_type>(gen, shape{ 6, 5, 4, 3, 2, 1 });
+ test_dft_md<float_type>(gen, shape{ 7, 6, 5, 4, 3, 2, 1 });
+ test_dft_md<float_type>(gen, shape{ 8, 7, 6, 5, 4, 3, 2, 1 });
+#endif
+ });
+}
+
} // namespace CMT_ARCH_NAME
#ifndef KFR_NO_MAIN