commit 939d53b331a65816db6fe21e0744b149829b0349
parent b0cee0b4456734fbfb7b8673d9c68657adf0f949
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Mon, 25 Nov 2019 15:24:06 +0000
zip & column
Diffstat:
2 files changed, 35 insertions(+), 0 deletions(-)
diff --git a/include/kfr/simd/shuffle.hpp b/include/kfr/simd/shuffle.hpp
@@ -483,6 +483,23 @@ KFR_INTRINSIC vec<T, Nout> interleave(const vec<T, N>& x, const vec<T, N>& y)
}
KFR_FN(interleave)
+template <typename T, size_t N1, size_t... Ns, size_t size = N1 + csum<size_t, Ns...>(),
+ size_t side2 = 1 + sizeof...(Ns), size_t side1 = size / side2>
+KFR_INTRINSIC vec<vec<T, side2>, side1> zip(const vec<T, N1>& x, const vec<T, Ns>&... y)
+{
+ static_assert(is_poweroftwo(1 + sizeof...(Ns)), "number of vectors must be power of two");
+ return vec<vec<T, side2>, side1>::from_flatten(concat(x, y...).shuffle(scale<1>(
+ csizeseq_t<size>() % csize_t<side2>() * csize_t<side1>() + csizeseq_t<size>() / csize_t<side2>())));
+}
+KFR_FN(zip)
+
+template <size_t index, typename T, size_t N1, size_t N2>
+KFR_INTRINSIC vec<T, N2> column(const vec<vec<T, N1>, N2>& x)
+{
+ static_assert(index < N1, "column index must be less than inner vector length");
+ return x.flatten().shuffle(csizeseq_t<N2>() * csize_t<N1>() + csize_t<index>());
+}
+
template <size_t group = 1, typename T, size_t N, size_t size = N / group, size_t side2 = 2,
size_t side1 = size / side2>
KFR_INTRINSIC vec<T, N> interleavehalves(const vec<T, N>& x)
diff --git a/tests/unit/simd/shuffle.cpp b/tests/unit/simd/shuffle.cpp
@@ -4,7 +4,9 @@
* See LICENSE.txt for details
*/
+#include <kfr/io.hpp>
#include <kfr/simd/shuffle.hpp>
+
namespace kfr
{
inline namespace CMT_ARCH_NAME
@@ -75,6 +77,22 @@ TEST(split_interleave)
CHECK(interleavehalves<2>(pack(0, 1, 2, 3, 4, 5, 6, 7)) == pack(0, 1, 4, 5, 2, 3, 6, 7));
}
+TEST(zip)
+{
+ CHECK(zip(pack(1, 2, 3, 4), pack(10, 20, 30, 40)) ==
+ pack(pack(1, 10), pack(2, 20), pack(3, 30), pack(4, 40)));
+
+ CHECK(zip(pack(1, 2, 3, 4), pack(10, 20, 30, 40), pack(111, 222, 333, 444), pack(-1, -2, -3, -4)) ==
+ pack(pack(1, 10, 111, -1), pack(2, 20, 222, -2), pack(3, 30, 333, -3), pack(4, 40, 444, -4)));
+}
+
+TEST(column)
+{
+ CHECK(column<1>(pack(pack(0, 1), pack(2, 3), pack(4, 5), pack(6, 7))) == pack(1, 3, 5, 7));
+
+ CHECK(column<0>(pack(pack(0., 1.), pack(2., 3.), pack(4., 5.), pack(6., 7.))) == pack(0., 2., 4., 6.));
+}
+
TEST(broadcast)
{
CHECK(broadcast<8>(1) == pack(1, 1, 1, 1, 1, 1, 1, 1));