commit 8e6230b42eff654a74c36e5b9f0013ee7e71d509
parent eef498a3a3cd8f61704c6c6a9eb9aa4c59cac052
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Tue, 27 Nov 2018 18:35:11 +0000
groupsize support for gather/scatter
Diffstat:
1 file changed, 41 insertions(+), 7 deletions(-)
diff --git a/include/kfr/base/read_write.hpp b/include/kfr/base/read_write.hpp
@@ -75,12 +75,12 @@ CMT_INLINE vec<T, Nout> gather_stride(const T* base, csizes_t<Indices...>)
{
return make_vector(base[Indices * Stride]...);
}
-template <size_t Nout, typename T, size_t... Indices>
+template <size_t Nout, size_t groupsize, typename T, size_t... Indices>
CMT_INLINE vec<T, Nout> gather_stride_s(const T* base, size_t stride, csizes_t<Indices...>)
{
- return make_vector(base[Indices * stride]...);
-}
+ return make_vector(read<groupsize>(base + Indices * groupsize * stride)...);
}
+} // namespace internal
template <typename T, size_t N>
CMT_INLINE vec<T, N> gather(const T* base, const vec<u32, N>& indices)
@@ -88,10 +88,10 @@ CMT_INLINE vec<T, N> gather(const T* base, const vec<u32, N>& indices)
return internal::gather(base, indices, csizeseq_t<N>());
}
-template <size_t Nout, typename T>
-CMT_INLINE vec<T, Nout> gather_stride(const T* base, size_t stride)
+template <size_t Nout, size_t groupsize = 1, typename T>
+CMT_INLINE vec<T, Nout * groupsize> gather_stride(const T* base, size_t stride)
{
- return internal::gather_stride_s<Nout>(base, stride, csizeseq_t<Nout>());
+ return internal::gather_stride_s<Nout, groupsize>(base, stride, csizeseq_t<Nout>());
}
template <size_t Nout, size_t Stride, typename T>
@@ -118,12 +118,46 @@ CMT_INLINE void scatter_helper(T* base, const vec<IT, N>& offset, const vec<T, N
swallow{ (write(base + groupsize * (*offset)[Indices], slice<Indices * groupsize, groupsize>(value)),
0)... };
}
+template <size_t groupsize, typename T, size_t N, size_t Nout = N* groupsize, size_t... Indices>
+CMT_INLINE void scatter_helper_s(T* base, size_t stride, const vec<T, Nout>& value, csizes_t<Indices...>)
+{
+ swallow{ (write(base + groupsize * stride, slice<Indices * groupsize, groupsize>(value)), 0)... };
+}
template <size_t groupsize = 1, typename T, size_t N, size_t Nout = N* groupsize, typename IT>
CMT_INLINE void scatter(T* base, const vec<IT, N>& offset, const vec<T, Nout>& value)
{
return scatter_helper<groupsize>(base, offset, value, csizeseq_t<N>());
}
+template <size_t groupsize = 1, typename T, size_t N, size_t Nout = N* groupsize, typename IT>
+CMT_INLINE void scatter_stride(T* base, const vec<T, Nout>& value, size_t stride)
+{
+ return scatter_helper_s<groupsize>(base, stride, value, csizeseq_t<N>());
+}
+
+template <typename T, size_t groupsize = 1>
+struct stride_pointer : public stride_pointer<const T, groupsize>
+{
+ template <size_t N>
+ void write(const vec<T, N>& val, csize_t<N> = csize_t<N>())
+ {
+ kfr::scatter_stride<N, groupsize>(this->ptr, val);
+ }
+};
+
+template <typename T, size_t groupsize>
+struct stride_pointer<const T, groupsize>
+{
+ const T* ptr;
+ const size_t stride;
+
+ template <size_t N>
+ vec<T, N> read(csize_t<N> = csize_t<N>())
+ {
+ return kfr::gather_stride<N, groupsize>(ptr, stride);
+ }
+};
+
template <typename T>
constexpr T partial_masks[] = { constants<T>::allones(),
constants<T>::allones(),
@@ -202,4 +236,4 @@ CMT_INLINE vec<T, N> partial_mask(size_t index, vec_t<T, N>)
{
return partial_mask<T, N>(index);
}
-}
+} // namespace kfr