2
0
mirror of https://gitlab.com/libeigen/eigen.git synced 2025-04-24 19:40:45 +08:00

Allow Tensor trace to be passed to a TensorRef.

This commit is contained in:
Antonio Sanchez 2025-03-25 08:26:23 -07:00
parent 8e32cbf7da
commit 6579e36eb4
2 changed files with 9 additions and 0 deletions
unsupported
Eigen/CXX11/src/Tensor
test

@ -27,6 +27,10 @@ struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType> {
typedef std::remove_reference_t<Nested> Nested_;
static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
static constexpr int Layout = XprTraits::Layout;
enum {
// Trace is read-only.
Flags = traits<XprType>::Flags & ~LvalueBit
};
};
template <typename Dims, typename XprType>
@ -203,6 +207,8 @@ struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device> {
return true;
}
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return nullptr; }
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {

@ -41,6 +41,9 @@ static void test_simple_lvalue_ref() {
for (int i = 0; i < 6; ++i) {
VERIFY_IS_EQUAL(input(i), -i * 2);
}
TensorRef<const Tensor<int, 1>> ref5(input.trace());
VERIFY_IS_EQUAL(ref5[0], input[0]);
}
static void test_simple_rvalue_ref() {