kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit e7e510cbd6d5e0284b42a3cb953352f523a43bf9
parent e5b25505221b88bbede90d571233cacb891f08ec
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Sat, 12 Nov 2022 18:13:49 +0000

get_element, indices and select

Diffstat:
Minclude/kfr/base/expression.hpp | 15+++++++++++++++
Minclude/kfr/simd/select.hpp | 4++--
Mtests/unit/base/basic_expressions.cpp | 2++
Mtests/unit/base/tensor.cpp | 31+++++++++++++++++++++++++++++++
4 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp @@ -206,6 +206,21 @@ struct expression_traits<T, std::enable_if_t<is_expr_element<T>>> : expression_t KFR_MEM_INTRINSIC constexpr static shape<0> shapeof() { return {}; } }; +template <typename E, enable_if_input_expression<E>* = nullptr, index_t Dims = expression_dims<E>> +inline expression_value_type<E> get_element(E&& expr, shape<Dims> index) +{ + return get_elements(expr, index, axis_params_v<0, 1>).front(); +} + +template <index_t Axis, index_t Dims, index_t VecAxis, size_t N> +inline vec<index_t, N> indices(const shape<Dims>& index, axis_params<VecAxis, N>) +{ + if constexpr (Axis == VecAxis) + return index[Axis] + enumerate<index_t, N, 0, 1>(); + else + return index[Axis]; +} + namespace internal_generic { struct anything diff --git a/include/kfr/simd/select.hpp b/include/kfr/simd/select.hpp @@ -42,8 +42,8 @@ template <typename T1, size_t N, typename T2, typename T3, KFR_ENABLE_IF(is_nume typename Tout = subtype<std::common_type_t<T2, T3>>> KFR_INTRINSIC vec<Tout, N> select(const mask<T1, N>& m, const T2& x, const T3& y) { - static_assert(sizeof(T1) == sizeof(Tout), "select: incompatible types"); - return intrinsics::select(bitcast<Tout>(m.asvec()).asmask(), broadcastto<Tout>(x), broadcastto<Tout>(y)); + return intrinsics::select(bitcast<Tout>(cast<itype<Tout>>(bitcast<itype<T1>>(m.asvec()))).asmask(), + broadcastto<Tout>(x), broadcastto<Tout>(y)); } } // namespace CMT_ARCH_NAME } // namespace kfr diff --git a/tests/unit/base/basic_expressions.cpp b/tests/unit/base/basic_expressions.cpp @@ -151,5 +151,7 @@ TEST(assign_expression) TEST(trace) { render(trace(counter()), 44); } +TEST(get_element) { CHECK(get_element(counter(0, 1, 10, 100), { 1, 2, 3 }) == 321); } + } // namespace CMT_ARCH_NAME } // namespace kfr diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp @@ -798,5 +798,36 @@ TEST(tensor_from_container) } // namespace CMT_ARCH_NAME +template <typename T, index_t Size> +struct identity_matrix +{ +}; + +template <typename T, index_t Size> +struct expression_traits<identity_matrix<T, Size>> : expression_traits_defaults +{ + using value_type = T; + constexpr static size_t dims = 2; + constexpr static shape<2> shapeof(const identity_matrix<T, Size>& self) { return { Size, Size }; } + constexpr static shape<2> shapeof() { return { Size, Size }; } +}; + +template <typename T, index_t Size, index_t Axis, size_t N> +vec<T, N> get_elements(const identity_matrix<T, Size>& self, const shape<2>& index, + const axis_params<Axis, N>& sh) +{ + return select(indices<0>(index, sh) == indices<1>(index, sh), 1, 0); +} + +inline namespace CMT_ARCH_NAME +{ + +TEST(identity_matrix) +{ + CHECK(trender(identity_matrix<float, 3>{}) == tensor<float, 2>{ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } }); +} + +} // namespace CMT_ARCH_NAME + } // namespace kfr CMT_PRAGMA_MSVC(warning(pop))