commit 712972cae512057f99a53a19ffdb8e0ccc920a6a
parent 17985b7d87fece531920085be1b2a0b0bbfdda9e
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Sun, 26 Nov 2023 15:11:25 +0000
transpose
Diffstat:
5 files changed, 429 insertions(+), 1 deletion(-)
diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp
@@ -344,6 +344,11 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
if (CMT_LIKELY(index < dims))
this->operator[](dims - 1 - index) = val;
}
+
+ KFR_MEM_INTRINSIC constexpr shape transpose() const
+ {
+ return this->shuffle(csizeseq<dims, dims - 1, -1>);
+ }
};
template <>
diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp
@@ -38,6 +38,7 @@
#include "expression.hpp"
#include "memory.hpp"
#include "shape.hpp"
+#include "transpose.hpp"
CMT_PRAGMA_MSVC(warning(push))
CMT_PRAGMA_MSVC(warning(disable : 4324))
@@ -372,7 +373,7 @@ public:
}
#else
tensor& operator=(const tensor& src) & = default;
- tensor& operator=(tensor&& src) & = default;
+ tensor& operator=(tensor&& src) & = default;
#endif
KFR_MEM_INTRINSIC const tensor& operator=(const tensor& src) const&
@@ -509,6 +510,23 @@ public:
using tensor_subscript<T, tensor<T, NDims>, std::make_integer_sequence<index_t, NDims>>::operator();
+ KFR_MEM_INTRINSIC tensor transpose() const
+ {
+ if constexpr (dims <= 1)
+ {
+ return *this;
+ }
+ else
+ {
+ return tensor<T, dims>{
+ m_data,
+ m_shape.transpose(),
+ m_strides.transpose(),
+ m_finalizer,
+ };
+ }
+ }
+
template <index_t dims>
KFR_MEM_INTRINSIC tensor<T, dims> reshape_may_copy(const kfr::shape<dims>& new_shape,
bool allow_copy = true) const
diff --git a/include/kfr/base/transpose.hpp b/include/kfr/base/transpose.hpp
@@ -0,0 +1,302 @@
+/** @addtogroup tensor
+ * @{
+ */
+/*
+ Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com)
+ This file is part of KFR
+
+ KFR is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 2 of the License, or
+ (at your option) any later version.
+
+ KFR is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with KFR.
+
+ If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
+ Buying a commercial license is mandatory as soon as you develop commercial activities without
+ disclosing the source code of your own applications.
+ See https://www.kfrlib.com for details.
+ */
+#pragma once
+
+#include "../simd/read_write.hpp"
+#include "../simd/types.hpp"
+#include "expression.hpp"
+#include "memory.hpp"
+#include "shape.hpp"
+
+namespace kfr
+{
+inline namespace CMT_ARCH_NAME
+{
+/// @brief Matrix transpose
+template <size_t group = 1, typename T, index_t Dims>
+void matrix_transpose(T* out, const T* in, shape<Dims> shape);
+
+/// @brief Matrix transpose (complex)
+template <size_t group = 1, typename T, index_t Dims>
+void matrix_transpose(complex<T>* out, const complex<T>* in, shape<Dims> shape);
+
+namespace internal
+{
+
+template <size_t group = 1, typename T, size_t N>
+void matrix_transpose_block_one(T* out, const T* in, size_t i, size_t stride)
+{
+ if constexpr (N == 1)
+ {
+ write(out + group * i, kfr::read<group>(in + group * i));
+ }
+ else
+ {
+ vec<T, (group * N * N)> vi = read_group<N, N, group>(in + group * i, stride);
+ vi = transpose<N, group>(vi);
+ write_group<N, N, group>(out + group * i, stride, vi);
+ }
+}
+
+template <size_t group = 1, typename T, size_t N>
+void matrix_transpose_block_two(T* out, const T* in, size_t i, size_t j, size_t stride)
+{
+ if constexpr (N == 1)
+ {
+ vec<T, group> vi = kfr::read<group>(in + group * i);
+ vec<T, group> vj = kfr::read<group>(in + group * j);
+ write(out + group * i, vj);
+ write(out + group * j, vi);
+ }
+ else
+ {
+ vec<T, (group * N * N)> vi = read_group<N, N, group>(in + group * i, stride);
+ vec<T, (group * N * N)> vj = read_group<N, N, group>(in + group * j, stride);
+ vi = transpose<N, group>(vi);
+ vj = transpose<N, group>(vj);
+ write_group<N, N, group>(out + group * i, stride, vj);
+ write_group<N, N, group>(out + group * j, stride, vi);
+ }
+}
+
+template <size_t group = 1, typename T>
+void matrix_transpose_square_small(T* out, const T* in, size_t n)
+{
+ cswitch(csizeseq<6, 1>, n, // 1, 2, 3, 4, 5 or 6
+ [&](auto n_) CMT_INLINE_LAMBDA
+ {
+ constexpr size_t n = CMT_CVAL(n_);
+ write(out, transpose<n, group>(kfr::read<n * n * group>(in)));
+ });
+}
+
+template <size_t group = 1, typename T>
+void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
+{
+#if 1
+ constexpr size_t width = 4;
+ const size_t nw = align_down(n, width);
+ const size_t wstride = width * stride;
+
+ size_t i = 0;
+ size_t istridei = 0;
+ CMT_LOOP_NOUNROLL
+ for (; i < nw; i += width)
+ {
+ matrix_transpose_block_one<group, T, width>(out, in, istridei, stride);
+
+ size_t j = i + width;
+ size_t istridej = istridei + width;
+ size_t jstridei = istridei + wstride;
+ CMT_LOOP_NOUNROLL
+ for (; j < nw; j += width)
+ {
+ matrix_transpose_block_two<group, T, width>(out, in, istridej, jstridei, stride);
+ istridej += width;
+ jstridei += wstride;
+ }
+ CMT_LOOP_NOUNROLL
+ for (; j < n; ++j)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t ii = i; ii < i + width; ++ii)
+ {
+ matrix_transpose_block_two<group, T, 1>(out, in, istridej, jstridei, stride);
+ istridej += stride;
+ jstridei += 1;
+ }
+ istridej = istridej - stride * width + 1;
+ jstridei = jstridei - width + stride;
+ }
+ istridei += width * (stride + 1);
+ }
+
+ CMT_LOOP_NOUNROLL
+ for (; i < n; ++i)
+ {
+ matrix_transpose_block_one<group, T, 1>(out, in, i * stride + i, stride);
+ CMT_LOOP_NOUNROLL
+ for (size_t j = i + 1; j < n; ++j)
+ {
+ matrix_transpose_block_two<group, T, 1>(out, in, i * stride + j, j * stride + i, stride);
+ }
+ }
+#else
+ constexpr size_t width = 4;
+ const size_t nw = align_down(n, width);
+
+ size_t i = 0;
+ CMT_LOOP_NOUNROLL
+ for (; i < nw; i += width)
+ {
+ matrix_transpose_block_one<group, T, width>(out, in, i * stride + i, stride);
+
+ size_t j = i + width;
+ CMT_LOOP_NOUNROLL
+ for (; j < nw; j += width)
+ {
+ matrix_transpose_block_two<group, T, width>(out, in, i * stride + j, j * stride + i, stride);
+ }
+ CMT_LOOP_NOUNROLL
+ for (; j < n; ++j)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t ii = i; ii < i + width; ++ii)
+ {
+ matrix_transpose_block_two<group, T, 1>(out, in, ii * stride + j, j * stride + ii, stride);
+ }
+ }
+ }
+
+ CMT_LOOP_NOUNROLL
+ for (; i < n; ++i)
+ {
+ matrix_transpose_block_one<group, T, 1>(out, in, i * stride + i, stride);
+ CMT_LOOP_NOUNROLL
+ for (size_t j = i + 1; j < n; ++j)
+ {
+ matrix_transpose_block_two<group, T, 1>(out, in, i * stride + j, j * stride + i, stride);
+ }
+ }
+#endif
+}
+
+template <size_t group = 1, typename T>
+void matrix_transpose_any(T* out, const T* in, size_t rows, size_t cols)
+{
+ // 1. transpose square sub-matrix
+ const size_t side = std::min(cols, rows);
+ matrix_transpose_square<group>(out, in, side, cols);
+
+ if (cols > rows)
+ {
+ // 2. copy remaining
+ size_t remaining = cols - rows;
+ if (in != out)
+ {
+ for (size_t r = 0; r < rows; ++r)
+ {
+ builtin_memcpy(out + group * (side + r * cols), in + group * (side + r * cols),
+ group * remaining * sizeof(T));
+ }
+ }
+
+ // 3. shift rows
+ auto* p = ptr_cast<vec<T, group>>(out) + side;
+ for (size_t r = 0; r + 1 < rows; ++r)
+ {
+ std::rotate(p, p + remaining + r * remaining, p + side + remaining + r * remaining);
+ p += side;
+ }
+ // 4. transpose remainder
+ matrix_transpose<group>(out + group * side * side, out + group * side * side,
+ shape{ side, remaining });
+ }
+ else // if (cols < rows)
+ {
+ // 2. copy remaining
+ size_t remaining = rows - cols;
+ if (in != out)
+ {
+ for (size_t r = 0; r < remaining; ++r)
+ {
+ builtin_memcpy(out + group * ((cols + r) * cols), in + group * ((cols + r) * cols),
+ group * cols * sizeof(T));
+ }
+ }
+
+ // 3. transpose remainder
+
+ matrix_transpose<group>(out + group * side * side, out + group * side * side,
+ shape{ remaining, cols });
+
+ // 4. shift cols
+ auto* p = ptr_cast<vec<T, group>>(out) + side * (cols - 1);
+ for (size_t c = cols - 1; c >= 1;)
+ {
+ --c;
+ std::rotate(p, p + side, p + (side + remaining + c * remaining));
+ p -= side;
+ }
+ }
+}
+
+template <size_t group = 1, typename T>
+KFR_INTRINSIC void matrix_transpose_noop(T* out, const T* in, size_t total)
+{
+ if (out == in)
+ return;
+ builtin_memcpy(out, in, total * sizeof(T) * group);
+}
+} // namespace internal
+
+template <size_t group, typename T, index_t Dims>
+void matrix_transpose(T* out, const T* in, shape<Dims> tshape)
+{
+ if constexpr (Dims <= 1)
+ {
+ return internal::matrix_transpose_noop<group>(out, in, tshape.product());
+ }
+ else if constexpr (Dims == 2)
+ {
+ const index_t rows = tshape[0];
+ const index_t cols = tshape[1];
+ if (cols == 1 || rows == 1)
+ {
+ return internal::matrix_transpose_noop<group>(out, in, tshape.product());
+ }
+ // TODO: special cases for tall or wide matrices
+ if (cols == rows)
+ {
+ if (cols <= 6)
+ return internal::matrix_transpose_square_small<group>(out, in, cols);
+ return internal::matrix_transpose_square<group>(out, in, cols, cols);
+ }
+ return internal::matrix_transpose_any<group>(out, in, rows, cols);
+ }
+ else
+ {
+ shape<Dims - 1> x = tshape.template slice<0, Dims - 1>();
+ index_t xproduct = x.product();
+ index_t y = tshape.back();
+ matrix_transpose<group>(out, in, shape<2>{ xproduct, y });
+ for (index_t i = 0; i < y; ++i)
+ {
+ matrix_transpose<group>(out, out, x);
+ out += group * xproduct;
+ }
+ }
+}
+
+template <size_t group, typename T, index_t Dims>
+void matrix_transpose(complex<T>* out, const complex<T>* in, shape<Dims> shape)
+{
+ return matrix_transpose<2 * group>(ptr_cast<T>(out), ptr_cast<T>(in), shape);
+}
+
+} // namespace CMT_ARCH_NAME
+
+} // namespace kfr
diff --git a/tests/unit/base/shape.cpp b/tests/unit/base/shape.cpp
@@ -37,6 +37,8 @@ TEST(shape)
CHECK(shape{ 3, 4, 5 }.from_flat(0) == shape{ 0, 0, 0 });
CHECK(shape{ 3, 4, 5 }.from_flat(59) == shape{ 2, 3, 4 });
+
+ CHECK(shape{ 3, 4, 5 }.transpose() == shape{ 5, 4, 3 });
}
TEST(shape_broadcast)
{
diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp
@@ -840,6 +840,107 @@ TEST(identity_matrix)
CHECK(trender(identity_matrix<float, 3>{}) == tensor<float, 2>{ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } });
}
+template <typename T, bool Transposed = false>
+struct expression_test_matrix : public expression_traits_defaults
+{
+ shape<2> matrix_shape;
+ index_t mark;
+ expression_test_matrix(index_t rows, size_t cols, index_t mark = 10000)
+ : matrix_shape({ rows, cols }), mark(mark)
+ {
+ if constexpr (Transposed)
+ std::swap(matrix_shape[0], matrix_shape[1]);
+ }
+
+ using value_type = T;
+ constexpr static size_t dims = 2;
+ constexpr static shape<2> get_shape(const expression_test_matrix& self) { return self.matrix_shape; }
+ constexpr static shape<2> get_shape() { return {}; }
+
+ template <index_t Axis, size_t N>
+ friend vec<T, N> get_elements(const expression_test_matrix& self, shape<2> index,
+ const axis_params<Axis, N>&)
+ {
+ shape<2> scale{ self.mark, 1 };
+ if constexpr (Transposed)
+ std::swap(scale[0], scale[1]);
+ vec<T, N> result;
+ for (size_t i = 0; i < N; ++i)
+ {
+ result[i] = index[0] * scale[0] + index[1] * scale[1];
+ index[Axis] += 1;
+ }
+ return result;
+ }
+};
+
+template <typename T>
+static void test_transpose(size_t rows, size_t cols, size_t mark = 10000)
+{
+ tensor<T, 2> t = expression_test_matrix<T>(rows, cols, mark);
+
+ tensor<T, 2> t2(shape<2>{ cols, rows });
+ univector<T> tt(t.size());
+ auto d = tensor<T, 2>(tt.data(), shape{ rows, cols }, nullptr);
+ auto d2 = tensor<T, 2>(tt.data(), shape{ cols, rows }, nullptr);
+ CHECK(d.data() == d2.data());
+ d = expression_test_matrix<T>(rows, cols, mark);
+ t2 = -1;
+ matrix_transpose(t2.data(), t.data(), shape{ rows, cols });
+
+ matrix_transpose(d2.data(), d.data(), shape{ rows, cols });
+
+ testo::scope s(as_string("type=", type_name<T>(), " rows=", rows, " cols=", cols));
+
+ auto erro = maxof(cabs(t2 - expression_test_matrix<T, true>(rows, cols, mark)));
+ CHECK(erro == 0);
+
+ auto erri = maxof(cabs(d2 - expression_test_matrix<T, true>(rows, cols, mark)));
+ CHECK(erri == 0);
+}
+
+[[maybe_unused]] static void test_transpose_t(size_t rows, size_t cols, size_t mark = 10000)
+{
+ test_transpose<float>(rows, cols, mark);
+ test_transpose<double>(rows, cols, mark);
+ test_transpose<complex<float>>(rows, cols, mark);
+ test_transpose<complex<double>>(rows, cols, mark);
+}
+
+TEST(matrix_transpose)
+{
+ for (int i = 1; i <= 100; ++i)
+ {
+ for (int j = 1; j <= 100; ++j)
+ {
+ test_transpose_t(i, j);
+ }
+ }
+
+ univector<int, 24> x = counter();
+ matrix_transpose(x.data(), x.data(), shape{ 2, 3, 4 });
+ CHECK(x == univector<int, 24>{ 0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21,
+ 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23 });
+
+ univector<uint8_t, 120> x2 = counter();
+ matrix_transpose(x2.data(), x2.data(), shape{ 2, 3, 4, 5 });
+ CHECK(x2 == univector<uint8_t, 120>{ 0, 60, 20, 80, 40, 100, 5, 65, 25, 85, 45, 105, 10, 70, 30,
+ 90, 50, 110, 15, 75, 35, 95, 55, 115, 1, 61, 21, 81, 41, 101,
+ 6, 66, 26, 86, 46, 106, 11, 71, 31, 91, 51, 111, 16, 76, 36,
+ 96, 56, 116, 2, 62, 22, 82, 42, 102, 7, 67, 27, 87, 47, 107,
+ 12, 72, 32, 92, 52, 112, 17, 77, 37, 97, 57, 117, 3, 63, 23,
+ 83, 43, 103, 8, 68, 28, 88, 48, 108, 13, 73, 33, 93, 53, 113,
+ 18, 78, 38, 98, 58, 118, 4, 64, 24, 84, 44, 104, 9, 69, 29,
+ 89, 49, 109, 14, 74, 34, 94, 54, 114, 19, 79, 39, 99, 59, 119 });
+
+ tensor<int, 1> d{ shape{ 24 } };
+ d = counter();
+ tensor<int, 3> dd = d.reshape(shape{ 2, 3, 4 });
+ tensor<int, 3> ddd = dd.transpose();
+ CHECK(trender(ddd.flatten_may_copy()) == tensor<int, 1>{ 0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21,
+ 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23 });
+}
+
} // namespace CMT_ARCH_NAME
} // namespace kfr