commit e6d7b12e3ad1aeca953ab548100740e1ab63550e
parent e77e2ee66be5abf5bbc273add81504f0e46a3c21
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Thu, 14 Jul 2022 19:40:27 +0100
Fix gather_helper and scatter_stride
Diffstat:
1 file changed, 32 insertions(+), 10 deletions(-)
diff --git a/include/kfr/simd/read_write.hpp b/include/kfr/simd/read_write.hpp
@@ -79,7 +79,7 @@ KFR_INTRINSIC vec<T, Nout> gather_stride(const T* base, csizes_t<Indices...>)
template <size_t Nout, size_t groupsize, typename T, size_t... Indices>
KFR_INTRINSIC vec<T, Nout> gather_stride_s(const T* base, size_t stride, csizes_t<Indices...>)
{
- return make_vector(read<groupsize>(base + Indices * groupsize * stride)...);
+ return concat(read<groupsize>(base + Indices * groupsize * stride)...);
}
} // namespace internal
@@ -92,7 +92,15 @@ KFR_INTRINSIC vec<T, N> gather(const T* base, const vec<u32, N>& indices)
template <size_t Nout, size_t groupsize = 1, typename T>
KFR_INTRINSIC vec<T, Nout * groupsize> gather_stride(const T* base, size_t stride)
{
- return internal::gather_stride_s<Nout, groupsize>(base, stride, csizeseq<Nout>);
+ if constexpr (Nout > 2)
+ {
+ constexpr size_t Nlow = prev_poweroftwo(Nout - 1);
+ return concat(
+ internal::gather_stride_s<Nlow, groupsize>(base, stride, csizeseq<Nlow>),
+ internal::gather_stride_s<Nout - Nlow, groupsize>(base + Nlow, stride, csizeseq<Nout - Nlow>));
+ }
+ else
+ return internal::gather_stride_s<Nout, groupsize>(base, stride, csizeseq<Nout>);
}
template <size_t Nout, size_t Stride, typename T>
@@ -105,7 +113,7 @@ template <size_t groupsize, typename T, size_t N, typename IT, size_t... Indices
KFR_INTRINSIC vec<T, N * groupsize> gather_helper(const T* base, const vec<IT, N>& offset,
csizes_t<Indices...>)
{
- return concat(read<groupsize>(base + groupsize * (*offset)[Indices])...);
+ return concat(read<groupsize>(base + groupsize * offset[Indices])...);
}
template <size_t groupsize = 1, typename T, size_t N, typename IT>
KFR_INTRINSIC vec<T, N * groupsize> gather(const T* base, const vec<IT, N>& offset)
@@ -113,28 +121,42 @@ KFR_INTRINSIC vec<T, N * groupsize> gather(const T* base, const vec<IT, N>& offs
return gather_helper<groupsize>(base, offset, csizeseq<N>);
}
+namespace internal
+{
template <size_t groupsize, typename T, size_t N, size_t Nout = N* groupsize, typename IT, size_t... Indices>
KFR_INTRINSIC void scatter_helper(T* base, const vec<IT, N>& offset, const vec<T, Nout>& value,
csizes_t<Indices...>)
{
- swallow{ (write(base + groupsize * (*offset)[Indices], slice<Indices * groupsize, groupsize>(value)),
+ swallow{ (write(base + groupsize * offset[Indices], slice<Indices * groupsize, groupsize>(value)),
0)... };
}
-template <size_t groupsize, typename T, size_t N, size_t Nout = N* groupsize, size_t... Indices>
-KFR_INTRINSIC void scatter_helper_s(T* base, size_t stride, const vec<T, Nout>& value, csizes_t<Indices...>)
+template <size_t groupsize, typename T, size_t N, size_t... Indices>
+KFR_INTRINSIC void scatter_helper_s(T* base, size_t stride, const vec<T, N>& value, csizes_t<Indices...>)
{
- swallow{ (write(base + groupsize * stride, slice<Indices * groupsize, groupsize>(value)), 0)... };
+ swallow{ (write(base + groupsize * Indices * stride, slice<Indices * groupsize, groupsize>(value)),
+ 0)... };
}
+} // namespace internal
+
template <size_t groupsize = 1, typename T, size_t N, size_t Nout = N* groupsize, typename IT>
KFR_INTRINSIC void scatter(T* base, const vec<IT, N>& offset, const vec<T, Nout>& value)
{
return scatter_helper<groupsize>(base, offset, value, csizeseq<N>);
}
-template <size_t groupsize = 1, typename T, size_t N, size_t Nout = N* groupsize, typename IT>
-KFR_INTRINSIC void scatter_stride(T* base, const vec<T, Nout>& value, size_t stride)
+template <size_t groupsize = 1, typename T, size_t N>
+KFR_INTRINSIC void scatter_stride(T* base, const vec<T, N>& value, size_t stride)
{
- return scatter_helper_s<groupsize>(base, stride, value, csizeseq<N>);
+ constexpr size_t Nout = N / groupsize;
+ if constexpr (Nout > 2)
+ {
+ constexpr size_t Nlow = prev_poweroftwo(Nout - 1);
+ internal::scatter_helper_s<groupsize>(base, stride, slice<0, Nlow>(value), csizeseq<Nlow>);
+ internal::scatter_helper_s<groupsize>(base + Nlow, stride, slice<Nlow, Nout - Nlow>(value),
+ csizeseq<(Nout - Nlow)>);
+ }
+ else
+ return internal::scatter_helper_s<groupsize>(base, stride, value, csizeseq<Nout>);
}
template <typename T, size_t groupsize = 1>