kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit a0a3ee4ca2258f2254e38e7cbe41c63de9f32295
parent f71d5430d6edfb99c346e7b1d020cffe00a48244
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Tue, 16 Jan 2024 23:24:33 +0000

Multidimensional reference DFT

Diffstat:
Minclude/kfr/dft/reference_dft.hpp | 221++++++++++++++++++++++++++++++++++++++++---------------------------------------
1 file changed, 111 insertions(+), 110 deletions(-)

diff --git a/include/kfr/dft/reference_dft.hpp b/include/kfr/dft/reference_dft.hpp @@ -38,53 +38,43 @@ namespace kfr { -template <typename Tnumber = double> -void reference_fft_pass(Tnumber pi2, size_t N, size_t offset, size_t delta, int flag, Tnumber (*x)[2], - Tnumber (*X)[2], Tnumber (*XX)[2]) +namespace internal_generic { - KFR_LOGIC_CHECK(N >= 2, "reference_fft_pass: invalid N"); - const size_t N2 = N / 2; - using std::cos; - using std::sin; + +template <typename T> +void reference_dft_po2_pass(size_t N, int flag, const complex<T>* in, complex<T>* out, complex<T>* scratch, + size_t in_delta = 1, size_t out_delta = 1, size_t scratch_delta = 1) +{ + const T pi2 = c_pi<T, 2, 1>; + const size_t N2 = N / 2; + const complex<T> w = pi2 * complex<T>{ 0, -T(flag) }; if (N != 2) { - reference_fft_pass(pi2, N2, offset, 2 * delta, flag, x, XX, X); - reference_fft_pass(pi2, N2, offset + delta, 2 * delta, flag, x, XX, X); + reference_dft_po2_pass(N2, flag, in, scratch, out, 2 * in_delta, 2 * scratch_delta, 2 * out_delta); + reference_dft_po2_pass(N2, flag, in + in_delta, scratch + scratch_delta, out + out_delta, + 2 * in_delta, 2 * scratch_delta, 2 * out_delta); for (size_t k = 0; k < N2; k++) { - const size_t k00 = offset + k * delta; - const size_t k01 = k00 + N2 * delta; - const size_t k10 = offset + 2 * k * delta; - const size_t k11 = k10 + delta; - const Tnumber m = static_cast<Tnumber>(k) / N; - const Tnumber cs = cos(pi2 * m); - const Tnumber sn = flag * sin(pi2 * m); - const Tnumber tmp0 = cs * XX[k11][0] + sn * XX[k11][1]; - const Tnumber tmp1 = cs * XX[k11][1] - sn * XX[k11][0]; - X[k01][0] = XX[k10][0] - tmp0; - X[k01][1] = XX[k10][1] - tmp1; - X[k00][0] = XX[k10][0] + tmp0; - X[k00][1] = XX[k10][1] + tmp1; + const T m = static_cast<T>(k) / N; + const complex<T> tw = std::exp(w * m); + const complex<T> tmp = scratch[(2 * k + 1) * scratch_delta] * tw; + out[(k + N2) * out_delta] = scratch[(2 * k) * scratch_delta] - tmp; + out[(k)*out_delta] = scratch[(2 * k) * scratch_delta] + tmp; } } else { - const size_t k00 = offset; - const size_t k01 = k00 + delta; - X[k01][0] = x[k00][0] - x[k01][0]; - X[k01][1] = x[k00][1] - x[k01][1]; - X[k00][0] = x[k00][0] + x[k01][0]; - X[k00][1] = x[k00][1] + x[k01][1]; + out[out_delta] = in[0] - in[in_delta]; + out[0] = in[0] + in[in_delta]; } } -/// @brief Performs Complex FFT using reference implementation (slow, used for testing) -template <typename Tnumber = double, typename T> -void reference_fft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false) +template <typename T> +void reference_dft_po2(complex<T>* out, const complex<T>* in, size_t size, bool inversion, + size_t out_delta = 1, size_t in_delta = 1) { - using Tcmplx = Tnumber(*)[2]; if (size < 1) return; if (size == 1) @@ -92,114 +82,125 @@ void reference_fft(complex<T>* out, const complex<T>* in, size_t size, bool inve out[0] = in[0]; return; } - std::vector<complex<Tnumber>> datain(size); - std::vector<complex<Tnumber>> dataout(size); - std::vector<complex<Tnumber>> temp(size); - std::copy(in, in + size, datain.begin()); - const Tnumber pi2 = c_pi<Tnumber, 2, 1>; - reference_fft_pass<Tnumber>(pi2, size, 0, 1, inversion ? -1 : +1, Tcmplx(datain.data()), - Tcmplx(dataout.data()), Tcmplx(temp.data())); - std::copy(dataout.begin(), dataout.end(), out); + std::vector<complex<T>> temp(size); + reference_dft_po2_pass(size, inversion ? -1 : +1, in, out, temp.data(), in_delta, out_delta, 1); } -/// @brief Performs Complex DFT using reference implementation (slow, used for testing) -template <typename Tnumber = double, typename T> -void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false) +/// @brief Performs Complex FFT using reference implementation (slow, used for testing) +template <typename T> +void reference_dft_nonpo2(complex<T>* out, const complex<T>* in, size_t size, bool inversion, + size_t out_delta = 1, size_t in_delta = 1) { - using std::cos; - using std::sin; - if (is_poweroftwo(size)) - { - return reference_fft<Tnumber>(out, in, size, inversion); - } - constexpr Tnumber pi2 = c_pi<Tnumber, 2>; + constexpr T pi2 = c_pi<T, 2>; + const complex<T> w = pi2 * complex<T>{ 0, T(inversion ? +1 : -1) }; if (size < 2) return; - std::vector<complex<T>> datain; - if (out == in) { - datain.resize(size); - std::copy_n(in, size, datain.begin()); - in = datain.data(); - } - { - Tnumber sumr = 0; - Tnumber sumi = 0; + complex<T> sum = 0; for (size_t j = 0; j < size; j++) - { - sumr += static_cast<Tnumber>(in[j].real()); - sumi += static_cast<Tnumber>(in[j].imag()); - } - out[0] = { static_cast<T>(sumr), static_cast<T>(sumi) }; + sum += in[j * in_delta]; + out[0] = sum; } for (size_t i = 1; i < size; i++) { - Tnumber sumr = static_cast<Tnumber>(in[0].real()); - Tnumber sumi = static_cast<Tnumber>(in[0].imag()); - + complex<T> sum = in[0]; for (size_t j = 1; j < size; j++) { - const Tnumber x = pi2 * ((i * j) % size) / size; - Tnumber twr = cos(x); - Tnumber twi = sin(x); - if (inversion) - twi = -twi; - - sumr += twr * static_cast<Tnumber>(in[j].real()) + twi * static_cast<Tnumber>(in[j].imag()); - sumi += twr * static_cast<Tnumber>(in[j].imag()) - twi * static_cast<Tnumber>(in[j].real()); - out[i] = { static_cast<T>(sumr), static_cast<T>(sumi) }; + complex<T> tw = std::exp(w * (static_cast<T>(i) * j / size)); + sum += tw * in[j * in_delta]; } + out[i * out_delta] = sum; } } +} // namespace internal_generic -/// @brief Performs Direct Real DFT using reference implementation (slow, used for testing) +/// @brief Performs Complex DFT using reference implementation (slow, used for testing) template <typename T> -void reference_dft(complex<T>* out, const T* in, size_t size) +void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false, + size_t out_delta = 1, size_t in_delta = 1) { - if (size < 1) - return; - std::vector<complex<T>> datain(size); - std::copy(in, in + size, datain.begin()); - reference_dft(out, datain.data(), size, false); + if (in == out) + { + std::vector<complex<T>> tmpin(size); + for (int i = 0; i < size; ++i) + tmpin[i] = in[i * in_delta]; + return reference_dft(out, tmpin.data(), size, inversion, out_delta, 1); + } + if (is_poweroftwo(size)) + { + return internal_generic::reference_dft_po2(out, in, size, inversion, out_delta, in_delta); + } + else + { + return internal_generic::reference_dft_nonpo2(out, in, size, inversion, out_delta, in_delta); + } } -/// @brief Performs Inverse Real DFT using reference implementation (slow, used for testing) +/// @brief Performs Direct Real DFT using reference implementation (slow, used for testing) template <typename T> -void reference_dft(T* out, const complex<T>* in, size_t size) +void reference_dft(complex<T>* out, const T* in, size_t size, size_t out_delta = 1, size_t in_delta = 1) { if (size < 1) return; - std::vector<complex<T>> dataout(size); - reference_dft(dataout.data(), in, size, true); - for (size_t i = 0; i < size; i++) - out[i] = dataout[i].real(); -} - -/// @brief Performs DFT using reference implementation (slow, used for testing) -template <typename Tnumber = double, typename T> -inline univector<complex<T>> reference_dft(const univector<complex<T>>& in, bool inversion = false) -{ - univector<complex<T>> out(in.size()); - reference_dft(&out[0], &in[0], in.size(), inversion); - return out; + std::vector<complex<T>> tmpin(size); + for (index_t i = 0; i < size; ++i) + tmpin[i] = in[i * in_delta]; + std::vector<complex<T>> tmpout(size); + reference_dft(tmpout.data(), tmpin.data(), size, false, 1, 1); + for (index_t i = 0; i < size / 2 + 1; i++) + out[i * out_delta] = tmpout[i]; } +/// @brief Performs Multidimensional Complex DFT using reference implementation (slow, used for testing) template <typename T> -struct reference_dft_plan +void reference_dft_md(complex<T>* out, const complex<T>* in, shape<dynamic_shape> size, + bool inversion = false, size_t out_delta = 1, size_t in_delta = 1) { - reference_dft_plan(size_t size) : size(size) {} - void execute(complex<T>* out, const complex<T>* in, u8*, bool inverse = false) const + index_t total = size.product(); + if (total < 1) + return; + if (total == 1) { - reference_dft(out, in, size, inverse); + out[0] = in[0]; + return; } - - template <size_t N, size_t N2> - void execute(univector<complex<T>, N>& out, const univector<const complex<T>, N>& in, univector<u8, N2>&, - bool inverse = false) const + index_t inner = 1; + index_t outer = total; + for (int axis = size.dims() - 1; axis >= 0; --axis) { - this->execute(out.data(), in.data(), nullptr, inverse); + index_t d = size[axis]; + outer /= d; + for (index_t o = 0; o < outer; ++o) + { + for (index_t i = 0; i < inner; ++i) + { + reference_dft(out + (i + o * inner * d) * out_delta, in + (i + o * inner * d) * in_delta, d, + inversion, out_delta * inner, in_delta * inner); + } + } + in = out; + in_delta = out_delta; + inner *= d; } - static constexpr size_t temp_size = 0; - const size_t size; -}; +} + +/// @brief Performs Multidimensional Direct Real DFT using reference implementation (slow, used for testing) +template <typename T> +void reference_dft_md(complex<T>* out, const T* in, shape<dynamic_shape> shape, bool inversion = false, + size_t out_delta = 1, size_t in_delta = 1) +{ + index_t size = shape.product(); + if (size < 1) + return; + std::vector<complex<T>> tmpin(size); + for (index_t i = 0; i < size; ++i) + tmpin[i] = in[i * in_delta]; + std::vector<complex<T>> tmpout(size); + reference_dft_md(tmpout.data(), tmpin.data(), shape, inversion, 1, 1); + index_t last = shape.back() / 2 + 1; + for (index_t i = 0; i < std::max(index_t(1), shape.remove_back().product()); ++i) + for (index_t j = 0; j < last; j++) + out[(i * last + j) * out_delta] = tmpout[i * shape.back() + j]; +} + } // namespace kfr