commit ece6b29cd94d43babd6faed513fb72a9a0a7b6ef
parent 27610ee7e598cf72334893ae88a36fcb19417497
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Wed, 31 Aug 2022 17:04:52 +0100
Ability to select axis
Diffstat:
5 files changed, 664 insertions(+), 172 deletions(-)
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -125,13 +125,15 @@ struct expression_traits<T, std::enable_if_t<is_simd_type<T>>> : expression_trai
inline namespace CMT_ARCH_NAME
{
-template <typename T, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
-KFR_INTRINSIC vec<std::decay_t<T>, N> get_elements(T&& self, const shape<0>& index, csize_t<N> sh)
+template <typename T, index_t Axis, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
+KFR_INTRINSIC vec<std::decay_t<T>, N> get_elements(T&& self, const shape<0>& index,
+ const axis_params<Axis, N>&)
{
return self;
}
-template <typename T, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
-KFR_INTRINSIC void set_elements(T& self, const shape<0>& index, csize_t<N> sh, const identity<vec<T, N>>& val)
+template <typename T, index_t Axis, size_t N, KFR_ENABLE_IF(is_simd_type<std::decay_t<T>>)>
+KFR_INTRINSIC void set_elements(T& self, const shape<0>& index, const axis_params<Axis, N>&,
+ const identity<vec<T, N>>& val)
{
static_assert(N == 1);
static_assert(!std::is_const_v<T>);
@@ -142,44 +144,69 @@ KFR_INTRINSIC void set_elements(T& self, const shape<0>& index, csize_t<N> sh, c
inline namespace CMT_ARCH_NAME
{
-template <typename Out, typename In, size_t w, size_t gw, typename Tin, index_t outdims, index_t indims>
+template <typename Out, typename In, index_t OutAxis, size_t w, size_t gw, typename Tin, index_t outdims,
+ index_t indims>
KFR_INTRINSIC static void tprocess_body(Out&& out, In&& in, size_t start, size_t stop, size_t insize,
shape<outdims> outidx, shape<indims> inidx)
{
- size_t x = start;
- if constexpr (w > gw)
+ if constexpr (indims == 0)
{
- csize_t<w> wval;
+ size_t x = start;
+ const vec<Tin, 1> val = get_elements(in, inidx, axis_params_v<0, 1>);
+ if constexpr (w > gw)
+ {
+ CMT_LOOP_NOUNROLL
+ for (; x < stop / w * w; x += w)
+ {
+ outidx[OutAxis] = x;
+ set_elements(out, outidx, axis_params_v<OutAxis, w>, repeat<w>(val));
+ }
+ }
CMT_LOOP_NOUNROLL
- for (; x < stop / w * w; x += w)
+ for (; x < stop / gw * gw; x += gw)
{
- outidx.set_revindex(0, x);
- inidx.set_revindex(0, std::min(x, insize - 1));
- set_elements(out, outidx, wval, get_elements(in, inidx, wval));
+ outidx[OutAxis] = x;
+ set_elements(out, outidx, axis_params_v<OutAxis, gw>, repeat<gw>(val));
}
}
- csize_t<gw> gwval;
- CMT_LOOP_NOUNROLL
- for (; x < stop / gw * gw; x += gw)
+ else
{
- outidx.set_revindex(0, x);
- inidx.set_revindex(0, std::min(x, insize - 1));
- set_elements(out, outidx, gwval, get_elements(in, inidx, gwval));
+ constexpr index_t InAxis = OutAxis + indims - outdims;
+ size_t x = start;
+ if constexpr (w > gw)
+ {
+ CMT_LOOP_NOUNROLL
+ for (; x < stop / w * w; x += w)
+ {
+ outidx[OutAxis] = x;
+ inidx[InAxis] = std::min(x, insize - 1);
+ set_elements(out, outidx, axis_params_v<OutAxis, w>,
+ get_elements(in, inidx, axis_params_v<InAxis, w>));
+ }
+ }
+ CMT_LOOP_NOUNROLL
+ for (; x < stop / gw * gw; x += gw)
+ {
+ outidx[OutAxis] = x;
+ inidx[InAxis] = std::min(x, insize - 1);
+ set_elements(out, outidx, axis_params_v<OutAxis, gw>,
+ get_elements(in, inidx, axis_params_v<InAxis, gw>));
+ }
}
}
-template <size_t width = 0, typename Out, typename In, size_t gw = 1,
+template <size_t width = 0, index_t Axis = 0, typename Out, typename In, size_t gw = 1,
CMT_ENABLE_IF(expression_traits<Out>::dims == 0)>
static auto tprocess(Out&& out, In&& in, shape<0> = {}, shape<0> = {}, csize_t<gw> = {}) -> shape<0>
{
- set_elements(out, shape<0>{}, csize_t<1>(), get_elements(in, shape<0>{}, csize_t<1>()));
+ set_elements(out, shape<0>{}, axis_params_v<0, 1>, get_elements(in, shape<0>{}, axis_params_v<0, 1>));
return {};
}
namespace internal
{
-constexpr size_t select_process_width(size_t width, size_t vec_width, index_t last_dim_size)
+constexpr KFR_INTRINSIC size_t select_process_width(size_t width, size_t vec_width, index_t last_dim_size)
{
if (width != 0)
return width;
@@ -188,14 +215,46 @@ constexpr size_t select_process_width(size_t width, size_t vec_width, index_t la
return std::min(vec_width, last_dim_size);
}
+
+constexpr KFR_INTRINSIC index_t select_axis(index_t ndims, index_t axis)
+{
+ if (axis >= ndims)
+ return ndims - 1;
+ return axis;
+}
+
+template <index_t VecAxis, index_t LoopAxis, index_t outdims>
+KFR_INTRINSIC index_t axis_start(const shape<outdims>& sh)
+{
+ static_assert(VecAxis < outdims);
+ static_assert(LoopAxis < outdims);
+ if constexpr (VecAxis == LoopAxis)
+ return 0;
+ else
+ return sh[LoopAxis];
+}
+template <index_t VecAxis, index_t LoopAxis, index_t outdims>
+KFR_INTRINSIC index_t axis_stop(const shape<outdims>& sh)
+{
+ static_assert(VecAxis < outdims);
+ static_assert(LoopAxis < outdims);
+ if constexpr (VecAxis == LoopAxis)
+ return 1;
+ else
+ return sh[LoopAxis];
+}
+
} // namespace internal
-template <size_t width = 0, typename Out, typename In, size_t gw = 1,
+template <size_t width = 0, index_t Axis = infinite_size, typename Out, typename In, size_t gw = 1,
typename Tin = expression_value_type<In>, typename Tout = expression_value_type<Out>,
index_t outdims = expression_dims<Out>, CMT_ENABLE_IF(expression_dims<Out> > 0)>
static auto tprocess(Out&& out, In&& in, shape<outdims> start = 0, shape<outdims> size = infinite_size,
csize_t<gw> = {}) -> shape<outdims>
{
+ using internal::axis_start;
+ using internal::axis_stop;
+
constexpr index_t indims = expression_dims<In>;
static_assert(outdims >= indims);
@@ -209,55 +268,71 @@ static auto tprocess(Out&& out, In&& in, shape<outdims> start = 0, shape<outdims
constexpr size_t w = internal::select_process_width(width, vec_width, last_dim_size);
+ constexpr index_t out_axis = internal::select_axis(outdims, Axis);
+ constexpr index_t in_axis = out_axis + indims - outdims;
+
const shape<outdims> outshape = shapeof(out);
const shape<indims> inshape = shapeof(in);
if (CMT_UNLIKELY(!internal_generic::can_assign_from(outshape, inshape)))
return { 0 };
shape<outdims> stop = min(start.add_inf(size), outshape);
+ index_t in_size = 0;
+ if constexpr (indims > 0)
+ in_size = inshape[in_axis];
+
shape<outdims> outidx;
if constexpr (outdims == 1)
{
outidx = shape<outdims>{ 0 };
- tprocess_body<Out, In, w, gw, Tin, outdims, indims>(
- std::forward<Out>(out), std::forward<In>(in), start.revindex(0), stop.revindex(0),
- inshape.revindex(0), outidx, inshape.adapt(outidx));
+ tprocess_body<Out, In, out_axis, w, gw, Tin, outdims, indims>(
+ std::forward<Out>(out), std::forward<In>(in), start[out_axis], stop[out_axis], in_size, outidx,
+ inshape.adapt(outidx));
}
else if constexpr (outdims == 2)
{
- for (index_t x = start.revindex(1); x < stop.revindex(1); ++x)
+ for (index_t i0 = axis_start<out_axis, 0>(start); i0 < axis_stop<out_axis, 0>(stop); ++i0)
{
- outidx = shape<outdims>{ x, 0 };
- tprocess_body<Out, In, w, gw, Tin, outdims, indims>(
- std::forward<Out>(out), std::forward<In>(in), start.revindex(0), stop.revindex(0),
- inshape.revindex(0), outidx, inshape.adapt(outidx));
+ for (index_t i1 = axis_start<out_axis, 1>(start); i1 < axis_stop<out_axis, 1>(stop); ++i1)
+ {
+ outidx = shape<outdims>{ i0, i1 };
+ tprocess_body<Out, In, out_axis, w, gw, Tin, outdims, indims>(
+ std::forward<Out>(out), std::forward<In>(in), start[out_axis], stop[out_axis], in_size,
+ outidx, inshape.adapt(outidx));
+ }
}
}
else if constexpr (outdims == 3)
{
- for (index_t x = start.revindex(2); x < stop.revindex(2); ++x)
+ for (index_t i0 = axis_start<out_axis, 0>(start); i0 < axis_stop<out_axis, 0>(stop); ++i0)
{
- for (index_t y = start.revindex(1); y < stop.revindex(1); ++y)
+ for (index_t i1 = axis_start<out_axis, 1>(start); i1 < axis_stop<out_axis, 1>(stop); ++i1)
{
- outidx = shape<outdims>{ x, y, 0 };
- tprocess_body<Out, In, w, gw, Tin, outdims, indims>(
- std::forward<Out>(out), std::forward<In>(in), start.revindex(0), stop.revindex(0),
- inshape.revindex(0), outidx, inshape.adapt(outidx));
+ for (index_t i2 = axis_start<out_axis, 2>(start); i2 < axis_stop<out_axis, 2>(stop); ++i2)
+ {
+ outidx = shape<outdims>{ i0, i1, i2 };
+ tprocess_body<Out, In, out_axis, w, gw, Tin, outdims, indims>(
+ std::forward<Out>(out), std::forward<In>(in), start[out_axis], stop[out_axis],
+ in_size, outidx, inshape.adapt(outidx));
+ }
}
}
}
else if constexpr (outdims == 4)
{
- for (index_t x = start.revindex(3); x < stop.revindex(3); ++x)
+ for (index_t i0 = axis_start<out_axis, 0>(start); i0 < axis_stop<out_axis, 0>(stop); ++i0)
{
- for (index_t y = start.revindex(2); y < stop.revindex(2); ++y)
+ for (index_t i1 = axis_start<out_axis, 1>(start); i1 < axis_stop<out_axis, 1>(stop); ++i1)
{
- for (index_t z = start.revindex(1); z < stop.revindex(1); ++z)
+ for (index_t i2 = axis_start<out_axis, 2>(start); i2 < axis_stop<out_axis, 2>(stop); ++i2)
{
- outidx = shape<outdims>{ x, y, z, 0 };
- tprocess_body<Out, In, w, gw, Tin, outdims, indims>(
- std::forward<Out>(out), std::forward<In>(in), start.revindex(0), stop.revindex(0),
- inshape.revindex(0), outidx, inshape.adapt(outidx));
+ for (index_t i3 = axis_start<out_axis, 3>(start); i3 < axis_stop<out_axis, 3>(stop); ++i3)
+ {
+ outidx = shape<outdims>{ i0, i1, i2, i3 };
+ tprocess_body<Out, In, out_axis, w, gw, Tin, outdims, indims>(
+ std::forward<Out>(out), std::forward<In>(in), start[out_axis], stop[out_axis],
+ in_size, outidx, inshape.adapt(outidx));
+ }
}
}
}
@@ -265,14 +340,15 @@ static auto tprocess(Out&& out, In&& in, shape<outdims> start = 0, shape<outdims
else
{
shape<outdims> outidx = start;
- if (CMT_UNLIKELY(!internal_generic::compare_indices(outidx, stop, outdims - 2)))
+ if (CMT_UNLIKELY(!internal_generic::compare_indices(outidx, stop)))
return stop;
do
{
- tprocess_body<Out, In, w, gw, Tin, outdims, indims>(
- std::forward<Out>(out), std::forward<In>(in), start.revindex(0), stop.revindex(0),
- inshape.revindex(0), outidx, inshape.adapt(outidx));
- } while (internal_generic::increment_indices(outidx, start, stop, outdims - 2));
+ tprocess_body<Out, In, out_axis, w, gw, Tin, outdims, indims>(
+ std::forward<Out>(out), std::forward<In>(in), start[out_axis], stop[out_axis], in_size,
+ outidx, inshape.adapt(outidx));
+ outidx[out_axis] = stop[out_axis] - 1;
+ } while (internal_generic::increment_indices(outidx, start, stop));
}
return stop;
}
@@ -477,7 +553,7 @@ struct expression_traits<T, std::enable_if_t<std::is_base_of_v<input_expression,
inline namespace CMT_ARCH_NAME
{
template <typename E, size_t N, KFR_ENABLE_IF(is_input_expression<E>), typename T = value_type_of<E>>
-KFR_MEM_INTRINSIC vec<T, N> get_elements(E&& self, const shape<1>& index, csize_t<N> sh)
+KFR_MEM_INTRINSIC vec<T, N> get_elements(E&& self, const shape<1>& index, const axis_params<0, N>& sh)
{
return get_elements(self, cinput_t{}, index[0], vec_shape<T, N>{});
}
@@ -511,7 +587,7 @@ struct xwitharguments
{
static_assert(idx < count);
using Traits = expression_traits<nth<idx>>;
- if constexpr (Traits::dims == 0)
+ if constexpr (sizeof...(Args) <= 1 || Traits::dims == 0)
{
return -1;
}
@@ -564,16 +640,63 @@ private:
}
};
-template <typename... Args>
-xwitharguments(Args&&... args) -> xwitharguments<Args...>;
-
-template <index_t Dims, typename Arg>
-struct xreshape : public xwitharguments<Arg>
+template <typename Arg>
+struct xwitharguments<Arg>
{
- shape<Dims> old_shape;
- shape<Dims> new_shape;
+ constexpr static size_t count = 1;
+
+ using type_list = ctypes_t<Arg>;
+
+ template <size_t idx>
+ using nth = Arg;
+
+ using first_arg = Arg;
+
+ template <size_t idx>
+ using nth_trait = expression_traits<Arg>;
+
+ using first_arg_trait = expression_traits<first_arg>;
+
+ std::tuple<Arg> args;
+
+ KFR_INTRINSIC auto& first() { return std::get<0>(args); }
+ KFR_INTRINSIC const auto& first() const { return std::get<0>(args); }
+
+ template <size_t idx>
+ KFR_INTRINSIC dimset getmask(csize_t<idx> = {}) const
+ {
+ return -1;
+ }
+
+ template <typename Fn>
+ KFR_INTRINSIC constexpr auto fold(Fn&& fn) const
+ {
+ return fold_impl(std::forward<Fn>(fn), csizeseq<count>);
+ }
+ template <typename Fn>
+ KFR_INTRINSIC constexpr static auto fold_idx(Fn&& fn)
+ {
+ return fold_idx_impl(std::forward<Fn>(fn), csizeseq<count>);
+ }
+
+ KFR_INTRINSIC xwitharguments(Arg&& arg) : args{ std::forward<Arg>(arg) } {}
+
+private:
+ template <typename Fn, size_t... indices>
+ KFR_INTRINSIC constexpr auto fold_impl(Fn&& fn, csizes_t<indices...>) const
+ {
+ return fn(std::get<indices>(args)...);
+ }
+ template <typename Fn, size_t... indices>
+ KFR_INTRINSIC constexpr static auto fold_idx_impl(Fn&& fn, csizes_t<indices...>)
+ {
+ return fn(csize<indices>...);
+ }
};
+template <typename... Args>
+xwitharguments(Args&&... args) -> xwitharguments<Args...>;
+
template <typename Fn, typename... Args>
struct xfunction : public xwitharguments<Args...>
{
@@ -617,29 +740,20 @@ struct expression_traits<xfunction<Fn, Args...>> : expression_traits_defaults
}
};
-template <index_t Dims, typename Arg>
-struct expression_traits<xreshape<Dims, Arg>> : expression_traits_defaults
-{
- using value_type = typename expression_traits<Arg>::value_type;
- constexpr static size_t dims = Dims;
-
- constexpr static shape<dims> shapeof(const xreshape<Dims, Arg>& self) { return self.new_shape; }
-};
-
inline namespace CMT_ARCH_NAME
{
namespace internal
{
-template <index_t outdims, typename Fn, typename... Args, size_t N, index_t Dims, size_t idx,
+template <index_t outdims, typename Fn, typename... Args, index_t Axis, size_t N, index_t Dims, size_t idx,
typename Traits = expression_traits<typename xfunction<Fn, Args...>::template nth<idx>>>
KFR_MEM_INTRINSIC vec<typename Traits::value_type, N> get_arg(const xfunction<Fn, Args...>& self,
- const shape<Dims>& index, csize_t<N> sh,
- csize_t<idx>)
+ const shape<Dims>& index,
+ const axis_params<Axis, N>& sh, csize_t<idx>)
{
if constexpr (Traits::dims == 0)
{
- return repeat<N>(get_elements(std::get<idx>(self.args), {}, csize_t<1>{}));
+ return repeat<N>(get_elements(std::get<idx>(self.args), {}, axis_params<Axis, 1>{}));
}
else
{
@@ -648,14 +762,14 @@ KFR_MEM_INTRINSIC vec<typename Traits::value_type, N> get_arg(const xfunction<Fn
if constexpr (last_dim > 0)
{
return repeat<N / std::min(last_dim, N)>(
- get_elements(std::get<idx>(self.args), indices, csize_t<std::min(last_dim, N)>{}));
+ get_elements(std::get<idx>(self.args), indices, axis_params<Axis, std::min(last_dim, N)>{}));
}
else
{
- if constexpr (N > 1)
+ if constexpr (sizeof...(Args) > 1 && N > 1)
{
if (CMT_UNLIKELY(self.masks[idx].back() == 0))
- return get_elements(std::get<idx>(self.args), indices, csize_t<1>{}).front();
+ return get_elements(std::get<idx>(self.args), indices, axis_params<Axis, 1>{}).front();
else
return get_elements(std::get<idx>(self.args), indices, sh);
}
@@ -668,10 +782,10 @@ KFR_MEM_INTRINSIC vec<typename Traits::value_type, N> get_arg(const xfunction<Fn
}
} // namespace internal
-template <typename Fn, typename... Args, size_t N, index_t Dims,
+template <typename Fn, typename... Args, index_t Axis, size_t N, index_t Dims,
typename Tr = expression_traits<xfunction<Fn, Args...>>, typename T = typename Tr::value_type>
KFR_INTRINSIC vec<T, N> get_elements(const xfunction<Fn, Args...>& self, const shape<Dims>& index,
- csize_t<N> sh)
+ const axis_params<Axis, N>& sh)
{
constexpr index_t outdims = Tr::dims;
return self.fold_idx(
diff --git a/include/kfr/base/new_expressions.hpp b/include/kfr/base/new_expressions.hpp
@@ -30,6 +30,8 @@
namespace kfr
{
+// ----------------------------------------------------------------------------
+
template <typename T, typename Arg>
struct xcastto : public xwitharguments<Arg>
{
@@ -66,20 +68,23 @@ struct expression_traits<xcastto<T, Arg>> : expression_traits_defaults
inline namespace CMT_ARCH_NAME
{
-template <typename T, typename Arg, index_t NDims, size_t N>
-KFR_INTRINSIC vec<T, N> get_elements(const xcastto<T, Arg>& self, const shape<NDims>& index, csize_t<N> sh)
+template <typename T, typename Arg, index_t NDims, index_t Axis, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const xcastto<T, Arg>& self, const shape<NDims>& index,
+ const axis_params<Axis, N>& sh)
{
return static_cast<vec<T, N>>(get_elements(self.first(), index, sh));
}
-template <typename T, typename Arg, index_t NDims, size_t N>
-KFR_INTRINSIC void set_elements(const xcastto<T, Arg>& self, const shape<NDims>& index, csize_t<N> sh,
- const identity<vec<T, N>>& value)
+template <typename T, typename Arg, index_t NDims, index_t Axis, size_t N>
+KFR_INTRINSIC void set_elements(const xcastto<T, Arg>& self, const shape<NDims>& index,
+ const axis_params<Axis, N>& sh, const identity<vec<T, N>>& value)
{
set_elements(self.first(), index, sh, value);
}
} // namespace CMT_ARCH_NAME
+// ----------------------------------------------------------------------------
+
template <typename T, index_t Dims, typename Fn>
struct xlambda
{
@@ -108,9 +113,9 @@ struct expression_traits<xlambda<T, Dims, Fn>> : expression_traits_defaults
inline namespace CMT_ARCH_NAME
{
-template <typename T, index_t Dims, typename Fn, size_t N>
+template <typename T, index_t Dims, typename Fn, index_t Axis, size_t N>
KFR_INTRINSIC vec<T, N> get_elements(const xlambda<T, Dims, Fn>& self, const shape<Dims>& index,
- csize_t<N> sh)
+ const axis_params<Axis, N>& sh)
{
if constexpr (std::is_callable_v<Fn, shape<Dims>, csize_t<N>>)
return self.fn(index, sh);
@@ -124,6 +129,8 @@ KFR_INTRINSIC vec<T, N> get_elements(const xlambda<T, Dims, Fn>& self, const sha
} // namespace CMT_ARCH_NAME
+// ----------------------------------------------------------------------------
+
template <typename Arg>
struct xpadded : public xwitharguments<Arg>
{
@@ -160,10 +167,10 @@ struct expression_traits<xpadded<Arg>> : expression_traits_defaults
inline namespace CMT_ARCH_NAME
{
-template <typename Arg, size_t N, typename Traits = expression_traits<xpadded<Arg>>,
+template <typename Arg, index_t Axis, size_t N, typename Traits = expression_traits<xpadded<Arg>>,
typename T = typename Traits::value_type>
KFR_INTRINSIC vec<T, N> get_elements(const xpadded<Arg>& self, const shape<Traits::dims>& index,
- csize_t<N> sh)
+ const axis_params<Axis, N>& sh)
{
if (index.ge(self.input_size))
{
@@ -188,6 +195,8 @@ KFR_INTRINSIC vec<T, N> get_elements(const xpadded<Arg>& self, const shape<Trait
} // namespace CMT_ARCH_NAME
+// ----------------------------------------------------------------------------
+
template <typename Arg>
struct xreverse : public xwitharguments<Arg>
{
@@ -217,7 +226,7 @@ struct expression_traits<xreverse<Arg>> : expression_traits_defaults
KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof(const xreverse<Arg>& self)
{
- return ArgTraits::shapeof(self);
+ return ArgTraits::shapeof(self.first());
}
KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof() { return ArgTraits::shapeof(); }
};
@@ -225,12 +234,220 @@ struct expression_traits<xreverse<Arg>> : expression_traits_defaults
inline namespace CMT_ARCH_NAME
{
-template <typename Arg, size_t N, typename Traits = expression_traits<xreverse<Arg>>,
+template <typename Arg, index_t Axis, size_t N, typename Traits = expression_traits<xreverse<Arg>>,
typename T = typename Traits::value_type>
KFR_INTRINSIC vec<T, N> get_elements(const xreverse<Arg>& self, const shape<Traits::dims>& index,
- csize_t<N> sh)
+ const axis_params<Axis, N>& sh)
+{
+ return reverse(get_elements(self.first(), self.input_shape.sub(index).sub(N), sh));
+}
+
+} // namespace CMT_ARCH_NAME
+
+// ----------------------------------------------------------------------------
+
+template <index_t... Values>
+struct static_shape
+{
+ constexpr static shape<sizeof...(Values)> get() { return { Values... }; }
+};
+
+template <typename Arg, typename Shape>
+struct xfixshape : public xwitharguments<Arg>
+{
+ using ArgTraits = typename xwitharguments<Arg>::first_arg_trait;
+
+ KFR_MEM_INTRINSIC xfixshape(Arg&& arg) : xwitharguments<Arg>{ std::forward<Arg>(arg) } {}
+};
+
+template <typename Arg, index_t... ShapeValues>
+KFR_INTRINSIC xfixshape<Arg, static_shape<ShapeValues...>> x_fixshape(Arg&& arg,
+ const static_shape<ShapeValues...>&)
+{
+ return { std::forward<Arg>(arg) };
+}
+
+template <typename Arg, index_t... ShapeValues>
+struct expression_traits<xfixshape<Arg, static_shape<ShapeValues...>>> : expression_traits_defaults
+{
+ using ArgTraits = expression_traits<Arg>;
+
+ using value_type = typename ArgTraits::value_type;
+ constexpr static size_t dims = ArgTraits::dims;
+
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof(
+ const xfixshape<Arg, static_shape<ShapeValues...>>& self)
+ {
+ return static_shape<ShapeValues...>::get();
+ }
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof() { return static_shape<ShapeValues...>::get(); }
+};
+
+inline namespace CMT_ARCH_NAME
+{
+
+template <typename Arg, typename Shape, index_t Axis, size_t N,
+ typename Traits = expression_traits<xfixshape<Arg, Shape>>,
+ typename T = typename Traits::value_type>
+KFR_INTRINSIC vec<T, N> get_elements(const xfixshape<Arg, Shape>& self, const shape<Traits::dims>& index,
+ const axis_params<Axis, N>& sh)
+{
+ return get_elements(self.first(), index, sh);
+}
+
+template <typename Arg, typename Shape, index_t Axis, size_t N,
+ typename Traits = expression_traits<xfixshape<Arg, Shape>>,
+ typename T = typename Traits::value_type>
+KFR_INTRINSIC void set_elements(xfixshape<Arg, Shape>& self, const shape<Traits::dims>& index,
+ const axis_params<Axis, N>& sh, const identity<vec<T, N>>& value)
{
- return reverse(get_elements(self.first(), self.input_shape - index - N, sh));
+ set_elements(self.first(), index, sh, value);
+}
+
+} // namespace CMT_ARCH_NAME
+
+// ----------------------------------------------------------------------------
+
+template <typename Arg, index_t OutDims>
+struct xreshape : public xwitharguments<Arg>
+{
+ using ArgTraits = typename xwitharguments<Arg>::first_arg_trait;
+ shape<ArgTraits::dims> in_shape;
+ shape<OutDims> out_shape;
+
+ KFR_MEM_INTRINSIC xreshape(Arg&& arg, const shape<OutDims>& out_shape)
+ : xwitharguments<Arg>{ std::forward<Arg>(arg) }, in_shape(ArgTraits::shapeof(arg)),
+ out_shape(out_shape)
+ {
+ }
+};
+
+template <typename Arg, index_t OutDims>
+KFR_INTRINSIC xreshape<Arg, OutDims> x_reshape(Arg&& arg, const shape<OutDims>& out_shape)
+{
+ return { std::forward<Arg>(arg), out_shape };
+}
+
+template <typename Arg, index_t OutDims>
+struct expression_traits<xreshape<Arg, OutDims>> : expression_traits_defaults
+{
+ using ArgTraits = expression_traits<Arg>;
+
+ using value_type = typename ArgTraits::value_type;
+ constexpr static size_t dims = OutDims;
+
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof(const xreshape<Arg, OutDims>& self)
+ {
+ return self.out_shape;
+ }
+ KFR_MEM_INTRINSIC constexpr static shape<dims> shapeof() { return { 0 }; }
+};
+
+inline namespace CMT_ARCH_NAME
+{
+
+namespace internal
+{
+} // namespace internal
+
+template <typename Arg, index_t outdims, index_t Axis, size_t N,
+ typename Traits = expression_traits<xreshape<Arg, outdims>>,
+ typename T = typename Traits::value_type>
+KFR_INTRINSIC vec<T, N> get_elements(const xreshape<Arg, outdims>& self, const shape<Traits::dims>& index,
+ const axis_params<Axis, N>& sh)
+{
+ using ArgTraits = typename Traits::ArgTraits;
+ constexpr index_t indims = ArgTraits::dims;
+ if constexpr (N == 1)
+ {
+ const shape<indims> idx = self.in_shape.from_flat(self.out_shape.to_flat(index));
+ return get_elements(self.first(), idx, axis_params<indims - 1, 1>{});
+ }
+ else
+ {
+ const shape<indims> first_idx = self.in_shape.from_flat(self.out_shape.to_flat(index));
+ const shape<indims> last_idx =
+ self.in_shape.from_flat(self.out_shape.to_flat(index.add_at(N - 1, cindex<Axis>)));
+
+ const shape<indims> diff_idx = last_idx.sub(first_idx);
+
+ vec<T, N> result;
+ bool done = false;
+
+ cforeach(cvalseq_t<index_t, indims, 0>{},
+ [&](auto n) CMT_INLINE_LAMBDA
+ {
+ constexpr index_t axis = val_of<decltype(n)>({});
+ if (!done && diff_idx[axis] == N - 1)
+ {
+ result = get_elements(self.first(), first_idx, axis_params<axis, N>{});
+ done = true;
+ }
+ });
+
+ if (!done)
+ {
+ portable_vec<T, N> tmp;
+ CMT_LOOP_NOUNROLL
+ for (size_t i = 0; i < N; ++i)
+ {
+ tmp[i] = get_elements(
+ self.first(),
+ self.in_shape.from_flat(self.out_shape.to_flat(index.add_at(i, cindex<Axis>))),
+ axis_params<indims - 1, 1>{})
+ .front();
+ }
+ result = tmp;
+ }
+ return result;
+ }
+}
+
+template <typename Arg, index_t outdims, index_t Axis, size_t N,
+ typename Traits = expression_traits<xreshape<Arg, outdims>>,
+ typename T = typename Traits::value_type>
+KFR_INTRINSIC void set_elements(xreshape<Arg, outdims>& self, const shape<Traits::dims>& index,
+ const axis_params<Axis, N>& sh, const identity<vec<T, N>>& value)
+{
+ using ArgTraits = typename Traits::ArgTraits;
+ constexpr index_t indims = ArgTraits::dims;
+ if constexpr (N == 1)
+ {
+ const shape<indims> idx = self.in_shape.from_flat(self.out_shape.to_flat(index));
+ set_elements(self.first(), idx, axis_params<indims - 1, 1>{}, value);
+ }
+ else
+ {
+ const shape<indims> first_idx = self.in_shape.from_flat(self.out_shape.to_flat(index));
+ const shape<indims> last_idx =
+ self.in_shape.from_flat(self.out_shape.to_flat(index.add_at(N - 1, cindex<Axis>)));
+
+ const shape<indims> diff_idx = last_idx.sub(first_idx);
+
+ bool done = false;
+
+ cforeach(cvalseq_t<index_t, indims, 0>{},
+ [&](auto n) CMT_INLINE_LAMBDA
+ {
+ constexpr index_t axis = val_of<decltype(n)>({});
+ if (!done && diff_idx[axis] == N - 1)
+ {
+ set_elements(self.first(), first_idx, axis_params<axis, N>{}, value);
+ done = true;
+ }
+ });
+
+ if (!done)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t i = 0; i < N; ++i)
+ {
+ set_elements(self.first(),
+ self.in_shape.from_flat(self.out_shape.to_flat(index.add_at(i, cindex<Axis>))),
+ axis_params<indims - 1, 1>{}, vec<T, 1>{ value[i] });
+ }
+ }
+ }
}
} // namespace CMT_ARCH_NAME
diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp
@@ -39,11 +39,23 @@ namespace kfr
#ifndef KFR_32BIT_INDICES
using index_t = size_t;
+#if SIZE_MAX == UINT64_MAX
+using signed_index_t = int64_t;
#else
-using index_t = uint32_t;
+using signed_index_t = int32_t;
+#endif
+#else
+using index_t = uint32_t;
+using signed_index_t = int32_t;
#endif
constexpr inline index_t max_index_t = std::numeric_limits<index_t>::max();
+template <index_t val>
+using cindex_t = cval_t<index_t, val>;
+
+template <index_t val>
+constexpr inline cindex_t<val> cindex{};
+
constexpr inline index_t infinite_size = max_index_t;
constexpr inline index_t maximum_dims = 8;
@@ -115,6 +127,13 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
result.back() += value;
return result;
}
+ template <index_t Axis>
+ shape add_at(index_t value, cval_t<index_t, Axis> = {}) const
+ {
+ shape result = *this;
+ result[Axis] += value;
+ return result;
+ }
shape add(const shape& other) const { return **this + *other; }
shape sub(const shape& other) const { return **this - *other; }
@@ -137,25 +156,50 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
KFR_MEM_INTRINSIC size_t to_flat(const shape<dims>& indices) const
{
- size_t result = 0;
- size_t scale = 1;
- for (size_t i = 0; i < dims; ++i)
+ if constexpr (dims == 1)
{
- result += scale * indices[dims - 1 - i];
- scale *= (*this)[dims - 1 - i];
+ return indices[0];
+ }
+ else if constexpr (dims == 2)
+ {
+ return (*this)[1] * indices[0] + indices[1];
+ }
+ else
+ {
+ size_t result = 0;
+ size_t scale = 1;
+ CMT_LOOP_UNROLL
+ for (size_t i = 0; i < dims; ++i)
+ {
+ result += scale * indices[dims - 1 - i];
+ scale *= (*this)[dims - 1 - i];
+ }
+ return result;
}
- return result;
}
KFR_MEM_INTRINSIC shape<dims> from_flat(size_t index) const
{
- shape<dims> indices;
- for (size_t i = 0; i < dims; ++i)
+ if constexpr (dims == 1)
{
- size_t sz = (*this)[dims - 1 - i];
- indices[dims - 1 - i] = index % sz;
- index /= sz;
+ return { index };
+ }
+ else if constexpr (dims == 2)
+ {
+ index_t sz = (*this)[1];
+ return { index / sz, index % sz };
+ }
+ else
+ {
+ shape<dims> indices;
+ CMT_LOOP_UNROLL
+ for (size_t i = 0; i < dims; ++i)
+ {
+ size_t sz = (*this)[dims - 1 - i];
+ indices[dims - 1 - i] = index % sz;
+ index /= sz;
+ }
+ return indices;
}
- return indices;
}
KFR_MEM_INTRINSIC index_t dot(const shape& other) const
@@ -401,19 +445,43 @@ constexpr KFR_INTRINSIC shape<outdims> compact_shape(const shape<dims>& in)
template <index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)>
bool can_assign_from(const shape<dims1>& dst_shape, const shape<dims2>& src_shape)
{
- for (size_t i = 0; i < outdims; ++i)
+ if constexpr (dims2 == 0)
{
- index_t dst_size = dst_shape.revindex(i);
- index_t src_size = src_shape.revindex(i);
- if (src_size == 1 || src_size == infinite_size || src_size == dst_size)
+ return true;
+ }
+ else
+ {
+ if constexpr (outdims >= 2)
{
+ vec<index_t, outdims> dst = padlow<outdims - dims1>(*dst_shape, 1);
+ vec<index_t, outdims> src = padlow<outdims - dims2>(*src_shape, 1);
+
+ mask<index_t, outdims> match = src + 1 <= 2 || src == dst;
+ return all(match);
}
else
{
- return false;
+ for (size_t i = 0; i < outdims; ++i)
+ {
+ index_t dst_size = dst_shape.revindex(i);
+ index_t src_size = src_shape.revindex(i);
+ if (src_size == 1 || src_size == infinite_size || src_size == dst_size)
+ {
+ }
+ else
+ {
+ return false;
+ }
+ }
+ return true;
}
}
- return true;
+}
+
+template <index_t dims>
+constexpr shape<dims> common_shape(const shape<dims>& shape)
+{
+ return shape;
}
template <index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)>
@@ -573,4 +641,16 @@ constexpr KFR_INTRINSIC index_t size_of_shape(const shape<dims>& shape)
return n;
}
+template <index_t Axis, index_t N>
+struct axis_params
+{
+ constexpr static index_t axis = Axis;
+ constexpr static index_t width = N;
+
+ constexpr axis_params() = default;
+};
+
+template <index_t Axis, index_t N>
+constexpr inline const axis_params<Axis, N> axis_params_v{};
+
} // namespace kfr
diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp
@@ -694,10 +694,10 @@ private:
}
T* m_data;
- index_t m_size;
- bool m_is_contiguous;
- shape_type m_shape;
- shape_type m_strides;
+ const index_t m_size;
+ const bool m_is_contiguous;
+ const shape_type m_shape;
+ const shape_type m_strides;
memory_finalizer m_finalizer;
};
@@ -740,33 +740,24 @@ struct expression_traits<tensor<T, Dims>> : expression_traits_defaults
inline namespace CMT_ARCH_NAME
{
-template <typename T, index_t NDims, size_t N>
-KFR_INTRINSIC vec<T, N> get_elements(const tensor<T, NDims>& self, const shape<NDims>& index, csize_t<N>)
+template <typename T, index_t NDims, index_t Axis, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const tensor<T, NDims>& self, const shape<NDims>& index,
+ const axis_params<Axis, N>&)
{
const T* data = self.data() + self.calc_index(index);
- if (self.is_last_contiguous())
- {
+ if (self.strides()[Axis] == 1)
return read<N>(data);
- }
- else
- {
- return gather_stride<N>(data, self.strides().back());
- }
+ return gather_stride<N>(data, self.strides()[Axis]);
}
-template <typename T, index_t NDims, size_t N>
-KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>& index, csize_t<N>,
- const identity<vec<T, N>>& value)
+template <typename T, index_t NDims, index_t Axis, size_t N>
+KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>& index,
+ const axis_params<Axis, N>&, const identity<vec<T, N>>& value)
{
T* data = self.data() + self.calc_index(index);
- if (self.is_last_contiguous())
- {
- write(data, value);
- }
- else
- {
- scatter_stride(data, value, self.strides().back());
- }
+ if (self.strides()[Axis] == 1)
+ return write(data, value);
+ scatter_stride(data, value, self.strides()[Axis]);
}
template <typename T, index_t dims1, index_t dims2, typename Fn, index_t outdims = const_max(dims1, dims2)>
diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp
@@ -36,14 +36,10 @@ TEST(shape)
CHECK(internal_generic::strides_for_shape(shape{ 2, 3, 4 }, 10) == shape{ 120, 40, 10 });
- CHECK(increment_indices_return(shape{ 0, 0, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
- shape{ 0, 0, 1 });
- CHECK(increment_indices_return(shape{ 0, 0, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
- shape{ 0, 1, 0 });
- CHECK(increment_indices_return(shape{ 0, 2, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
- shape{ 0, 2, 1 });
- CHECK(increment_indices_return(shape{ 0, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
- shape{ 1, 0, 0 });
+ CHECK(increment_indices_return(shape{ 0, 0, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 0, 1 });
+ CHECK(increment_indices_return(shape{ 0, 0, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 1, 0 });
+ CHECK(increment_indices_return(shape{ 0, 2, 0 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 0, 2, 1 });
+ CHECK(increment_indices_return(shape{ 0, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) == shape{ 1, 0, 0 });
CHECK(increment_indices_return(shape{ 1, 2, 3 }, shape{ 0, 0, 0 }, shape{ 2, 3, 4 }) ==
shape{ null_index, null_index, null_index });
@@ -223,7 +219,7 @@ TEST(tensor_broadcast)
tensor<float, 2> t2{ shape{ 5, 1 }, { 10.f, 20.f, 30.f, 40.f, 50.f } };
tensor<float, 1> t4{ shape{ 5 }, { 1.f, 2.f, 3.f, 4.f, 5.f } };
tensor<float, 2> tresult{ shape{ 5, 5 }, { 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33,
- 34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, 55 } };
+ 34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, 55 } };
tensor<float, 2> t3 = tapply(t1, t2, fn::add{});
@@ -249,7 +245,10 @@ template <typename T, index_t Dims = 1>
struct tcounter
{
T start;
- std::array<T, Dims> steps;
+ T steps[Dims];
+
+ T back() const { return steps[Dims - 1]; }
+ T front() const { return steps[0]; }
};
template <typename T, index_t Dims>
@@ -284,52 +283,68 @@ struct expression_traits<std::array<std::array<T, N1>, N2>> : expression_traits_
inline namespace CMT_ARCH_NAME
{
-template <typename T, size_t N>
-KFR_INTRINSIC vec<T, N> get_elements(const tcounter<T, 1>& self, const shape<1>& index, csize_t<N> sh)
+template <typename T, index_t Axis, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const tcounter<T, 1>& self, const shape<1>& index,
+ const axis_params<Axis, N>&)
{
T acc = self.start;
- acc += static_cast<T>(index.front()) * self.steps.front();
- return acc + enumerate(vec_shape<T, N>(), self.steps.back());
+ acc += static_cast<T>(index.back()) * self.back();
+ return acc + enumerate(vec_shape<T, N>(), self.back());
}
-template <typename T, index_t dims, size_t N>
-KFR_INTRINSIC vec<T, N> get_elements(const tcounter<T, dims>& self, const shape<dims>& index, csize_t<N> sh)
+template <typename T, index_t dims, index_t Axis, size_t N>
+KFR_INTRINSIC vec<T, N> get_elements(const tcounter<T, dims>& self, const shape<dims>& index,
+ const axis_params<Axis, N>&)
{
T acc = self.start;
vec<T, dims> tindices = cast<T>(*index);
cfor(csize<0>, csize<dims>, [&](auto i) CMT_INLINE_LAMBDA { acc += tindices[i] * self.steps[i]; });
- return acc + enumerate(vec_shape<T, N>(), self.steps.back());
+ return acc + enumerate(vec_shape<T, N>(), self.steps[Axis]);
}
-template <typename T, size_t N1, size_t N>
+template <typename T, size_t N1, index_t Axis, size_t N>
KFR_INTRINSIC vec<T, N> get_elements(const std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index,
- csize_t<N> sh)
+ const axis_params<Axis, N>&)
{
const T* CMT_RESTRICT const data = self.data();
- return read<N>(data + std::min(index[0], static_cast<index_t>(N1 - 1)));
+ return read<N>(data + index[0]);
}
-template <typename T, size_t N1, size_t N>
-KFR_INTRINSIC void set_elements(std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index, csize_t<N>,
- const identity<vec<T, N>>& val)
+template <typename T, size_t N1, index_t Axis, size_t N>
+KFR_INTRINSIC void set_elements(std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index,
+ const axis_params<Axis, N>&, const identity<vec<T, N>>& val)
{
T* CMT_RESTRICT const data = self.data();
- write(data + std::min(index[0], static_cast<index_t>(N1 - 1)), val);
+ write(data + index[0], val);
}
-template <typename T, size_t N1, size_t N2, size_t N>
+template <typename T, size_t N1, size_t N2, index_t Axis, size_t N>
KFR_INTRINSIC vec<T, N> get_elements(const std::array<std::array<T, N1>, N2>& CMT_RESTRICT self,
- const shape<2>& index, csize_t<N> sh)
+ const shape<2>& index, const axis_params<Axis, N>&)
{
- const T* CMT_RESTRICT const data = self[std::min(index[0], static_cast<index_t>(N2 - 1))].data();
- return read<N>(data + std::min(index[1], static_cast<index_t>(N1 - 1)));
+ const T* CMT_RESTRICT const data = self.front().data() + index.front() * N1 + index.back();
+ if constexpr (Axis == 1)
+ {
+ return read<N>(data);
+ }
+ else
+ {
+ return gather_stride<N>(data, N1);
+ }
}
-template <typename T, size_t N1, size_t N2, size_t N>
+template <typename T, size_t N1, size_t N2, index_t Axis, size_t N>
KFR_INTRINSIC void set_elements(std::array<std::array<T, N1>, N2>& CMT_RESTRICT self, const shape<2>& index,
- csize_t<N>, const identity<vec<T, N>>& val)
+ const axis_params<Axis, N>&, const identity<vec<T, N>>& val)
{
- T* CMT_RESTRICT const data = self[std::min(index[0], static_cast<index_t>(N2 - 1))].data();
- write(data + std::min(index[1], static_cast<index_t>(N1 - 1)), val);
+ T* CMT_RESTRICT data = self.front().data() + index.front() * N1 + index.back();
+ if constexpr (Axis == 1)
+ {
+ write(data, val);
+ }
+ else
+ {
+ scatter_stride(data, val, N1);
+ }
}
TEST(tensor_expressions2)
@@ -337,12 +352,12 @@ TEST(tensor_expressions2)
auto aa = std::array<std::array<double, 2>, 2>{ { { { 1, 2 } }, { { 3, 4 } } } };
static_assert(expression_traits<decltype(aa)>::dims == 2);
CHECK(expression_traits<decltype(aa)>::shapeof(aa) == shape{ 2, 2 });
- CHECK(get_elements(aa, { 1, 1 }, csize_t<1>{}) == vec{ 4. });
- CHECK(get_elements(aa, { 1, 0 }, csize_t<2>{}) == vec{ 3., 4. });
+ CHECK(get_elements(aa, { 1, 1 }, axis_params<1, 1>{}) == vec{ 4. });
+ CHECK(get_elements(aa, { 1, 0 }, axis_params<1, 2>{}) == vec{ 3., 4. });
static_assert(expression_traits<decltype(1234.f)>::dims == 0);
CHECK(expression_traits<decltype(1234.f)>::shapeof(1234.f) == shape{});
- CHECK(get_elements(1234.f, {}, csize_t<3>{}) == vec{ 1234.f, 1234.f, 1234.f });
+ CHECK(get_elements(1234.f, {}, axis_params<0, 3>{}) == vec{ 1234.f, 1234.f, 1234.f });
tprocess(aa, 123.45f);
@@ -360,13 +375,13 @@ TEST(tensor_counter)
{
std::array<double, 6> x;
- tprocess(x, tcounter<double>{ 0.0, { { 0.5 } } });
+ tprocess(x, tcounter<double>{ 0.0, { 0.5 } });
CHECK(x == std::array<double, 6>{ { 0.0, 0.5, 1.0, 1.5, 2.0, 2.5 } });
std::array<std::array<double, 4>, 3> y;
- tprocess(y, tcounter<double, 2>{ 100.0, { { 1.0, 10.0 } } });
+ tprocess(y, tcounter<double, 2>{ 100.0, { 1.0, 10.0 } });
CHECK(y == std::array<std::array<double, 4>, 3>{ {
{ { 100.0, 110.0, 120.0, 130.0 } },
@@ -379,7 +394,7 @@ DTEST(tensor_dims)
{
tensor<double, 6> t12{ shape{ 2, 3, 4, 5, 6, 7 } };
- tprocess(t12, tcounter<double, 6>{ 0, { { 1, 10, 100, 1000, 10000, 100000 } } });
+ tprocess(t12, tcounter<double, 6>{ 0, { 1, 10, 100, 1000, 10000, 100000 } });
auto t1 = t12(1, 2, 3, tall(), 5, 6);
CHECK(render(t1) == univector<double>{ 650321, 651321, 652321, 653321, 654321 });
@@ -444,6 +459,13 @@ TEST(xfunction_test)
{ { 501.f, 502.f, 503.f, 504.f, 505.f } } } });
}
+TEST(xreshape)
+{
+ std::array<float, 12> x;
+ tprocess(x_reshape(x, shape{ 3, 4 }), tcounter<float, 2>{ 0, { 10, 1 } });
+ CHECK(x == std::array<float, 12>{ { 0, 1, 2, 3, 10, 11, 12, 13, 20, 21, 22, 23 } });
+}
+
} // namespace CMT_ARCH_NAME
#ifdef _MSC_VER
@@ -467,17 +489,17 @@ extern "C" __declspec(dllexport) bool assembly_test3(std::array<double, 16>& x)
extern "C" __declspec(dllexport) bool assembly_test4(std::array<double, 16>& x)
{
- return tprocess(x, tcounter<double>{ 1000.0, { { 1.0 } } }).front() > 0;
+ return tprocess(x, tcounter<double>{ 1000.0, { 1.0 } }).front() > 0;
}
extern "C" __declspec(dllexport) bool assembly_test5(const tensor<double, 3>& x)
{
- return tprocess(x, tcounter<double, 3>{ 1000.0, { { 1.0, 2.0, 3.0 } } }).front() > 0;
+ return tprocess(x, tcounter<double, 3>{ 1000.0, { 1.0, 2.0, 3.0 } }).front() > 0;
}
extern "C" __declspec(dllexport) bool assembly_test6(const tensor<double, 2>& x)
{
- return tprocess(x, tcounter<double, 2>{ 1000.0, { { 1.0, 2.0 } } }).front() > 0;
+ return tprocess(x, tcounter<double, 2>{ 1000.0, { 1.0, 2.0 } }).front() > 0;
}
extern "C" __declspec(dllexport) bool assembly_test7(const tensor<double, 2>& x)
@@ -529,6 +551,74 @@ extern "C" __declspec(dllexport) void assembly_test13(const tensor<float, 1>& x,
// static_assert(sh2 == shape{ 4, 4 });
tprocess(x, y * 0.5f);
}
+
+template <typename T, size_t N1, size_t N2>
+using array2d = std::array<std::array<T, N2>, N1>;
+
+extern "C" __declspec(dllexport) void assembly_test14(std::array<float, 32>& x,
+ const std::array<float, 32>& y)
+{
+ tprocess(x, x_reverse(y));
+}
+
+extern "C" __declspec(dllexport) void assembly_test15(array2d<float, 32, 32>& x,
+ const array2d<float, 32, 32>& y)
+{
+ tprocess(x, x_reverse(y));
+}
+
+extern "C" __declspec(dllexport) void assembly_test16a(array2d<double, 8, 2>& x,
+ const array2d<double, 8, 2>& y)
+{
+ tprocess<8, 0>(x, y * y);
+}
+extern "C" __declspec(dllexport) void assembly_test16b(array2d<double, 8, 2>& x,
+ const array2d<double, 8, 2>& y)
+{
+ tprocess<2, 1>(x, y * y);
+}
+
+extern "C" __declspec(dllexport) void assembly_test17a(const tensor<double, 2>& x, const tensor<double, 2>& y)
+{
+ xfunction ysqr = xfunction{ xwitharguments{ y }, fn::sqr{} };
+ tprocess<8, 0>(x, ysqr);
+}
+extern "C" __declspec(dllexport) void assembly_test17b(const tensor<double, 2>& x, const tensor<double, 2>& y)
+{
+ xfunction ysqr = xfunction{ xwitharguments{ y }, fn::sqr{} };
+ tprocess<2, 1>(x, ysqr);
+}
+
+extern "C" __declspec(dllexport) void assembly_test18a(const tensor<double, 2>& x, const tensor<double, 2>& y)
+{
+ xfunction ysqr = xfunction{ xwitharguments{ y }, fn::sqr{} };
+ tprocess<8, 0>(x_fixshape(x, static_shape<8, 2>{}), x_fixshape(ysqr, static_shape<8, 2>{}));
+}
+extern "C" __declspec(dllexport) void assembly_test18b(const tensor<double, 2>& x, const tensor<double, 2>& y)
+{
+ xfunction ysqr = xfunction{ xwitharguments{ y }, fn::sqr{} };
+ tprocess<2, 1>(x_fixshape(x, static_shape<8, 2>{}), x_fixshape(ysqr, static_shape<8, 2>{}));
+}
+
+extern "C" __declspec(dllexport) void assembly_test19(const tensor<double, 2>& x,
+ const xreshape<tensor<double, 2>, 2>& y)
+{
+ tprocess(x, y);
+}
+
+extern "C" __declspec(dllexport) shape<2> assembly_test20_2(const shape<2>& x, size_t fl)
+{
+ return x.from_flat(fl);
+}
+extern "C" __declspec(dllexport) shape<4> assembly_test20_4(const shape<4>& x, size_t fl)
+{
+ return x.from_flat(fl);
+}
+
+extern "C" __declspec(dllexport) shape<4> assembly_test21(const shape<4>& x, size_t fl)
+{
+ return x.from_flat(fl);
+}
#endif
struct val