commit 1481b43e32b3cb95ccc69bb6291d92e4e0dd485d
parent 79e9543e7bdb3110e1efbe1fee5bae381510e831
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Sun, 26 Nov 2023 10:21:40 +0000
read_group
Diffstat:
3 files changed, 50 insertions(+), 57 deletions(-)
diff --git a/include/kfr/dft/impl/bitrev.hpp b/include/kfr/dft/impl/bitrev.hpp
@@ -104,13 +104,13 @@ KFR_INTRINSIC void fft_reorder_swap_two(T* inout, size_t i, size_t j)
constexpr size_t N = 1 << log2n;
constexpr size_t N4 = 2 * N / 4;
- cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i));
- cxx vj = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j));
+ cxx vi = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), N4 / 2);
+ cxx vj = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), N4 / 2);
vi = digitreverse<bitrev, 2>(vi);
- cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vi);
+ cwrite_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), N4 / 2, vi);
vj = digitreverse<bitrev, 2>(vj);
- cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), vj);
+ cwrite_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), N4 / 2, vj);
}
template <size_t log2n, size_t bitrev, typename T>
@@ -121,13 +121,13 @@ KFR_INTRINSIC void fft_reorder_swap(T* inout, size_t i, size_t j)
constexpr size_t N = 1 << log2n;
constexpr size_t N4 = 2 * N / 4;
- cxx vi = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i));
- cxx vj = cread_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j));
+ cxx vi = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), N4 / 2);
+ cxx vj = cread_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), N4 / 2);
vi = digitreverse<bitrev, 2>(vi);
- cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), vi);
+ cwrite_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + j), N4 / 2, vi);
vj = digitreverse<bitrev, 2>(vj);
- cwrite_group<4, 4, N4 / 2, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), vj);
+ cwrite_group<4, 4, fft_reorder_aligned>(ptr_cast<complex<T>>(inout + i), N4 / 2, vj);
}
template <size_t log2n, size_t bitrev, typename T>
diff --git a/include/kfr/dft/impl/ft.hpp b/include/kfr/dft/impl/ft.hpp
@@ -128,51 +128,16 @@ KFR_INTRINSIC void cwrite(complex<T>* dest, const cvec<T, N>& value)
value.write(ptr_cast<T>(dest), cbool_t<A>());
}
-template <size_t count, size_t N, size_t stride, bool A, typename T, size_t... indices>
-KFR_INTRINSIC cvec<T, count * N> cread_group_impl(const complex<T>* src, csizes_t<indices...>)
-{
- return concat(read(cbool<A>, csize<N * 2>, ptr_cast<T>(src + stride * indices))...);
-}
-template <size_t count, size_t N, size_t stride, bool A, typename T, size_t... indices>
-KFR_INTRINSIC void cwrite_group_impl(complex<T>* dest, const cvec<T, count * N>& value, csizes_t<indices...>)
-{
- swallow{ (write<A>(ptr_cast<T>(dest + stride * indices), slice<indices * N * 2, N * 2>(value)), 0)... };
-}
-
-template <size_t count, size_t N, bool A, typename T, size_t... indices>
-KFR_INTRINSIC cvec<T, count * N> cread_group_impl(const complex<T>* src, size_t stride, csizes_t<indices...>)
-{
- return concat(read(cbool<A>, csize<N * 2>, ptr_cast<T>(src + stride * indices))...);
-}
-template <size_t count, size_t N, bool A, typename T, size_t... indices>
-KFR_INTRINSIC void cwrite_group_impl(complex<T>* dest, size_t stride, const cvec<T, count * N>& value,
- csizes_t<indices...>)
-{
- swallow{ (write<A>(ptr_cast<T>(dest + stride * indices), slice<indices * N * 2, N * 2>(value)), 0)... };
-}
-
-template <size_t count, size_t N, size_t stride, bool A = false, typename T>
-KFR_INTRINSIC cvec<T, count * N> cread_group(const complex<T>* src)
-{
- return cread_group_impl<count, N, stride, A>(src, csizeseq_t<count>());
-}
-
-template <size_t count, size_t N, size_t stride, bool A = false, typename T>
-KFR_INTRINSIC void cwrite_group(complex<T>* dest, const cvec<T, count * N>& value)
-{
- return cwrite_group_impl<count, N, stride, A>(dest, value, csizeseq_t<count>());
-}
-
template <size_t count, size_t N, bool A = false, typename T>
KFR_INTRINSIC cvec<T, count * N> cread_group(const complex<T>* src, size_t stride)
{
- return cread_group_impl<count, N, A>(src, stride, csizeseq_t<count>());
+ return internal::read_group_impl<2, count, N, A>(ptr_cast<T>(src), stride, csizeseq_t<count>());
}
template <size_t count, size_t N, bool A = false, typename T>
KFR_INTRINSIC void cwrite_group(complex<T>* dest, size_t stride, const cvec<T, count * N>& value)
{
- return cwrite_group_impl<count, N, A>(dest, stride, value, csizeseq_t<count>());
+ return internal::write_group_impl<2, count, N, A>(ptr_cast<T>(dest), stride, value, csizeseq_t<count>());
}
template <size_t N, bool A = false, bool split = false, typename T>
@@ -844,23 +809,23 @@ KFR_INTRINSIC void butterfly64_memory(cbool_t<inverse>, cbool_t<aligned>, comple
{
cvec<T, 16> w0, w1, w2, w3;
- w0 = cread_group<4, 4, 16, aligned>(
- in); // concat(cread<4>(in + 0), cread<4>(in + 16), cread<4>(in + 32), cread<4>(in + 48));
+ w0 = cread_group<4, 4, aligned>(
+ in, 16); // concat(cread<4>(in + 0), cread<4>(in + 16), cread<4>(in + 32), cread<4>(in + 48));
butterfly4_packed<4, inverse>(w0);
apply_twiddles4<0, 1, 4, inverse>(w0);
- w1 = cread_group<4, 4, 16, aligned>(
- in + 4); // concat(cread<4>(in + 4), cread<4>(in + 20), cread<4>(in + 36), cread<4>(in + 52));
+ w1 = cread_group<4, 4, aligned>(
+ in + 4, 16); // concat(cread<4>(in + 4), cread<4>(in + 20), cread<4>(in + 36), cread<4>(in + 52));
butterfly4_packed<4, inverse>(w1);
apply_twiddles4<4, 1, 4, inverse>(w1);
- w2 = cread_group<4, 4, 16, aligned>(
- in + 8); // concat(cread<4>(in + 8), cread<4>(in + 24), cread<4>(in + 40), cread<4>(in + 56));
+ w2 = cread_group<4, 4, aligned>(
+ in + 8, 16); // concat(cread<4>(in + 8), cread<4>(in + 24), cread<4>(in + 40), cread<4>(in + 56));
butterfly4_packed<4, inverse>(w2);
apply_twiddles4<8, 1, 4, inverse>(w2);
- w3 = cread_group<4, 4, 16, aligned>(
- in + 12); // concat(cread<4>(in + 12), cread<4>(in + 28), cread<4>(in + 44), cread<4>(in + 60));
+ w3 = cread_group<4, 4, aligned>(
+ in + 12, 16); // concat(cread<4>(in + 12), cread<4>(in + 28), cread<4>(in + 44), cread<4>(in + 60));
butterfly4_packed<4, inverse>(w3);
apply_twiddles4<12, 1, 4, inverse>(w3);
@@ -883,16 +848,16 @@ KFR_INTRINSIC void butterfly64_memory(cbool_t<inverse>, cbool_t<aligned>, comple
// pass 3:
butterfly4_packed<4, inverse>(w3);
- cwrite_group<4, 4, 16, aligned>(out + 12, w3); // split(w3, out[3], out[7], out[11], out[15]);
+ cwrite_group<4, 4, aligned>(out + 12, 16, w3); // split(w3, out[3], out[7], out[11], out[15]);
butterfly4_packed<4, inverse>(w2);
- cwrite_group<4, 4, 16, aligned>(out + 8, w2); // split(w2, out[2], out[6], out[10], out[14]);
+ cwrite_group<4, 4, aligned>(out + 8, 16, w2); // split(w2, out[2], out[6], out[10], out[14]);
butterfly4_packed<4, inverse>(w1);
- cwrite_group<4, 4, 16, aligned>(out + 4, w1); // split(w1, out[1], out[5], out[9], out[13]);
+ cwrite_group<4, 4, aligned>(out + 4, 16, w1); // split(w1, out[1], out[5], out[9], out[13]);
butterfly4_packed<4, inverse>(w0);
- cwrite_group<4, 4, 16, aligned>(out, w0); // split(w0, out[0], out[4], out[8], out[12]);
+ cwrite_group<4, 4, aligned>(out, 16, w0); // split(w0, out[0], out[4], out[8], out[12]);
}
template <bool inverse = false, typename T>
diff --git a/include/kfr/simd/read_write.hpp b/include/kfr/simd/read_write.hpp
@@ -46,6 +46,34 @@ KFR_INTRINSIC void write(T* dest, const vec<T, N>& value)
intrinsics::write(cbool<A>, ptr_cast<deep_subtype<T>>(dest), value.flatten());
}
+namespace internal
+{
+template <size_t group, size_t count, size_t N, bool A, typename T, size_t... indices>
+KFR_INTRINSIC vec<T, group * count * N> read_group_impl(const T* src, size_t stride, csizes_t<indices...>)
+{
+ return concat(intrinsics::read(cbool<A>, csize<N * group>, src + group * stride * indices)...);
+}
+template <size_t group, size_t count, size_t N, bool A, typename T, size_t... indices>
+KFR_INTRINSIC void write_group_impl(T* dest, size_t stride, const vec<T, group * count * N>& value,
+ csizes_t<indices...>)
+{
+ swallow{ (write<A>(dest + group * stride * indices, slice<group * indices * N, group * N>(value)),
+ 0)... };
+}
+} // namespace internal
+
+template <size_t count, size_t N, size_t group = 1, bool A = false, typename T>
+KFR_INTRINSIC vec<T, group * count * N> read_group(const T* src, size_t stride)
+{
+ return internal::read_group_impl<group, count, N, A>(ptr_cast<T>(src), stride, csizeseq_t<count>());
+}
+
+template <size_t count, size_t N, size_t group = 1, bool A = false, typename T>
+KFR_INTRINSIC void write_group(T* dest, size_t stride, const vec<T, group * count * N>& value)
+{
+ return internal::write_group_impl<group, count, N, A>(dest, stride, value, csizeseq_t<count>());
+}
+
template <typename... Indices, typename T, size_t Nout = 1 + sizeof...(Indices)>
KFR_INTRINSIC vec<T, Nout> gather(const T* base, size_t index, Indices... indices)
{