commit 91a87c91fbf5e591943e9799eb99988080593515
parent 1e423705833f3d0fc0dcb5262901d61521019019
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Tue, 15 Nov 2022 06:44:23 +0000
Tensor small refactoring
Diffstat:
8 files changed, 233 insertions(+), 153 deletions(-)
diff --git a/include/kfr/base/basic_expressions.hpp b/include/kfr/base/basic_expressions.hpp
@@ -131,7 +131,7 @@ KFR_INTRINSIC vec<T, N> get_elements(const expression_counter<T, dims>& self, co
const axis_params<Axis, N>&)
{
T acc = self.start;
- vec<T, dims> tindices = cast<T>(*index);
+ vec<T, dims> tindices = cast<T>(to_vec(index));
cfor(csize<0>, csize<dims>, [&](auto i) CMT_INLINE_LAMBDA { acc += tindices[i] * self.steps[i]; });
return acc + enumerate(vec_shape<T, N>(), self.steps[Axis]);
}
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -26,6 +26,7 @@
#pragma once
#include "../simd/platform.hpp"
+#include "../simd/read_write.hpp"
#include "../simd/shuffle.hpp"
#include "../simd/types.hpp"
#include "../simd/vec.hpp"
@@ -234,6 +235,13 @@ struct anything
inline namespace CMT_ARCH_NAME
{
+
+template <index_t Dims, typename U = unsigned_type<sizeof(index_t) * 8>>
+KFR_INTRINSIC vec<U, Dims> to_vec(const shape<Dims>& sh)
+{
+ return read<Dims>(reinterpret_cast<const U*>(sh.data()));
+}
+
namespace internal
{
template <size_t width, typename Fn>
@@ -306,11 +314,11 @@ struct expression_with_arguments
using Traits = expression_traits<nth<idx>>;
if constexpr (sizeof...(Args) <= 1 || Traits::dims == 0)
{
- return -1;
+ return dimset(-1);
}
else
{
- if constexpr (Traits::get_shape().cproduct() > 0)
+ if constexpr (Traits::get_shape().product() > 0)
{
return Traits::get_shape().tomask();
}
@@ -381,7 +389,7 @@ struct expression_with_arguments<Arg>
template <size_t idx>
KFR_MEM_INTRINSIC dimset getmask(csize_t<idx> = {}) const
{
- return -1;
+ return dimset(-1);
}
template <typename Fn>
diff --git a/include/kfr/base/impl/static_array.hpp b/include/kfr/base/impl/static_array.hpp
@@ -27,7 +27,6 @@
#include "../../cometa.hpp"
#include "../../kfr.h"
-#include "../../simd/read_write.hpp"
namespace kfr
{
@@ -39,6 +38,9 @@ using type_for = T;
template <typename T, typename indices_t>
struct static_array_base;
+template <typename T, size_t Size>
+using static_array_of_size = static_array_base<T, csizeseq_t<Size>>;
+
template <typename T, size_t... indices>
struct static_array_base<T, csizes_t<indices...>>
{
@@ -54,18 +56,22 @@ struct static_array_base<T, csizes_t<indices...>>
constexpr static size_t static_size = sizeof...(indices);
- constexpr static_array_base() : array{ (static_cast<void>(indices), 0)... } {}
- constexpr static_array_base(const static_array_base&) = default;
- constexpr static_array_base(static_array_base&&) = default;
+ constexpr static_array_base() noexcept : array{ (static_cast<void>(indices), 0)... } {}
+ constexpr static_array_base(const static_array_base&) noexcept = default;
+ constexpr static_array_base(static_array_base&&) noexcept = default;
- KFR_MEM_INTRINSIC constexpr static_array_base(type_for<value_type, indices>... args) : array{ args... } {}
+ KFR_MEM_INTRINSIC constexpr static_array_base(type_for<value_type, indices>... args) noexcept
+ : array{ args... }
+ {
+ }
template <typename U, typename otherindices_t>
friend struct static_array_base;
template <size_t... idx1, size_t... idx2>
- KFR_MEM_INTRINSIC constexpr static_array_base(const static_array_base<T, csizes_t<idx1...>>& first,
- const static_array_base<T, csizes_t<idx2...>>& second)
+ KFR_MEM_INTRINSIC constexpr static_array_base(
+ const static_array_base<T, csizes_t<idx1...>>& first,
+ const static_array_base<T, csizes_t<idx2...>>& second) noexcept
: array{ (indices >= sizeof...(idx1) ? second.array[indices - sizeof...(idx1)]
: first.array[indices])... }
{
@@ -73,13 +79,20 @@ struct static_array_base<T, csizes_t<indices...>>
}
template <size_t... idx>
- constexpr static_array_base<T, csizeseq_t<sizeof...(idx)>> shuffle(csizes_t<idx...>) const
+ constexpr static_array_base<T, csizeseq_t<sizeof...(idx)>> shuffle(csizes_t<idx...>) const noexcept
{
return static_array_base<T, csizeseq_t<sizeof...(idx)>>{ array[idx]... };
}
+ template <size_t... idx>
+ constexpr static_array_base<T, csizeseq_t<sizeof...(idx)>> shuffle(csizes_t<idx...>,
+ T filler) const noexcept
+ {
+ return static_array_base<T, csizeseq_t<sizeof...(idx)>>{ (idx >= static_size ? filler
+ : array[idx])... };
+ }
template <size_t start, size_t size>
- constexpr static_array_base<T, csizeseq_t<size>> slice() const
+ constexpr static_array_base<T, csizeseq_t<size>> slice() const noexcept
{
return shuffle(csizeseq<size, start>);
}
@@ -88,47 +101,128 @@ struct static_array_base<T, csizes_t<indices...>>
constexpr static_array_base& operator=(static_array_base&&) = default;
template <int dummy = 0, CMT_ENABLE_IF(dummy == 0 && static_size > 1)>
- KFR_MEM_INTRINSIC constexpr explicit static_array_base(value_type value)
+ KFR_MEM_INTRINSIC constexpr explicit static_array_base(value_type value) noexcept
: array{ (static_cast<void>(indices), value)... }
{
}
- KFR_MEM_INTRINSIC vec<T, static_size> operator*() const { return read<static_size>(data()); }
+ KFR_MEM_INTRINSIC constexpr const value_type* data() const noexcept { return std::data(array); }
+ KFR_MEM_INTRINSIC constexpr value_type* data() noexcept { return std::data(array); }
- KFR_MEM_INTRINSIC static_array_base(const vec<T, static_size>& v) { write(data(), v); }
+ KFR_MEM_INTRINSIC constexpr const_iterator begin() const noexcept { return std::begin(array); }
+ KFR_MEM_INTRINSIC constexpr iterator begin() noexcept { return std::begin(array); }
+ KFR_MEM_INTRINSIC constexpr const_iterator cbegin() const noexcept { return std::begin(array); }
- KFR_MEM_INTRINSIC constexpr const value_type* data() const { return std::data(array); }
- KFR_MEM_INTRINSIC constexpr value_type* data() { return std::data(array); }
+ KFR_MEM_INTRINSIC constexpr const_iterator end() const noexcept { return std::end(array); }
+ KFR_MEM_INTRINSIC constexpr iterator end() noexcept { return std::end(array); }
+ KFR_MEM_INTRINSIC constexpr const_iterator cend() const noexcept { return std::end(array); }
- KFR_MEM_INTRINSIC constexpr const_iterator begin() const { return std::begin(array); }
- KFR_MEM_INTRINSIC constexpr iterator begin() { return std::begin(array); }
- KFR_MEM_INTRINSIC constexpr const_iterator cbegin() const { return std::begin(array); }
-
- KFR_MEM_INTRINSIC constexpr const_iterator end() const { return std::end(array); }
- KFR_MEM_INTRINSIC constexpr iterator end() { return std::end(array); }
- KFR_MEM_INTRINSIC constexpr const_iterator cend() const { return std::end(array); }
-
- KFR_MEM_INTRINSIC constexpr const_reference operator[](size_t index) const { return array[index]; }
- KFR_MEM_INTRINSIC constexpr reference operator[](size_t index) { return array[index]; }
+ KFR_MEM_INTRINSIC constexpr const_reference operator[](size_t index) const noexcept
+ {
+ return array[index];
+ }
+ KFR_MEM_INTRINSIC constexpr reference operator[](size_t index) noexcept { return array[index]; }
- KFR_MEM_INTRINSIC constexpr const_reference front() const { return array[0]; }
- KFR_MEM_INTRINSIC constexpr reference front() { return array[0]; }
+ KFR_MEM_INTRINSIC constexpr const_reference front() const noexcept { return array[0]; }
+ KFR_MEM_INTRINSIC constexpr reference front() noexcept { return array[0]; }
- KFR_MEM_INTRINSIC constexpr const_reference back() const { return array[static_size - 1]; }
- KFR_MEM_INTRINSIC constexpr reference back() { return array[static_size - 1]; }
+ KFR_MEM_INTRINSIC constexpr const_reference back() const noexcept { return array[static_size - 1]; }
+ KFR_MEM_INTRINSIC constexpr reference back() noexcept { return array[static_size - 1]; }
- KFR_MEM_INTRINSIC constexpr bool empty() const { return false; }
+ KFR_MEM_INTRINSIC constexpr bool empty() const noexcept { return false; }
- KFR_MEM_INTRINSIC constexpr size_t size() const { return std::size(array); }
+ KFR_MEM_INTRINSIC constexpr size_t size() const noexcept { return std::size(array); }
- KFR_MEM_INTRINSIC constexpr bool operator==(const static_array_base& other) const
+ KFR_MEM_INTRINSIC constexpr bool operator==(const static_array_base& other) const noexcept
{
return ((array[indices] == other.array[indices]) && ...);
}
- KFR_MEM_INTRINSIC constexpr bool operator!=(const static_array_base& other) const
+ KFR_MEM_INTRINSIC constexpr bool operator!=(const static_array_base& other) const noexcept
{
return !operator==(other);
}
+ constexpr T minof() const noexcept
+ {
+ T result = std::numeric_limits<T>::max();
+ (static_cast<void>(result = std::min(result, array[indices])), ...);
+ return result;
+ }
+ constexpr T maxof() const noexcept
+ {
+ T result = std::numeric_limits<T>::lowest();
+ (static_cast<void>(result = std::max(result, array[indices])), ...);
+ return result;
+ }
+ constexpr T sum() const noexcept
+ {
+ T result = 0;
+ (static_cast<void>(result += array[indices]), ...);
+ return result;
+ }
+ constexpr T product() const noexcept
+ {
+ T result = 1;
+ (static_cast<void>(result *= array[indices]), ...);
+ return result;
+ }
+
+ constexpr static_array_base min(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ std::min(array[indices], y.array[indices])... };
+ }
+ constexpr static_array_base max(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ std::max(array[indices], y.array[indices])... };
+ }
+ template <typename Fn>
+ constexpr static_array_base bin(const static_array_base& y, Fn&& fn) const noexcept
+ {
+ return static_array_base{ fn(array[indices], y.array[indices])... };
+ }
+ template <typename Fn>
+ constexpr static_array_base un(Fn&& fn) const noexcept
+ {
+ return static_array_base{ fn(array[indices])... };
+ }
+ template <typename U>
+ constexpr static_array_base<U, csizes_t<indices...>> cast() const noexcept
+ {
+ return static_array_base<U, csizes_t<indices...>>{ static_cast<U>(array[indices])... };
+ }
+
+ constexpr static_array_base operator+(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ (array[indices] + y.array[indices])... };
+ }
+ constexpr static_array_base operator-(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ (array[indices] - y.array[indices])... };
+ }
+ constexpr static_array_base operator*(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ (array[indices] * y.array[indices])... };
+ }
+ constexpr static_array_base operator&(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ (array[indices] & y.array[indices])... };
+ }
+ constexpr static_array_base operator|(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ (array[indices] | y.array[indices])... };
+ }
+ constexpr static_array_base operator^(const static_array_base& y) const noexcept
+ {
+ return static_array_base{ (array[indices] ^ y.array[indices])... };
+ }
+ constexpr static_array_base operator+(const T& y) const noexcept
+ {
+ return static_array_base{ (array[indices] + y)... };
+ }
+ constexpr static_array_base operator-(const T& y) const noexcept
+ {
+ return static_array_base{ (array[indices] - y)... };
+ }
+ constexpr T dot(const static_array_base& y) const noexcept { return (operator*(y)).sum(); }
private:
T array[static_size];
diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp
@@ -33,7 +33,6 @@
#include "../simd/min_max.hpp"
#include "../simd/shuffle.hpp"
#include "../simd/types.hpp"
-#include "../simd/vec.hpp"
#include <bitset>
#include <optional>
@@ -85,7 +84,7 @@ CMT_INTRINSIC constexpr size_t size_min(size_t x, size_t y, Ts... rest) CMT_NOEX
return size_min(x < y ? x : y, rest...);
}
-using dimset = vec<i8, maximum_dims>;
+using dimset = static_array_of_size<i8, maximum_dims>; // std::array<i8, maximum_dims>;
template <index_t dims>
struct shape;
@@ -100,12 +99,11 @@ KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& st
template <index_t dims>
struct shape : static_array_base<index_t, csizeseq_t<dims>>
{
- using static_array_base<index_t, csizeseq_t<dims>>::static_array_base;
+ using base = static_array_base<index_t, csizeseq_t<dims>>;
- constexpr shape(const static_array_base<index_t, csizeseq_t<dims>>& a)
- : static_array_base<index_t, csizeseq_t<dims>>(a)
- {
- }
+ using base::base;
+
+ constexpr shape(const base& a) : base(a) {}
static_assert(dims < maximum_dims);
@@ -165,22 +163,22 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
}
}
- shape add(index_t value) const
+ constexpr shape add(index_t value) const
{
shape result = *this;
result.back() += value;
return result;
}
template <index_t Axis>
- shape add_at(index_t value, cval_t<index_t, Axis> = {}) const
+ constexpr shape add_at(index_t value, cval_t<index_t, Axis> = {}) const
{
shape result = *this;
result[Axis] += value;
return result;
}
- shape add(const shape& other) const { return **this + *other; }
- shape sub(const shape& other) const { return **this - *other; }
- index_t sum() const { return hsum(**this); }
+ constexpr shape add(const shape& other) const { return **this + *other; }
+ constexpr shape sub(const shape& other) const { return **this - *other; }
+ constexpr index_t sum() const { return (*this)->sum(); }
constexpr bool has_infinity() const
{
@@ -192,40 +190,43 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
return false;
}
- friend shape add_shape(const shape& lhs, const shape& rhs)
+ friend constexpr shape add_shape(const shape& lhs, const shape& rhs)
{
- vec<index_t, dims> x = *lhs;
- vec<index_t, dims> y = *rhs;
- mask<index_t, dims> inf = max(x, y) == infinite_size;
- return select(inf, infinite_size, x + y);
+ return lhs.bin(rhs, [](index_t x, index_t y) { return std::max(std::max(x, y), x + y); });
}
- friend shape sub_shape(const shape& lhs, const shape& rhs)
+ friend constexpr shape sub_shape(const shape& lhs, const shape& rhs)
{
- vec<index_t, dims> x = *lhs;
- vec<index_t, dims> y = *rhs;
- mask<index_t, dims> inf = max(x, y) == infinite_size;
- return select(inf, infinite_size, x - y);
+ return lhs.bin(rhs, [](index_t x, index_t y)
+ { return std::max(x, y) == infinite_size ? infinite_size : x - y; });
}
- friend shape add_shape_undef(const shape& lhs, const shape& rhs)
+ friend constexpr shape add_shape_undef(const shape& lhs, const shape& rhs)
{
- vec<index_t, dims> x = *lhs;
- vec<index_t, dims> y = *rhs;
- mask<index_t, dims> inf = max(x, y) == infinite_size;
- mask<index_t, dims> undef = min(x, y) == undefined_size;
- return select(inf, infinite_size, select(undef, undefined_size, x + y));
+ return lhs.bin(rhs,
+ [](index_t x, index_t y)
+ {
+ bool inf = std::max(x, y) == infinite_size;
+ bool undef = std::min(x, y) == undefined_size;
+ return inf ? infinite_size : undef ? undefined_size : x + y;
+ });
}
- friend shape sub_shape_undef(const shape& lhs, const shape& rhs)
+ friend constexpr shape sub_shape_undef(const shape& lhs, const shape& rhs)
{
- vec<index_t, dims> x = *lhs;
- vec<index_t, dims> y = *rhs;
- mask<index_t, dims> inf = max(x, y) == infinite_size;
- mask<index_t, dims> undef = min(x, y) == undefined_size;
- return select(inf, infinite_size, select(undef, undefined_size, x - y));
+ return lhs.bin(rhs,
+ [](index_t x, index_t y)
+ {
+ bool inf = std::max(x, y) == infinite_size;
+ bool undef = std::min(x, y) == undefined_size;
+ return inf ? infinite_size : undef ? undefined_size : x - y;
+ });
}
- friend shape min(const shape& x, const shape& y) { return kfr::min(*x, *y); }
+ friend constexpr shape min(const shape& x, const shape& y) { return x->min(*y); }
+
+ constexpr const base& operator*() const { return static_cast<const base&>(*this); }
+
+ constexpr const base* operator->() const { return static_cast<const base*>(this); }
- KFR_MEM_INTRINSIC size_t to_flat(const shape<dims>& indices) const
+ KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<dims>& indices) const
{
if constexpr (dims == 1)
{
@@ -248,7 +249,7 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
return result;
}
}
- KFR_MEM_INTRINSIC shape<dims> from_flat(size_t index) const
+ KFR_MEM_INTRINSIC constexpr shape<dims> from_flat(size_t index) const
{
if constexpr (dims == 1)
{
@@ -273,41 +274,20 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
}
}
- KFR_MEM_INTRINSIC index_t dot(const shape& other) const
- {
- if constexpr (dims == 1)
- {
- return (*this)[0] * other[0];
- }
- else if constexpr (dims == 2)
- {
- return (*this)[0] * other[0] + (*this)[1] * other[1];
- }
- else
- {
- return hdot(**this, *other);
- }
- }
+ KFR_MEM_INTRINSIC constexpr index_t dot(const shape& other) const { return (*this)->dot(*other); }
template <index_t indims>
- KFR_MEM_INTRINSIC shape adapt(const shape<indims>& other) const
+ KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other) const
{
static_assert(indims >= dims);
- return min(other.template trim<dims>(), **this - 1);
+ return other.template trim<dims>()->min(**this - 1);
}
- KFR_MEM_INTRINSIC index_t product() const { return hproduct(**this); }
- KFR_MEM_INTRINSIC constexpr index_t cproduct() const
- {
- index_t result = this->front();
- for (index_t i = 1; i < dims; i++)
- result *= this->operator[](i);
- return result;
- }
+ KFR_MEM_INTRINSIC constexpr index_t product() const { return (*this)->product(); }
- KFR_MEM_INTRINSIC dimset tomask() const
+ KFR_MEM_INTRINSIC constexpr dimset tomask() const
{
- dimset result = 0;
+ dimset result(0);
for (index_t i = 0; i < dims; ++i)
{
result[i + maximum_dims - dims] = this->operator[](i) == 1 ? 0 : -1;
@@ -389,7 +369,7 @@ struct shape<0>
KFR_MEM_INTRINSIC index_t product() const { return 0; }
- KFR_MEM_INTRINSIC dimset tomask() const { return -1; }
+ KFR_MEM_INTRINSIC dimset tomask() const { return dimset(-1); }
template <index_t new_dims>
constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const
@@ -430,8 +410,9 @@ KFR_MEM_INTRINSIC shape<outdims> adapt(const shape<indims>& in, const dimset& se
}
else
{
- const vec<std::make_signed_t<index_t>, maximum_dims> eset = cast<std::make_signed_t<index_t>>(set);
- return slice<indims - outdims, outdims>(*in) & slice<maximum_dims - outdims, outdims>(eset);
+ const static_array_of_size<index_t, maximum_dims> eset = set.template cast<index_t>();
+ return in->template slice<indims - outdims, outdims>() &
+ eset.template slice<maximum_dims - outdims, outdims>();
}
}
template <index_t outdims>
@@ -518,7 +499,7 @@ constexpr KFR_INTRINSIC shape<outdims> compact_shape(const shape<dims>& in)
}
template <index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)>
-bool can_assign_from(const shape<dims1>& dst_shape, const shape<dims2>& src_shape)
+constexpr bool can_assign_from(const shape<dims1>& dst_shape, const shape<dims2>& src_shape)
{
if constexpr (dims2 == 0)
{
@@ -526,31 +507,20 @@ bool can_assign_from(const shape<dims1>& dst_shape, const shape<dims2>& src_shap
}
else
{
- if constexpr (outdims >= 2)
+ for (size_t i = 0; i < outdims; ++i)
{
- vec<index_t, outdims> dst = padlow<outdims - dims1>(*dst_shape, 1);
- vec<index_t, outdims> src = padlow<outdims - dims2>(*src_shape, 1);
-
- mask<index_t, outdims> match = src + 1 <= 2 || src == dst || dst == infinite_size;
- return all(match);
- }
- else
- {
- for (size_t i = 0; i < outdims; ++i)
+ index_t dst_size = dst_shape.revindex(i);
+ index_t src_size = src_shape.revindex(i);
+ if (CMT_LIKELY(src_size == 1 || src_size == infinite_size || src_size == dst_size ||
+ dst_size == infinite_size))
{
- index_t dst_size = dst_shape.revindex(i);
- index_t src_size = src_shape.revindex(i);
- if (CMT_LIKELY(src_size == 1 || src_size == infinite_size || src_size == dst_size ||
- dst_size == infinite_size))
- {
- }
- else
- {
- return false;
- }
}
- return true;
+ else
+ {
+ return false;
+ }
}
+ return true;
}
}
diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp
@@ -245,6 +245,13 @@ public:
std::copy(values.begin(), values.end(), begin());
}
+ template <typename Input, KFR_ACCEPT_EXPRESSIONS(Input)>
+ KFR_MEM_INTRINSIC tensor(Input&& input) : tensor(get_shape(input))
+ {
+ static_assert(expression_traits<Input>::dims == dims);
+ process(*this, input);
+ }
+
KFR_INTRINSIC pointer data() const { return m_data; }
KFR_INTRINSIC size_type size() const { return m_size; }
@@ -307,6 +314,8 @@ public:
tensor(const tensor&) = default;
tensor(tensor&&) = default;
+ tensor(tensor& other) : tensor(const_cast<const tensor&>(other)) {}
+ tensor(const tensor&& other) : tensor(static_cast<const tensor&>(other)) {}
#if defined(CMT_COMPILER_IS_MSVC)
tensor& operator=(const tensor& src) &
@@ -743,18 +752,24 @@ public:
KFR_MEM_INTRINSIC memory_finalizer finalizer() const { return m_finalizer; }
- template <typename Input, index_t Dims = expression_traits<Input>::dims>
+ template <typename Input, KFR_ACCEPT_EXPRESSIONS(Input)>
KFR_MEM_INTRINSIC const tensor& operator=(Input&& input) const&
{
process(*this, input);
return *this;
}
- template <typename Input, index_t Dims = expression_traits<Input>::dims>
+ template <typename Input, KFR_ACCEPT_EXPRESSIONS(Input)>
KFR_MEM_INTRINSIC tensor& operator=(Input&& input) &&
{
process(*this, input);
return *this;
}
+ template <typename Input, KFR_ACCEPT_EXPRESSIONS(Input)>
+ KFR_MEM_INTRINSIC tensor& operator=(Input&& input) &
+ {
+ process(*this, input);
+ return *this;
+ }
bool operator==(const tensor& other) const
{
@@ -875,25 +890,6 @@ KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>
}
}
-template <typename T, index_t dims1, index_t dims2, typename Fn, index_t outdims = const_max(dims1, dims2)>
-tensor<T, outdims> tapply(const tensor<T, dims1>& x, const tensor<T, dims2>& y, Fn&& fn)
-{
- shape<outdims> xyshape = internal_generic::common_shape(x.shape(), y.shape());
-
- tensor<T, outdims> result(xyshape);
-
- shape<outdims> xshape = padlow<outdims - dims1>(*x.shape(), 1);
- shape<outdims> yshape = padlow<outdims - dims2>(*y.shape(), 1);
-
- tensor<T, outdims> xx = x.reshape(xshape);
- tensor<T, outdims> yy = y.reshape(yshape);
-
- result.iterate([&](T& val, const shape<outdims>& index)
- { val = fn(xx.access(xshape.adapt(index)), yy.access(yshape.adapt(index))); });
-
- return result;
-}
-
template <size_t width = 0, index_t Axis = infinite_size, typename E, typename Traits = expression_traits<E>>
tensor<typename Traits::value_type, Traits::dims> trender(const E& expr)
{
diff --git a/include/kfr/simd/read_write.hpp b/include/kfr/simd/read_write.hpp
@@ -187,7 +187,7 @@ struct stride_pointer<const T, groupsize>
};
template <typename T, size_t N>
-KFR_INTRINSIC vec<T, N> v(const std::array<T, N>& a)
+KFR_INTRINSIC vec<T, N> to_vec(const std::array<T, N>& a)
{
return read<N>(a.data());
}
diff --git a/include/kfr/simd/vec.hpp b/include/kfr/simd/vec.hpp
@@ -1441,7 +1441,7 @@ template <typename T, size_t N1, size_t N2>
constexpr inline size_t vec_rank<vec<vec<T, N1>, N2>> = 2;
template <typename T, size_t N>
-KFR_INTRINSIC vec<T, N> v(const portable_vec<T, N>& pv)
+KFR_INTRINSIC vec<T, N> to_vec(const portable_vec<T, N>& pv)
{
return pv;
}
diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp
@@ -147,12 +147,12 @@ TEST(tensor_broadcast)
tensor<float, 2> tresult{ shape{ 5, 5 }, { 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33,
34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, 55 } };
- tensor<float, 2> t3 = tapply(t1, t2, fn::add{});
+ tensor<float, 2> t3 = t1 + t2;
CHECK(t3.shape() == shape{ 5, 5 });
CHECK(t3 == tresult);
- tensor<float, 2> t5 = tapply(t4, t2, fn::add{});
+ tensor<float, 2> t5 = t4 + t2;
// tensor<float, 2> t5 = t4 + t2;
CHECK(t5 == tresult);
}
@@ -785,15 +785,27 @@ TEST(from_ilist)
CHECK(t4 == tensor<float, 3>(shape{ 2, 2, 2 }, { 10, 20, 30, 40, 50, 60, 70, 80 }));
}
+TEST(sharing_data)
+{
+ tensor<int, 2> t{ { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
+ auto t2 = t; // share data
+ t2(0, 0) = 10;
+ CHECK(t == tensor<int, 2>{ { 10, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } });
+ auto t3 = t(0, tall());
+ CHECK(t3 == tensor<int, 1>{ 10, 2, 3 });
+ t3 *= 10;
+ CHECK(t3 == tensor<int, 1>{ 100, 20, 30 });
+ CHECK(t == tensor<int, 2>{ { 100, 20, 30 }, { 4, 5, 6 }, { 7, 8, 9 } });
+ t(trange(0, 2), trange(0, 2)) = 0;
+ CHECK(t == tensor<int, 2>{ { 0, 0, 30 }, { 0, 0, 6 }, { 7, 8, 9 } });
+}
+
TEST(tensor_from_container)
{
std::vector<int> a{ 1, 2, 3 };
auto t = tensor_from_container(a);
CHECK(t.shape() == shape{ 3 });
CHECK(t == tensor<int, 1>{ 1, 2, 3 });
- auto t2 = t; // share data
- t(0) = 100;
- CHECK(t2 == tensor<int, 1>{ 100, 2, 3 });
}
} // namespace CMT_ARCH_NAME