commit e0199eab4ac465129068dbeff7b94ce6d722b4d3
parent 96249c0adef7f4f54630398c00b90cbeb7291a78
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Wed, 2 Nov 2022 08:19:23 +0000
expression_trace
Diffstat:
2 files changed, 39 insertions(+), 4 deletions(-)
diff --git a/include/kfr/base/basic_expressions.hpp b/include/kfr/base/basic_expressions.hpp
@@ -826,8 +826,8 @@ template <typename Arg1, typename Arg2, index_t ConcatAxis, index_t NDims, index
KFR_INTRINSIC vec<T, N> get_elements(const expression_concatenate<Arg1, Arg2, ConcatAxis>& self,
const shape<NDims>& index, const axis_params<Axis, N>& sh)
{
- const shape<NDims> size1 = self.size1;
- constexpr index_t Naxis = ConcatAxis == Axis ? N : 1;
+ const shape<NDims> size1 = self.size1;
+ constexpr index_t Naxis = ConcatAxis == Axis ? N : 1;
if (index[ConcatAxis] >= size1[ConcatAxis])
{
shape index1 = index;
@@ -947,7 +947,7 @@ private:
template <typename... E, enable_if_output_expressions<E...>* = nullptr>
KFR_FUNCTION expression_unpack<E...> unpack(E&&... e)
{
- return expression_unpack<E...>(std::forward<E>(e)...);
+ return { std::forward<E>(e)... };
}
// ----------------------------------------------------------------------------
@@ -983,8 +983,38 @@ struct expression_adjacent : expression_with_traits<E>
template <typename Fn, typename E1>
KFR_INTRINSIC expression_adjacent<Fn, E1> adjacent(Fn&& fn, E1&& e1)
{
- return expression_adjacent<Fn, E1>(std::forward<Fn>(fn), std::forward<E1>(e1));
+ return { std::forward<Fn>(fn), std::forward<E1>(e1) };
}
+// ----------------------------------------------------------------------------
+
+template <typename E>
+struct expression_trace : public expression_with_traits<E>
+{
+ using expression_with_traits<E>::expression_with_traits;
+ using value_type = typename expression_with_traits<E>::value_type;
+ constexpr static inline index_t dims = expression_with_traits<E>::dims;
+
+ template <size_t N, index_t VecAxis>
+ KFR_INTRINSIC friend vec<value_type, N> get_elements(const expression_trace& self, shape<dims> index,
+ axis_params<VecAxis, N> sh)
+ {
+ const vec<value_type, N> in = get_elements(self.first(), index, sh);
+ println("[", fmt<'s', 16>(as_string(index)), "] = ", in);
+ return in;
+ }
+};
+
+/**
+ * @brief Returns template expression that returns the result of calling \f$ fn(x_i, x_{i-1}) \f$
+ */
+template <typename E1>
+KFR_INTRINSIC expression_trace<E1> trace(E1&& e1)
+{
+ return { std::forward<E1>(e1) };
+}
+
+// ----------------------------------------------------------------------------
+
} // namespace CMT_ARCH_NAME
} // namespace kfr
diff --git a/tests/unit/base/basic_expressions.cpp b/tests/unit/base/basic_expressions.cpp
@@ -133,5 +133,10 @@ TEST(assign_expression)
CHECK_EXPRESSION(b, { 1000, 1010, 1020, 1030, 1040, 1050, 1060, 1070, 1080, 1090 });
}
+TEST(trace)
+{
+ render(trace(counter()), 44);
+}
+
} // namespace CMT_ARCH_NAME
} // namespace kfr