commit 125177d1b75db4af7ffa1761ec31ed1769bb8bb1
parent 1f2d9979a26406e321d3ee2abf33c70081ce1587
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Mon, 1 Apr 2019 17:21:35 +0000
expression_scalar: support for vec<T>
Diffstat:
2 files changed, 10 insertions(+), 28 deletions(-)
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -286,8 +286,8 @@ protected:
vec_shape<U, N>) const
{
static_assert(ArgIndex < count, "Incorrect ArgIndex");
- return get_elements(
- static_cast<vec<U, N>>(std::get<ArgIndex>(this->args), cinput, index, vec_shape<T, N>()));
+ return static_cast<vec<U, N>>(
+ get_elements(std::get<ArgIndex>(this->args), cinput, index, vec_shape<T, N>()));
}
template <typename U, size_t N,
typename T = value_type_of<typename details::get_nth_type<0, Args...>::type>>
@@ -317,20 +317,19 @@ private:
}
};
-template <typename T, size_t width = 1>
+template <typename T>
struct expression_scalar : input_expression
{
using value_type = T;
expression_scalar() = delete;
constexpr expression_scalar(const T& val) CMT_NOEXCEPT : val(val) {}
- constexpr expression_scalar(const vec<T, width>& val) CMT_NOEXCEPT : val(val) {}
- vec<T, width> val;
+ T val;
template <size_t N>
friend KFR_INTRINSIC vec<T, N> get_elements(const expression_scalar& self, cinput_t, size_t,
vec_shape<T, N>)
{
- return resize<N>(self.val);
+ return broadcast<N>(self.val);
}
};
@@ -341,23 +340,11 @@ struct arg_impl
};
template <typename T1, typename T2>
-struct arg_impl<T1, T2, void_t<enable_if<is_number<T1>::value>>>
+struct arg_impl<T1, T2, void_t<enable_if<is_vec_element<T1>::value>>>
{
using type = expression_scalar<T1>;
};
-template <typename T1, typename T2>
-struct arg_impl<complex<T1>, T2>
-{
- using type = expression_scalar<complex<T1>>;
-};
-
-template <typename T1, typename T2, size_t N>
-struct arg_impl<vec<T1, N>, T2>
-{
- using type = expression_scalar<T1, N>;
-};
-
template <typename T>
using arg = typename internal::arg_impl<decay<T>, T>::type;
@@ -404,12 +391,6 @@ CMT_INTRINSIC internal::expression_scalar<T> scalar(const T& val)
return internal::expression_scalar<T>(val);
}
-template <typename T, size_t N>
-CMT_INTRINSIC internal::expression_scalar<T, N> scalar(const vec<T, N>& val)
-{
- return internal::expression_scalar<T, N>(val);
-}
-
template <typename Fn, typename... Args>
CMT_INTRINSIC internal::expression_function<decay<Fn>, Args...> bind_expression(Fn&& fn, Args&&... args)
{
@@ -428,7 +409,8 @@ CMT_INTRINSIC internal::expression_function<Fn, NewArgs...> rebind(
return internal::expression_function<Fn, NewArgs...>(e.get_fn(), std::forward<NewArgs>(args)...);
}
-template <size_t width = 0, typename OutputExpr, typename InputExpr, size_t groupsize = 1>
+template <size_t width = 0, typename OutputExpr, typename InputExpr, size_t groupsize = 1,
+ typename Tvec = vec<value_type_of<InputExpr>, 1>>
CMT_INTRINSIC static size_t process(OutputExpr&& out, const InputExpr& in, size_t start = 0,
size_t size = infinite_size, coutput_t coutput = nullptr,
cinput_t cinput = nullptr, csize_t<groupsize> = csize_t<groupsize>())
diff --git a/tests/complex_test.cpp b/tests/complex_test.cpp
@@ -200,9 +200,9 @@ TEST(static_tests)
testo::assert_is_same<ftype<vec<complex<i32>, 4>>, vec<complex<f32>, 4>>();
testo::assert_is_same<ftype<vec<complex<i64>, 8>>, vec<complex<f64>, 8>>();
- testo::assert_is_same<kfr::internal::arg<int>, kfr::internal::expression_scalar<int, 1>>();
+ testo::assert_is_same<kfr::internal::arg<int>, kfr::internal::expression_scalar<int>>();
testo::assert_is_same<kfr::internal::arg<complex<int>>,
- kfr::internal::expression_scalar<kfr::complex<int>, 1>>();
+ kfr::internal::expression_scalar<kfr::complex<int>>>();
testo::assert_is_same<kfr::common_type<complex<int>, double>, complex<double>>();
}