commit a0a3ee4ca2258f2254e38e7cbe41c63de9f32295
parent f71d5430d6edfb99c346e7b1d020cffe00a48244
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Tue, 16 Jan 2024 23:24:33 +0000
Multidimensional reference DFT
Diffstat:
1 file changed, 111 insertions(+), 110 deletions(-)
diff --git a/include/kfr/dft/reference_dft.hpp b/include/kfr/dft/reference_dft.hpp
@@ -38,53 +38,43 @@
namespace kfr
{
-template <typename Tnumber = double>
-void reference_fft_pass(Tnumber pi2, size_t N, size_t offset, size_t delta, int flag, Tnumber (*x)[2],
- Tnumber (*X)[2], Tnumber (*XX)[2])
+namespace internal_generic
{
- KFR_LOGIC_CHECK(N >= 2, "reference_fft_pass: invalid N");
- const size_t N2 = N / 2;
- using std::cos;
- using std::sin;
+
+template <typename T>
+void reference_dft_po2_pass(size_t N, int flag, const complex<T>* in, complex<T>* out, complex<T>* scratch,
+ size_t in_delta = 1, size_t out_delta = 1, size_t scratch_delta = 1)
+{
+ const T pi2 = c_pi<T, 2, 1>;
+ const size_t N2 = N / 2;
+ const complex<T> w = pi2 * complex<T>{ 0, -T(flag) };
if (N != 2)
{
- reference_fft_pass(pi2, N2, offset, 2 * delta, flag, x, XX, X);
- reference_fft_pass(pi2, N2, offset + delta, 2 * delta, flag, x, XX, X);
+ reference_dft_po2_pass(N2, flag, in, scratch, out, 2 * in_delta, 2 * scratch_delta, 2 * out_delta);
+ reference_dft_po2_pass(N2, flag, in + in_delta, scratch + scratch_delta, out + out_delta,
+ 2 * in_delta, 2 * scratch_delta, 2 * out_delta);
for (size_t k = 0; k < N2; k++)
{
- const size_t k00 = offset + k * delta;
- const size_t k01 = k00 + N2 * delta;
- const size_t k10 = offset + 2 * k * delta;
- const size_t k11 = k10 + delta;
- const Tnumber m = static_cast<Tnumber>(k) / N;
- const Tnumber cs = cos(pi2 * m);
- const Tnumber sn = flag * sin(pi2 * m);
- const Tnumber tmp0 = cs * XX[k11][0] + sn * XX[k11][1];
- const Tnumber tmp1 = cs * XX[k11][1] - sn * XX[k11][0];
- X[k01][0] = XX[k10][0] - tmp0;
- X[k01][1] = XX[k10][1] - tmp1;
- X[k00][0] = XX[k10][0] + tmp0;
- X[k00][1] = XX[k10][1] + tmp1;
+ const T m = static_cast<T>(k) / N;
+ const complex<T> tw = std::exp(w * m);
+ const complex<T> tmp = scratch[(2 * k + 1) * scratch_delta] * tw;
+ out[(k + N2) * out_delta] = scratch[(2 * k) * scratch_delta] - tmp;
+ out[(k)*out_delta] = scratch[(2 * k) * scratch_delta] + tmp;
}
}
else
{
- const size_t k00 = offset;
- const size_t k01 = k00 + delta;
- X[k01][0] = x[k00][0] - x[k01][0];
- X[k01][1] = x[k00][1] - x[k01][1];
- X[k00][0] = x[k00][0] + x[k01][0];
- X[k00][1] = x[k00][1] + x[k01][1];
+ out[out_delta] = in[0] - in[in_delta];
+ out[0] = in[0] + in[in_delta];
}
}
-/// @brief Performs Complex FFT using reference implementation (slow, used for testing)
-template <typename Tnumber = double, typename T>
-void reference_fft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false)
+template <typename T>
+void reference_dft_po2(complex<T>* out, const complex<T>* in, size_t size, bool inversion,
+ size_t out_delta = 1, size_t in_delta = 1)
{
- using Tcmplx = Tnumber(*)[2];
if (size < 1)
return;
if (size == 1)
@@ -92,114 +82,125 @@ void reference_fft(complex<T>* out, const complex<T>* in, size_t size, bool inve
out[0] = in[0];
return;
}
- std::vector<complex<Tnumber>> datain(size);
- std::vector<complex<Tnumber>> dataout(size);
- std::vector<complex<Tnumber>> temp(size);
- std::copy(in, in + size, datain.begin());
- const Tnumber pi2 = c_pi<Tnumber, 2, 1>;
- reference_fft_pass<Tnumber>(pi2, size, 0, 1, inversion ? -1 : +1, Tcmplx(datain.data()),
- Tcmplx(dataout.data()), Tcmplx(temp.data()));
- std::copy(dataout.begin(), dataout.end(), out);
+ std::vector<complex<T>> temp(size);
+ reference_dft_po2_pass(size, inversion ? -1 : +1, in, out, temp.data(), in_delta, out_delta, 1);
}
-/// @brief Performs Complex DFT using reference implementation (slow, used for testing)
-template <typename Tnumber = double, typename T>
-void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false)
+/// @brief Performs Complex FFT using reference implementation (slow, used for testing)
+template <typename T>
+void reference_dft_nonpo2(complex<T>* out, const complex<T>* in, size_t size, bool inversion,
+ size_t out_delta = 1, size_t in_delta = 1)
{
- using std::cos;
- using std::sin;
- if (is_poweroftwo(size))
- {
- return reference_fft<Tnumber>(out, in, size, inversion);
- }
- constexpr Tnumber pi2 = c_pi<Tnumber, 2>;
+ constexpr T pi2 = c_pi<T, 2>;
+ const complex<T> w = pi2 * complex<T>{ 0, T(inversion ? +1 : -1) };
if (size < 2)
return;
- std::vector<complex<T>> datain;
- if (out == in)
{
- datain.resize(size);
- std::copy_n(in, size, datain.begin());
- in = datain.data();
- }
- {
- Tnumber sumr = 0;
- Tnumber sumi = 0;
+ complex<T> sum = 0;
for (size_t j = 0; j < size; j++)
- {
- sumr += static_cast<Tnumber>(in[j].real());
- sumi += static_cast<Tnumber>(in[j].imag());
- }
- out[0] = { static_cast<T>(sumr), static_cast<T>(sumi) };
+ sum += in[j * in_delta];
+ out[0] = sum;
}
for (size_t i = 1; i < size; i++)
{
- Tnumber sumr = static_cast<Tnumber>(in[0].real());
- Tnumber sumi = static_cast<Tnumber>(in[0].imag());
-
+ complex<T> sum = in[0];
for (size_t j = 1; j < size; j++)
{
- const Tnumber x = pi2 * ((i * j) % size) / size;
- Tnumber twr = cos(x);
- Tnumber twi = sin(x);
- if (inversion)
- twi = -twi;
-
- sumr += twr * static_cast<Tnumber>(in[j].real()) + twi * static_cast<Tnumber>(in[j].imag());
- sumi += twr * static_cast<Tnumber>(in[j].imag()) - twi * static_cast<Tnumber>(in[j].real());
- out[i] = { static_cast<T>(sumr), static_cast<T>(sumi) };
+ complex<T> tw = std::exp(w * (static_cast<T>(i) * j / size));
+ sum += tw * in[j * in_delta];
}
+ out[i * out_delta] = sum;
}
}
+} // namespace internal_generic
-/// @brief Performs Direct Real DFT using reference implementation (slow, used for testing)
+/// @brief Performs Complex DFT using reference implementation (slow, used for testing)
template <typename T>
-void reference_dft(complex<T>* out, const T* in, size_t size)
+void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false,
+ size_t out_delta = 1, size_t in_delta = 1)
{
- if (size < 1)
- return;
- std::vector<complex<T>> datain(size);
- std::copy(in, in + size, datain.begin());
- reference_dft(out, datain.data(), size, false);
+ if (in == out)
+ {
+ std::vector<complex<T>> tmpin(size);
+ for (int i = 0; i < size; ++i)
+ tmpin[i] = in[i * in_delta];
+ return reference_dft(out, tmpin.data(), size, inversion, out_delta, 1);
+ }
+ if (is_poweroftwo(size))
+ {
+ return internal_generic::reference_dft_po2(out, in, size, inversion, out_delta, in_delta);
+ }
+ else
+ {
+ return internal_generic::reference_dft_nonpo2(out, in, size, inversion, out_delta, in_delta);
+ }
}
-/// @brief Performs Inverse Real DFT using reference implementation (slow, used for testing)
+/// @brief Performs Direct Real DFT using reference implementation (slow, used for testing)
template <typename T>
-void reference_dft(T* out, const complex<T>* in, size_t size)
+void reference_dft(complex<T>* out, const T* in, size_t size, size_t out_delta = 1, size_t in_delta = 1)
{
if (size < 1)
return;
- std::vector<complex<T>> dataout(size);
- reference_dft(dataout.data(), in, size, true);
- for (size_t i = 0; i < size; i++)
- out[i] = dataout[i].real();
-}
-
-/// @brief Performs DFT using reference implementation (slow, used for testing)
-template <typename Tnumber = double, typename T>
-inline univector<complex<T>> reference_dft(const univector<complex<T>>& in, bool inversion = false)
-{
- univector<complex<T>> out(in.size());
- reference_dft(&out[0], &in[0], in.size(), inversion);
- return out;
+ std::vector<complex<T>> tmpin(size);
+ for (index_t i = 0; i < size; ++i)
+ tmpin[i] = in[i * in_delta];
+ std::vector<complex<T>> tmpout(size);
+ reference_dft(tmpout.data(), tmpin.data(), size, false, 1, 1);
+ for (index_t i = 0; i < size / 2 + 1; i++)
+ out[i * out_delta] = tmpout[i];
}
+/// @brief Performs Multidimensional Complex DFT using reference implementation (slow, used for testing)
template <typename T>
-struct reference_dft_plan
+void reference_dft_md(complex<T>* out, const complex<T>* in, shape<dynamic_shape> size,
+ bool inversion = false, size_t out_delta = 1, size_t in_delta = 1)
{
- reference_dft_plan(size_t size) : size(size) {}
- void execute(complex<T>* out, const complex<T>* in, u8*, bool inverse = false) const
+ index_t total = size.product();
+ if (total < 1)
+ return;
+ if (total == 1)
{
- reference_dft(out, in, size, inverse);
+ out[0] = in[0];
+ return;
}
-
- template <size_t N, size_t N2>
- void execute(univector<complex<T>, N>& out, const univector<const complex<T>, N>& in, univector<u8, N2>&,
- bool inverse = false) const
+ index_t inner = 1;
+ index_t outer = total;
+ for (int axis = size.dims() - 1; axis >= 0; --axis)
{
- this->execute(out.data(), in.data(), nullptr, inverse);
+ index_t d = size[axis];
+ outer /= d;
+ for (index_t o = 0; o < outer; ++o)
+ {
+ for (index_t i = 0; i < inner; ++i)
+ {
+ reference_dft(out + (i + o * inner * d) * out_delta, in + (i + o * inner * d) * in_delta, d,
+ inversion, out_delta * inner, in_delta * inner);
+ }
+ }
+ in = out;
+ in_delta = out_delta;
+ inner *= d;
}
- static constexpr size_t temp_size = 0;
- const size_t size;
-};
+}
+
+/// @brief Performs Multidimensional Direct Real DFT using reference implementation (slow, used for testing)
+template <typename T>
+void reference_dft_md(complex<T>* out, const T* in, shape<dynamic_shape> shape, bool inversion = false,
+ size_t out_delta = 1, size_t in_delta = 1)
+{
+ index_t size = shape.product();
+ if (size < 1)
+ return;
+ std::vector<complex<T>> tmpin(size);
+ for (index_t i = 0; i < size; ++i)
+ tmpin[i] = in[i * in_delta];
+ std::vector<complex<T>> tmpout(size);
+ reference_dft_md(tmpout.data(), tmpin.data(), shape, inversion, 1, 1);
+ index_t last = shape.back() / 2 + 1;
+ for (index_t i = 0; i < std::max(index_t(1), shape.remove_back().product()); ++i)
+ for (index_t j = 0; j < last; j++)
+ out[(i * last + j) * out_delta] = tmpout[i * shape.back() + j];
+}
+
} // namespace kfr