commit e7e510cbd6d5e0284b42a3cb953352f523a43bf9
parent e5b25505221b88bbede90d571233cacb891f08ec
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Sat, 12 Nov 2022 18:13:49 +0000
get_element, indices and select
Diffstat:
4 files changed, 50 insertions(+), 2 deletions(-)
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -206,6 +206,21 @@ struct expression_traits<T, std::enable_if_t<is_expr_element<T>>> : expression_t
KFR_MEM_INTRINSIC constexpr static shape<0> shapeof() { return {}; }
};
+template <typename E, enable_if_input_expression<E>* = nullptr, index_t Dims = expression_dims<E>>
+inline expression_value_type<E> get_element(E&& expr, shape<Dims> index)
+{
+ return get_elements(expr, index, axis_params_v<0, 1>).front();
+}
+
+template <index_t Axis, index_t Dims, index_t VecAxis, size_t N>
+inline vec<index_t, N> indices(const shape<Dims>& index, axis_params<VecAxis, N>)
+{
+ if constexpr (Axis == VecAxis)
+ return index[Axis] + enumerate<index_t, N, 0, 1>();
+ else
+ return index[Axis];
+}
+
namespace internal_generic
{
struct anything
diff --git a/include/kfr/simd/select.hpp b/include/kfr/simd/select.hpp
@@ -42,8 +42,8 @@ template <typename T1, size_t N, typename T2, typename T3, KFR_ENABLE_IF(is_nume
typename Tout = subtype<std::common_type_t<T2, T3>>>
KFR_INTRINSIC vec<Tout, N> select(const mask<T1, N>& m, const T2& x, const T3& y)
{
- static_assert(sizeof(T1) == sizeof(Tout), "select: incompatible types");
- return intrinsics::select(bitcast<Tout>(m.asvec()).asmask(), broadcastto<Tout>(x), broadcastto<Tout>(y));
+ return intrinsics::select(bitcast<Tout>(cast<itype<Tout>>(bitcast<itype<T1>>(m.asvec()))).asmask(),
+ broadcastto<Tout>(x), broadcastto<Tout>(y));
}
} // namespace CMT_ARCH_NAME
} // namespace kfr
diff --git a/tests/unit/base/basic_expressions.cpp b/tests/unit/base/basic_expressions.cpp
@@ -151,5 +151,7 @@ TEST(assign_expression)
TEST(trace) { render(trace(counter()), 44); }
+TEST(get_element) { CHECK(get_element(counter(0, 1, 10, 100), { 1, 2, 3 }) == 321); }
+
} // namespace CMT_ARCH_NAME
} // namespace kfr
diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp
@@ -798,5 +798,36 @@ TEST(tensor_from_container)
} // namespace CMT_ARCH_NAME
+template <typename T, index_t Size>
+struct identity_matrix
+{
+};
+
+template <typename T, index_t Size>
+struct expression_traits<identity_matrix<T, Size>> : expression_traits_defaults
+{
+ using value_type = T;
+ constexpr static size_t dims = 2;
+ constexpr static shape<2> shapeof(const identity_matrix<T, Size>& self) { return { Size, Size }; }
+ constexpr static shape<2> shapeof() { return { Size, Size }; }
+};
+
+template <typename T, index_t Size, index_t Axis, size_t N>
+vec<T, N> get_elements(const identity_matrix<T, Size>& self, const shape<2>& index,
+ const axis_params<Axis, N>& sh)
+{
+ return select(indices<0>(index, sh) == indices<1>(index, sh), 1, 0);
+}
+
+inline namespace CMT_ARCH_NAME
+{
+
+TEST(identity_matrix)
+{
+ CHECK(trender(identity_matrix<float, 3>{}) == tensor<float, 2>{ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } });
+}
+
+} // namespace CMT_ARCH_NAME
+
} // namespace kfr
CMT_PRAGMA_MSVC(warning(pop))