commit 858d79d45000c5be1490859822e20d0fd23e9ed3
parent 50b19a74b2d3bc8d32a8efc1dc455cc8251593ae
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Thu, 10 Aug 2023 03:53:50 +0100
Fix shape::adapt
Diffstat:
2 files changed, 10 insertions(+), 7 deletions(-)
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -759,7 +759,7 @@ static auto process(Out&& out, In&& in, shape<outdims> start = shape<outdims>(0)
in_size = inshape[in_axis];
begin_pass(out, start, stop);
- begin_pass(in, inshape.adapt(start), inshape.adapt(stop));
+ begin_pass(in, inshape.adapt(start), inshape.adapt(stop, ctrue));
shape<outdims> outidx;
if constexpr (outdims == 1)
@@ -830,7 +830,7 @@ static auto process(Out&& out, In&& in, shape<outdims> start = shape<outdims>(0)
outidx[out_axis] = stop[out_axis] - 1;
} while (internal_generic::increment_indices(outidx, start, stop));
}
- end_pass(in, inshape.adapt(start), inshape.adapt(stop));
+ end_pass(in, inshape.adapt(start), inshape.adapt(stop, ctrue));
end_pass(out, start, stop);
return stop;
}
diff --git a/include/kfr/base/shape.hpp b/include/kfr/base/shape.hpp
@@ -276,11 +276,14 @@ struct shape : static_array_base<index_t, csizeseq_t<dims>>
KFR_MEM_INTRINSIC constexpr index_t dot(const shape& other) const { return (*this)->dot(*other); }
- template <index_t indims>
- KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other) const
+ template <index_t indims, bool stop = false>
+ KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other, cbool_t<stop> = {}) const
{
static_assert(indims >= dims);
- return other.template trim<dims>()->min(**this - 1);
+ if constexpr (stop)
+ return other.template trim<dims>()->min(**this);
+ else
+ return other.template trim<dims>()->min(**this - 1);
}
KFR_MEM_INTRINSIC constexpr index_t product() const { return (*this)->product(); }
@@ -357,8 +360,8 @@ struct shape<0>
KFR_MEM_INTRINSIC size_t to_flat(const shape<0>& indices) const { return 0; }
KFR_MEM_INTRINSIC shape<0> from_flat(size_t index) const { return {}; }
- template <index_t odims>
- KFR_MEM_INTRINSIC shape<0> adapt(const shape<odims>& other) const
+ template <index_t odims, bool stop = false>
+ KFR_MEM_INTRINSIC shape<0> adapt(const shape<odims>& other, cbool_t<stop> = {}) const
{
return {};
}