commit 27610ee7e598cf72334893ae88a36fcb19417497
parent 8015037e306c16313de730484835b72b048cc928
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Thu, 4 Aug 2022 14:08:05 +0100
Multidimensional expressions
Diffstat:
5 files changed, 353 insertions(+), 71 deletions(-)
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -125,13 +125,13 @@ struct expression_traits<T, std::enable_if_t<is_simd_type<T>>> : expression_trai
inline namespace CMT_ARCH_NAME
{
-template <typename T, typename U, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
-KFR_MEM_INTRINSIC vec<U, N> get_elements(T&& self, const shape<0>& index, vec_shape<U, N> sh)
+template <typename T, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
+KFR_INTRINSIC vec<std::decay_t<T>, N> get_elements(T&& self, const shape<0>& index, csize_t<N> sh)
{
return self;
}
-template <typename T, typename U, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
-KFR_MEM_INTRINSIC void set_elements(T& self, const shape<0>& index, const vec<U, N>& val)
+template <typename T, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
+KFR_INTRINSIC void set_elements(T& self, const shape<0>& index, csize_t<N> sh, const identity<vec<T, N>>& val)
{
static_assert(N == 1);
static_assert(!std::is_const_v<T>);
@@ -149,20 +149,22 @@ KFR_INTRINSIC static void tprocess_body(Out&& out, In&& in, size_t start, size_t
size_t x = start;
if constexpr (w > gw)
{
+ csize_t<w> wval;
CMT_LOOP_NOUNROLL
for (; x < stop / w * w; x += w)
{
outidx.set_revindex(0, x);
inidx.set_revindex(0, std::min(x, insize - 1));
- set_elements(out, outidx, get_elements(in, inidx, vec_shape<Tin, w>()));
+ set_elements(out, outidx, wval, get_elements(in, inidx, wval));
}
}
+ csize_t<gw> gwval;
CMT_LOOP_NOUNROLL
for (; x < stop / gw * gw; x += gw)
{
outidx.set_revindex(0, x);
inidx.set_revindex(0, std::min(x, insize - 1));
- set_elements(out, outidx, get_elements(in, inidx, vec_shape<Tin, gw>()));
+ set_elements(out, outidx, gwval, get_elements(in, inidx, gwval));
}
}
@@ -170,8 +172,7 @@ template <size_t width = 0, typename Out, typename In, size_t gw = 1,
CMT_ENABLE_IF(expression_traits<Out>::dims == 0)>
static auto tprocess(Out&& out, In&& in, shape<0> = {}, shape<0> = {}, csize_t<gw> = {}) -> shape<0>
{
- set_elements(out, shape<0>{},
- get_elements(in, shape<0>{}, vec_shape<typename expression_traits<In>::value_type, 1>()));
+ set_elements(out, shape<0>{}, csize_t<1>(), get_elements(in, shape<0>{}, csize_t<1>()));
return {};
}
@@ -212,9 +213,7 @@ static auto tprocess(Out&& out, In&& in, shape<outdims> start = 0, shape<outdims
const shape<indims> inshape = shapeof(in);
if (CMT_UNLIKELY(!internal_generic::can_assign_from(outshape, inshape)))
return { 0 };
- shape<outdims> stop = min(start + size, outshape);
-
- // min(out, in, size + start) - start
+ shape<outdims> stop = min(start.add_inf(size), outshape);
shape<outdims> outidx;
if constexpr (outdims == 1)
@@ -477,10 +476,10 @@ struct expression_traits<T, std::enable_if_t<std::is_base_of_v<input_expression,
inline namespace CMT_ARCH_NAME
{
-template <typename T, typename U, size_t N, KFR_ENABLE_IF(is_input_expression<T>)>
-KFR_MEM_INTRINSIC vec<U, N> get_elements(T&& self, const shape<1>& index, vec_shape<U, N> sh)
+template <typename E, size_t N, KFR_ENABLE_IF(is_input_expression<E>), typename T = value_type_of<E>>
+KFR_MEM_INTRINSIC vec<T, N> get_elements(E&& self, const shape<1>& index, csize_t<N> sh)
{
- return get_elements(self, cinput_t{}, index[0], sh);
+ return get_elements(self, cinput_t{}, index[0], vec_shape<T, N>{});
}
} // namespace CMT_ARCH_NAME
@@ -494,6 +493,13 @@ struct xwitharguments
template <size_t idx>
using nth = typename type_list::template nth<idx>;
+ using first_arg = typename type_list::template nth<0>;
+
+ template <size_t idx>
+ using nth_trait = expression_traits<typename type_list::template nth<idx>>;
+
+ using first_arg_trait = expression_traits<first_arg>;
+
std::tuple<Args...> args;
std::array<dimset, count> masks;
@@ -625,14 +631,15 @@ inline namespace CMT_ARCH_NAME
namespace internal
{
-template <index_t outdims, typename Fn, typename... Args, typename U, size_t N, index_t Dims, size_t idx>
-KFR_MEM_INTRINSIC vec<U, N> get_arg(const xfunction<Fn, Args...>& self, const shape<Dims>& index,
- vec_shape<U, N> sh, csize_t<idx>)
+template <index_t outdims, typename Fn, typename... Args, size_t N, index_t Dims, size_t idx,
+ typename Traits = expression_traits<typename xfunction<Fn, Args...>::template nth<idx>>>
+KFR_MEM_INTRINSIC vec<typename Traits::value_type, N> get_arg(const xfunction<Fn, Args...>& self,
+ const shape<Dims>& index, csize_t<N> sh,
+ csize_t<idx>)
{
- using Traits = expression_traits<typename xfunction<Fn, Args...>::template nth<idx>>;
if constexpr (Traits::dims == 0)
{
- return repeat<N>(get_elements(std::get<idx>(self.args), {}, vec_shape<U, 1>{}));
+ return repeat<N>(get_elements(std::get<idx>(self.args), {}, csize_t<1>{}));
}
else
{
@@ -641,14 +648,14 @@ KFR_MEM_INTRINSIC vec<U, N> get_arg(const xfunction<Fn, Args...>& self, const sh
if constexpr (last_dim > 0)
{
return repeat<N / std::min(last_dim, N)>(
- get_elements(std::get<idx>(self.args), indices, vec_shape<U, std::min(last_dim, N)>{}));
+ get_elements(std::get<idx>(self.args), indices, csize_t<std::min(last_dim, N)>{}));
}
else
{
if constexpr (N > 1)
{
if (CMT_UNLIKELY(self.masks[idx].back() == 0))
- return get_elements(std::get<idx>(self.args), indices, vec_shape<U, 1>{}).front();
+ return get_elements(std::get<idx>(self.args), indices, csize_t<1>{}).front();
else
return get_elements(std::get<idx>(self.args), indices, sh);
}
@@ -661,13 +668,14 @@ KFR_MEM_INTRINSIC vec<U, N> get_arg(const xfunction<Fn, Args...>& self, const sh
}
} // namespace internal
-template <typename Fn, typename... Args, typename U, size_t N, index_t Dims>
-KFR_MEM_INTRINSIC vec<U, N> get_elements(const xfunction<Fn, Args...>& self, const shape<Dims>& index,
- vec_shape<U, N> sh)
+template <typename Fn, typename... Args, size_t N, index_t Dims,
+ typename Tr = expression_traits<xfunction<Fn, Args...>>, typename T = typename Tr::value_type>
+KFR_INTRINSIC vec<T, N> get_elements(const xfunction<Fn, Args...>& self, const shape<Dims>& index,
+ csize_t<N> sh)
{
- constexpr index_t outdims = expression_traits<xfunction<Fn, Args...>>::dims;
+ constexpr index_t outdims = Tr::dims;
return self.fold_idx(
- [&](auto... idx) CMT_INLINE_LAMBDA -> vec<U, N> {
+ [&](auto... idx) CMT_INLINE_LAMBDA -> vec<T, N> {
return self.fn(internal::get_arg<outdims>(self, index, sh, idx)...);
});
}
diff --git a/include/kfr/base/new_expressions.hpp b/include/kfr/base/new_expressions.hpp
@@ -0,0 +1,238 @@
+/** @addtogroup expressions
+ * @{
+ */
+/*
+ Copyright (C) 2016 D Levin (https://www.kfrlib.com)
+ This file is part of KFR
+
+ KFR is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 2 of the License, or
+ (at your option) any later version.
+
+ KFR is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with KFR.
+
+ If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
+ Buying a commercial license is mandatory as soon as you develop commercial activities without
+ disclosing the source code of your own applications.
+ See https://www.kfrlib.com for details.
+ */
+#pragma once
+
+#include "expression.hpp"
+
+namespace kfr
+{
+
+template <typename T, typename Arg>
+struct xcastto : public xwitharguments<Arg>
+{
+ using xwitharguments<Arg>::xwitharguments;
+};
+
+template <typename T, typename Arg, KFR_ACCEPT_EXPRESSIONS(Arg)>
+KFR_INTRINSIC xcastto<T, Arg> x_castto(Arg&& arg)
+{
+ return { std::forward<Arg>(arg) };
+}
+
+template <typename T, typename Arg, KFR_ACCEPT_EXPRESSIONS(Arg)>
+KFR_INTRINSIC xcastto<T, Arg> x_castto(Arg&& arg, ctype_t<T>)
+{
+ return { std::forward<Arg>(arg) };
+}
+
+template <typename T, typename Arg>
+struct expression_traits<xcastto<T, Arg>> : expression_traits_defaults
+{
+ using ArgTraits = expression_traits<Arg>;
+
+ using value_type = T;
+ constexpr static size_t dims = ArgTraits::dims;
+
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof(const xcastto<T, Arg>& self)
+ {
+ return ArgTraits::shapeof(self.first());
+ }
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof() { return ArgTraits::shapeof(); }
+};
+
+inline namespace CMT_ARCH_NAME
+{
+
+template <typename T, typename Arg, index_t NDims, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const xcastto<T, Arg>& self, const shape<NDims>& index, csize_t<N> sh)
+{
+ return static_cast<vec<T, N>>(get_elements(self.first(), index, sh));
+}
+
+template <typename T, typename Arg, index_t NDims, size_t N>
+KFR_INTRINSIC void set_elements(const xcastto<T, Arg>& self, const shape<NDims>& index, csize_t<N> sh,
+ const identity<vec<T, N>>& value)
+{
+ set_elements(self.first(), index, sh, value);
+}
+} // namespace CMT_ARCH_NAME
+
+template <typename T, index_t Dims, typename Fn>
+struct xlambda
+{
+ Fn&& fn;
+};
+
+template <typename T, index_t Dims = 1, typename Fn>
+KFR_INTRINSIC xlambda<T, Dims, Fn> x_lambda(Fn&& fn)
+{
+ return { std::forward<Fn>(fn) };
+}
+
+template <typename T, index_t Dims, typename Fn>
+struct expression_traits<xlambda<T, Dims, Fn>> : expression_traits_defaults
+{
+ using value_type = T;
+ constexpr static size_t dims = Dims;
+
+ KFR_MEM_INTRINSIC constexpr static shape<Dims> shapeof(const xlambda<T, Dims, Fn>& self)
+ {
+ return infinite_size;
+ }
+ KFR_MEM_INTRINSIC constexpr static shape<Dims> shapeof() { return infinite_size; }
+};
+
+inline namespace CMT_ARCH_NAME
+{
+
+template <typename T, index_t Dims, typename Fn, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const xlambda<T, Dims, Fn>& self, const shape<Dims>& index,
+ csize_t<N> sh)
+{
+ if constexpr (std::is_callable_v<Fn, shape<Dims>, csize_t<N>>)
+ return self.fn(index, sh);
+ else if constexpr (std::is_callable_v<Fn, shape<Dims>>)
+ return vec<T, N>{ [&](size_t idx) { return self.fn(index.add(idx)); } };
+ else if constexpr (std::is_callable_v<Fn>)
+ return apply<N>(self.fn);
+ else
+ return czeros;
+}
+
+} // namespace CMT_ARCH_NAME
+
+template <typename Arg>
+struct xpadded : public xwitharguments<Arg>
+{
+ using ArgTraits = typename xwitharguments<Arg>::first_arg_trait;
+ typename ArgTraits::value_type fill_value;
+ shape<ArgTraits::dims> input_shape;
+
+ KFR_MEM_INTRINSIC xpadded(Arg&& arg, typename ArgTraits::value_type fill_value)
+ : xwitharguments<Arg>{ std::forward<Arg>(arg) }, fill_value(std::move(fill_value)),
+ input_shape(ArgTraits::shapeof(this->first()))
+ {
+ }
+};
+
+template <typename Arg, typename T = expression_value_type<Arg>>
+KFR_INTRINSIC xpadded<Arg> x_padded(Arg&& arg, T fill_value = T{})
+{
+ static_assert(expression_dims<Arg> >= 1);
+ return { std::forward<Arg>(arg), std::move(fill_value) };
+}
+
+template <typename Arg>
+struct expression_traits<xpadded<Arg>> : expression_traits_defaults
+{
+ using ArgTraits = expression_traits<Arg>;
+
+ using value_type = typename ArgTraits::value_type;
+ constexpr static size_t dims = ArgTraits::dims;
+
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof(const xpadded<Arg>& self) { return infinite_size; }
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof() { return infinite_size; }
+};
+
+inline namespace CMT_ARCH_NAME
+{
+
+template <typename Arg, size_t N, typename Traits = expression_traits<xpadded<Arg>>,
+ typename T = typename Traits::value_type>
+KFR_INTRINSIC vec<T, N> get_elements(const xpadded<Arg>& self, const shape<Traits::dims>& index,
+ csize_t<N> sh)
+{
+ if (index.ge(self.input_size))
+ {
+ return self.fill_value;
+ }
+ else if (CMT_LIKELY(index.add(N).le(self.input_size)))
+ {
+ return get_elements(self.first(), index, sh);
+ }
+ else
+ {
+ vec<T, N> x = self.fill_value;
+ for (size_t i = 0; i < N; i++)
+ {
+ shape ish = index.add(i);
+ if (ish.back() < self.input_size.back())
+ x[i] = get_elements(self.first(), ish, csize_t<1>()).front();
+ }
+ return x;
+ }
+}
+
+} // namespace CMT_ARCH_NAME
+
+template <typename Arg>
+struct xreverse : public xwitharguments<Arg>
+{
+ using ArgTraits = typename xwitharguments<Arg>::first_arg_trait;
+ shape<ArgTraits::dims> input_shape;
+
+ KFR_MEM_INTRINSIC xreverse(Arg&& arg)
+ : xwitharguments<Arg>{ std::forward<Arg>(arg) }, input_shape(ArgTraits::shapeof(this->first()))
+ {
+ }
+};
+
+template <typename Arg>
+KFR_INTRINSIC xreverse<Arg> x_reverse(Arg&& arg)
+{
+ static_assert(expression_dims<Arg> >= 1);
+ return { std::forward<Arg>(arg) };
+}
+
+template <typename Arg>
+struct expression_traits<xreverse<Arg>> : expression_traits_defaults
+{
+ using ArgTraits = expression_traits<Arg>;
+
+ using value_type = typename ArgTraits::value_type;
+ constexpr static size_t dims = ArgTraits::dims;
+
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof(const xreverse<Arg>& self)
+ {
+ return ArgTraits::shapeof(self);
+ }
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof() { return ArgTraits::shapeof(); }
+};
+
+inline namespace CMT_ARCH_NAME
+{
+
+template <typename Arg, size_t N, typename Traits = expression_traits<xreverse<Arg>>,
+ typename T = typename Traits::value_type>
+KFR_INTRINSIC vec<T, N> get_elements(const xreverse<Arg>& self, const shape<Traits::dims>& index,
+ csize_t<N> sh)
+{
+ return reverse(get_elements(self.first(), self.input_shape - index - N, sh));
+}
+
+} // namespace CMT_ARCH_NAME
+
+} // namespace kfr
diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp
@@ -27,10 +27,10 @@
#include "impl/static_array.hpp"
+#include "../math/min_max.hpp"
#include "../simd/shuffle.hpp"
#include "../simd/types.hpp"
#include "../simd/vec.hpp"
-#include "../math/min_max.hpp"
#include <bitset>
@@ -85,14 +85,47 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
static_assert(dims < maximum_dims);
- shape operator+(const shape& other) const
+ bool ge(const shape& other) const
+ {
+ if constexpr (dims == 1)
+ {
+ return front() >= other.front();
+ }
+ else
+ {
+ return all(**this >= *other);
+ }
+ }
+
+ bool le(const shape& other) const
+ {
+ if constexpr (dims == 1)
+ {
+ return front() <= other.front();
+ }
+ else
+ {
+ return all(**this <= *other);
+ }
+ }
+
+ shape add(index_t value) const
+ {
+ shape result = *this;
+ result.back() += value;
+ return result;
+ }
+ shape add(const shape& other) const { return **this + *other; }
+ shape sub(const shape& other) const { return **this - *other; }
+
+ shape add_inf(const shape& other) const
{
vec<index_t, dims> x = **this;
vec<index_t, dims> y = *other;
mask<index_t, dims> inf = (x == infinite_size) || (y == infinite_size);
return select(inf, infinite_size, x + y);
}
- shape operator-(const shape& other) const
+ shape sub_inf(const shape& other) const
{
vec<index_t, dims> x = **this;
vec<index_t, dims> y = *other;
diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp
@@ -740,32 +740,32 @@ struct expression_traits<tensor<T, Dims>> : expression_traits_defaults
inline namespace CMT_ARCH_NAME
{
-template <typename T, index_t NDims, typename U, size_t N>
-KFR_INTRINSIC vec<U, N> get_elements(const tensor<T, NDims>& self, const shape<NDims>& index, vec_shape<U, N>)
+template <typename T, index_t NDims, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const tensor<T, NDims>& self, const shape<NDims>& index, csize_t<N>)
{
const T* data = self.data() + self.calc_index(index);
if (self.is_last_contiguous())
{
- return static_cast<vec<U, N>>(read<N>(data));
+ return read<N>(data);
}
else
{
- return static_cast<vec<U, N>>(gather_stride<N>(data, self.strides().back()));
+ return gather_stride<N>(data, self.strides().back());
}
}
-template <typename T, index_t NDims, typename U, size_t N>
-KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>& index,
- const vec<U, N>& value)
+template <typename T, index_t NDims, size_t N>
+KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>& index, csize_t<N>,
+ const identity<vec<T, N>>& value)
{
T* data = self.data() + self.calc_index(index);
if (self.is_last_contiguous())
{
- write(data, vec<T, N>(value));
+ write(data, value);
}
else
{
- scatter_stride(data, vec<T, N>(value), self.strides().back());
+ scatter_stride(data, value, self.strides().back());
}
}
@@ -819,4 +819,4 @@ struct representation<kfr::shape<dims>>
};
} // namespace cometa
-CMT_PRAGMA_MSVC(warning(pop))
-\ No newline at end of file
+CMT_PRAGMA_MSVC(warning(pop))
diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp
@@ -5,6 +5,7 @@
*/
#include <kfr/base.hpp>
+#include <kfr/base/new_expressions.hpp>
#include <kfr/base/tensor.hpp>
#include <kfr/io/tostring.hpp>
@@ -35,10 +36,14 @@ TEST(shape)
CHECK(internal_generic::strides_for_shape(shape{ 2, 3, 4 }, 10) == shape{ 120, 40, 10 });
- CHECK(increment_indices_return(shape{ 0, 0, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 0, 1 });
- CHECK(increment_indices_return(shape{ 0, 0, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 1, 0 });
- CHECK(increment_indices_return(shape{ 0, 2, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 2, 1 });
- CHECK(increment_indices_return(shape{ 0, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 1, 0, 0 });
+ CHECK(increment_indices_return(shape{ 0, 0, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
+ shape{ 0, 0, 1 });
+ CHECK(increment_indices_return(shape{ 0, 0, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
+ shape{ 0, 1, 0 });
+ CHECK(increment_indices_return(shape{ 0, 2, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
+ shape{ 0, 2, 1 });
+ CHECK(increment_indices_return(shape{ 0, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
+ shape{ 1, 0, 0 });
CHECK(increment_indices_return(shape{ 1, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
shape{ null_index, null_index, null_index });
@@ -218,7 +223,7 @@ TEST(tensor_broadcast)
tensor<float, 2> t2{ shape{ 5, 1 }, { 10.f, 20.f, 30.f, 40.f, 50.f } };
tensor<float, 1> t4{ shape{ 5 }, { 1.f, 2.f, 3.f, 4.f, 5.f } };
tensor<float, 2> tresult{ shape{ 5, 5 }, { 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33,
- 34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, 55 } };
+ 34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, 55 } };
tensor<float, 2> t3 = tapply(t1, t2, fn::add{});
@@ -279,52 +284,52 @@ struct expression_traits<std::array<std::array<T, N1>, N2>> : expression_traits_
inline namespace CMT_ARCH_NAME
{
-template <typename T, typename U, size_t N>
-KFR_INTRINSIC vec<U, N> get_elements(const tcounter<T, 1>& self, const shape<1>& index, vec_shape<U, N> sh)
+template <typename T, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const tcounter<T, 1>& self, const shape<1>& index, csize_t<N> sh)
{
T acc = self.start;
acc += static_cast<T>(index.front()) * self.steps.front();
- return static_cast<vec<U, N>>(acc + enumerate(vec_shape<T, N>(), self.steps.back()));
+ return acc + enumerate(vec_shape<T, N>(), self.steps.back());
}
-template <typename T, index_t dims, typename U, size_t N>
-KFR_INTRINSIC vec<U, N> get_elements(const tcounter<T, dims>& self, const shape<dims>& index,
- vec_shape<U, N> sh)
+template <typename T, index_t dims, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const tcounter<T, dims>& self, const shape<dims>& index, csize_t<N> sh)
{
T acc = self.start;
vec<T, dims> tindices = cast<T>(*index);
cfor(csize<0>, csize<dims>, [&](auto i) CMT_INLINE_LAMBDA { acc += tindices[i] * self.steps[i]; });
- return static_cast<vec<U, N>>(acc + enumerate(vec_shape<T, N>(), self.steps.back()));
+ return acc + enumerate(vec_shape<T, N>(), self.steps.back());
}
-template <typename T, size_t N1, typename U, size_t N>
-KFR_INTRINSIC vec<U, N> get_elements(const std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index,
- vec_shape<U, N> sh)
+template <typename T, size_t N1, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index,
+ csize_t<N> sh)
{
const T* CMT_RESTRICT const data = self.data();
return read<N>(data + std::min(index[0], static_cast<index_t>(N1 - 1)));
}
-template <typename T, size_t N1, typename U, size_t N>
-KFR_INTRINSIC void set_elements(std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index, vec<U, N> val)
+template <typename T, size_t N1, size_t N>
+KFR_INTRINSIC void set_elements(std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index, csize_t<N>,
+ const identity<vec<T, N>>& val)
{
T* CMT_RESTRICT const data = self.data();
- write(data + std::min(index[0], static_cast<index_t>(N1 - 1)), static_cast<vec<T, N>>(val));
+ write(data + std::min(index[0], static_cast<index_t>(N1 - 1)), val);
}
-template <typename T, size_t N1, size_t N2, typename U, size_t N>
-KFR_INTRINSIC vec<U, N> get_elements(const std::array<std::array<T, N1>, N2>& CMT_RESTRICT self,
- const shape<2>& index, vec_shape<U, N> sh)
+template <typename T, size_t N1, size_t N2, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const std::array<std::array<T, N1>, N2>& CMT_RESTRICT self,
+ const shape<2>& index, csize_t<N> sh)
{
const T* CMT_RESTRICT const data = self[std::min(index[0], static_cast<index_t>(N2 - 1))].data();
return read<N>(data + std::min(index[1], static_cast<index_t>(N1 - 1)));
}
-template <typename T, size_t N1, size_t N2, typename U, size_t N>
+template <typename T, size_t N1, size_t N2, size_t N>
KFR_INTRINSIC void set_elements(std::array<std::array<T, N1>, N2>& CMT_RESTRICT self, const shape<2>& index,
- vec<U, N> val)
+ csize_t<N>, const identity<vec<T, N>>& val)
{
T* CMT_RESTRICT const data = self[std::min(index[0], static_cast<index_t>(N2 - 1))].data();
- write(data + std::min(index[1], static_cast<index_t>(N1 - 1)), static_cast<vec<T, N>>(val));
+ write(data + std::min(index[1], static_cast<index_t>(N1 - 1)), val);
}
TEST(tensor_expressions2)
@@ -332,12 +337,12 @@ TEST(tensor_expressions2)
auto aa = std::array<std::array<double, 2>, 2>{ { { { 1, 2 } }, { { 3, 4 } } } };
static_assert(expression_traits<decltype(aa)>::dims == 2);
CHECK(expression_traits<decltype(aa)>::shapeof(aa) == shape{ 2, 2 });
- CHECK(get_elements(aa, { 1, 1 }, vec_shape<float, 1>{}) == vec{ 4.f });
- CHECK(get_elements(aa, { 1, 0 }, vec_shape<float, 2>{}) == vec{ 3.f, 4.f });
+ CHECK(get_elements(aa, { 1, 1 }, csize_t<1>{}) == vec{ 4. });
+ CHECK(get_elements(aa, { 1, 0 }, csize_t<2>{}) == vec{ 3., 4. });
static_assert(expression_traits<decltype(1234.f)>::dims == 0);
CHECK(expression_traits<decltype(1234.f)>::shapeof(1234.f) == shape{});
- CHECK(get_elements(1234.f, {}, vec_shape<float, 3>{}) == vec{ 1234.f, 1234.f, 1234.f });
+ CHECK(get_elements(1234.f, {}, csize_t<3>{}) == vec{ 1234.f, 1234.f, 1234.f });
tprocess(aa, 123.45f);
@@ -504,8 +509,8 @@ extern "C" __declspec(dllexport) void assembly_test11(f64x2& x, u64x2 y) { x = y
extern "C" __declspec(dllexport) void assembly_test12(
std::array<std::array<uint32_t, 4>, 4>& x,
- const xfunction<std::plus<>, std::array<std::array<uint32_t, 1>, 4>&, std::array<std::array<uint32_t, 4>, 1>&>&
- y)
+ const xfunction<std::plus<>, std::array<std::array<uint32_t, 1>, 4>&,
+ std::array<std::array<uint32_t, 4>, 1>&>& y)
{
// [[maybe_unused]] constexpr auto sh1 = expression_traits<decltype(x)>::shapeof();
// [[maybe_unused]] constexpr auto sh2 = expression_traits<decltype(y)>::shapeof();
@@ -567,4 +572,4 @@ TEST(enumerate)
} // namespace CMT_ARCH_NAME
} // namespace kfr
-CMT_PRAGMA_MSVC(warning(pop))
-\ No newline at end of file
+CMT_PRAGMA_MSVC(warning(pop))