commit 86d0df769de84a2638ea720421a4d891acaf5740
parent a87cb75207e8860d1252308c6f4c7e5d925ed592
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Wed, 9 Nov 2022 15:45:14 +0000
Fix DFT fallback algorithm tests
Diffstat:
3 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -139,7 +139,7 @@ struct dft_plan
#ifdef KFR_DFT_MULTI
explicit dft_plan(cpu_t cpu, size_t size, dft_order order = dft_order::normal)
- : size(size), temp_size(0), data_size(0)
+ : size(size), temp_size(0), data_size(0), arblen(false)
{
if (cpu == cpu_t::runtime)
cpu = get_cpu();
@@ -168,7 +168,7 @@ struct dft_plan
}
#else
explicit dft_plan(size_t size, dft_order order = dft_order::normal)
- : size(size), temp_size(0), data_size(0)
+ : size(size), temp_size(0), data_size(0), arblen(false)
{
dft_initialize(*this);
}
@@ -217,13 +217,14 @@ struct dft_plan
autofree<u8> data;
size_t data_size;
std::vector<dft_stage_ptr<T>> stages;
+ bool arblen;
protected:
struct noinit
{
};
explicit dft_plan(noinit, size_t size, dft_order order = dft_order::normal)
- : size(size), temp_size(0), data_size(0)
+ : size(size), temp_size(0), data_size(0), arblen(false)
{
}
const complex<T>* select_in(size_t stage, const complex<T>* out, const complex<T>* in,
diff --git a/include/kfr/dft/impl/dft-impl.hpp b/include/kfr/dft/impl/dft-impl.hpp
@@ -515,6 +515,7 @@ void init_dft(dft_plan<T>* self, size_t size, dft_order)
if (cur_size >= 101)
{
add_stage<intrinsics::dft_arblen_stage_impl<T>>(self, size);
+ self->arblen = true;
}
else
{
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -197,6 +197,7 @@ TEST(fft_accuracy)
univector<complex<float_type>> refout = out;
univector<complex<float_type>> outo = in;
const dft_plan<float_type> dft(size);
+ double min_prec2 = dft.arblen ? 2 * min_prec : min_prec;
if (!inverse)
{
#if DEBUG_DFT_PROGRESS
@@ -210,9 +211,9 @@ TEST(fft_accuracy)
dft.execute(out, out, temp, inverse);
const float_type rms_diff_inplace = rms(cabs(refout - out));
- CHECK(rms_diff_inplace < min_prec);
+ CHECK(rms_diff_inplace < min_prec2);
const float_type rms_diff_outofplace = rms(cabs(refout - outo));
- CHECK(rms_diff_outofplace < min_prec);
+ CHECK(rms_diff_outofplace < min_prec2);
}
if (size >= 4 && is_poweroftwo(size))