diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index 2a59530a1..aad1647c2 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -27,6 +27,10 @@ struct traits> : public traits static constexpr int NumDimensions = XprTraits::NumDimensions; static constexpr int Layout = XprTraits::Layout; typedef typename XprTraits::PointerType PointerType; + enum { + // Broadcast is read-only. + Flags = traits::Flags & ~LvalueBit + }; }; template diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h index 7061f5120..e12923da0 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h @@ -52,7 +52,13 @@ class TensorLazyEvaluatorReadOnly typedef TensorEvaluator EvalType; TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) { - m_dims = m_impl.dimensions(); + EIGEN_STATIC_ASSERT( + internal::array_size::value == internal::array_size::value, + "Dimension sizes must match."); + const auto& other_dims = m_impl.dimensions(); + for (std::size_t i = 0; i < m_dims.size(); ++i) { + m_dims[i] = other_dims[i]; + } m_impl.evalSubExprsIfNeeded(NULL); } virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); } @@ -86,14 +92,12 @@ class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnlym_impl.coeffRef(index); } }; -template -class TensorLazyEvaluator : public std::conditional_t::value), - TensorLazyEvaluatorWritable, - TensorLazyEvaluatorReadOnly > { +template +class TensorLazyEvaluator : public std::conditional_t, + TensorLazyEvaluatorReadOnly> { public: - typedef std::conditional_t::value), - TensorLazyEvaluatorWritable, - TensorLazyEvaluatorReadOnly > + typedef std::conditional_t, + TensorLazyEvaluatorReadOnly> Base; typedef typename Base::Scalar Scalar; @@ -101,24 +105,15 @@ class TensorLazyEvaluator : public std::conditional_t -class TensorRef : public TensorBase > { +template +class TensorRefBase : public TensorBase { public: - typedef TensorRef Self; + typedef typename traits::PlainObjectType PlainObjectType; typedef typename PlainObjectType::Base Base; - typedef typename Eigen::internal::nested::type Nested; - typedef typename internal::traits::StorageKind StorageKind; - typedef typename internal::traits::Index Index; - typedef typename internal::traits::Scalar Scalar; + typedef typename Eigen::internal::nested::type Nested; + typedef typename traits::StorageKind StorageKind; + typedef typename traits::Index Index; + typedef typename traits::Scalar Scalar; typedef typename NumTraits::Real RealScalar; typedef typename Base::CoeffReturnType CoeffReturnType; typedef Scalar* PointerType; @@ -138,33 +133,17 @@ class TensorRef : public TensorBase > { }; //===- Tensor block evaluation strategy (see TensorBlock.h) -----------===// - typedef internal::TensorBlockNotImplemented TensorBlock; + typedef TensorBlockNotImplemented TensorBlock; //===------------------------------------------------------------------===// - EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {} + EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {} - template - EIGEN_STRONG_INLINE TensorRef(const Expression& expr) - : m_evaluator(new internal::TensorLazyEvaluator(expr, DefaultDevice())) { - m_evaluator->incrRefCount(); - } - - template - EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) { - unrefEvaluator(); - m_evaluator = new internal::TensorLazyEvaluator(expr, DefaultDevice()); - m_evaluator->incrRefCount(); - return *this; - } - - ~TensorRef() { unrefEvaluator(); } - - TensorRef(const TensorRef& other) : TensorBase >(other), m_evaluator(other.m_evaluator) { + TensorRefBase(const TensorRefBase& other) : TensorBase(other), m_evaluator(other.m_evaluator) { eigen_assert(m_evaluator->refCount() > 0); m_evaluator->incrRefCount(); } - TensorRef& operator=(const TensorRef& other) { + TensorRefBase& operator=(const TensorRefBase& other) { if (this != &other) { unrefEvaluator(); m_evaluator = other.m_evaluator; @@ -174,6 +153,28 @@ class TensorRef : public TensorBase > { return *this; } + template , Derived>::value>> + EIGEN_STRONG_INLINE TensorRefBase(const Expression& expr) + : m_evaluator(new TensorLazyEvaluator::value && + bool(is_lvalue::value)>(expr, DefaultDevice())) { + m_evaluator->incrRefCount(); + } + + template , Derived>::value>> + EIGEN_STRONG_INLINE TensorRefBase& operator=(const Expression& expr) { + unrefEvaluator(); + m_evaluator = new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice, + /*IsWritable=*/!std::is_const::value&& bool(is_lvalue::value) > + (expr, DefaultDevice()); + m_evaluator->incrRefCount(); + return *this; + } + + ~TensorRefBase() { unrefEvaluator(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); } @@ -188,12 +189,6 @@ class TensorRef : public TensorBase > { const array indices{{firstIndex, otherIndices...}}; return coeff(indices); } - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) { - const std::size_t num_indices = (sizeof...(otherIndices) + 1); - const array indices{{firstIndex, otherIndices...}}; - return coeffRef(indices); - } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(const array& indices) const { @@ -212,6 +207,70 @@ class TensorRef : public TensorBase > { } return m_evaluator->coeff(index); } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); } + + protected: + TensorLazyBaseEvaluator* evaluator() { return m_evaluator; } + + private: + EIGEN_STRONG_INLINE void unrefEvaluator() { + if (m_evaluator) { + m_evaluator->decrRefCount(); + if (m_evaluator->refCount() == 0) { + delete m_evaluator; + } + } + } + + TensorLazyBaseEvaluator* m_evaluator; +}; + +} // namespace internal + +/** + * \ingroup CXX11_Tensor_Module + * + * \brief A reference to a tensor expression + * The expression will be evaluated lazily (as much as possible). + * + */ +template +class TensorRef : public internal::TensorRefBase> { + typedef internal::TensorRefBase> Base; + + public: + using Scalar = typename Base::Scalar; + using Dimensions = typename Base::Dimensions; + + EIGEN_STRONG_INLINE TensorRef() : Base() {} + + template + EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) { + EIGEN_STATIC_ASSERT(internal::is_lvalue::value, + "Expression must be mutable to create a mutable TensorRef. Did you mean " + "TensorRef?)"); + } + + template + EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) { + EIGEN_STATIC_ASSERT(internal::is_lvalue::value, + "Expression must be mutable to create a mutable TensorRef. Did you mean " + "TensorRef?)"); + return Base::operator=(expr).derived(); + } + + TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) { + const std::size_t num_indices = (sizeof...(otherIndices) + 1); + const array indices{{firstIndex, otherIndices...}}; + return coeffRef(indices); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array& indices) { const Dimensions& dims = this->dimensions(); @@ -227,24 +286,37 @@ class TensorRef : public TensorBase > { index = index * dims[i] + indices[i]; } } - return m_evaluator->coeffRef(index); + return Base::evaluator()->coeffRef(index); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return Base::evaluator()->coeffRef(index); } +}; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); } +/** + * \ingroup CXX11_Tensor_Module + * + * \brief A reference to a constant tensor expression + * The expression will be evaluated lazily (as much as possible). + * + */ +template +class TensorRef : public internal::TensorRefBase> { + typedef internal::TensorRefBase> Base; - private: - EIGEN_STRONG_INLINE void unrefEvaluator() { - if (m_evaluator) { - m_evaluator->decrRefCount(); - if (m_evaluator->refCount() == 0) { - delete m_evaluator; - } - } + public: + EIGEN_STRONG_INLINE TensorRef() : Base() {} + + template + EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {} + + template + EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) { + return Base::operator=(expr).derived(); } - internal::TensorLazyBaseEvaluator* m_evaluator; + TensorRef(const TensorRef& other) : Base(other) {} + + TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); } }; // evaluator for rvalues diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h index 017b4ff63..f5954d6f3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h @@ -95,8 +95,9 @@ struct traits > : public trai typedef typename MakePointer::Type PointerType; }; -template -struct traits > : public traits { +template +struct traits > : public traits { + typedef PlainObjectType_ PlainObjectType; typedef traits BaseTraits; typedef typename BaseTraits::Scalar Scalar; typedef typename BaseTraits::StorageKind StorageKind; diff --git a/unsupported/test/cxx11_tensor_morphing.cpp b/unsupported/test/cxx11_tensor_morphing.cpp index 0672572fe..55d42918b 100644 --- a/unsupported/test/cxx11_tensor_morphing.cpp +++ b/unsupported/test/cxx11_tensor_morphing.cpp @@ -138,7 +138,7 @@ static void test_const_slice() { TensorMap> m(b, 1); DSizes offsets; offsets[0] = 0; - TensorRef> slice_ref(m.slice(offsets, m.dimensions())); + TensorRef> slice_ref(m.slice(offsets, m.dimensions())); VERIFY_IS_EQUAL(slice_ref(0), 42); } diff --git a/unsupported/test/cxx11_tensor_ref.cpp b/unsupported/test/cxx11_tensor_ref.cpp index d5ff19617..cf097499d 100644 --- a/unsupported/test/cxx11_tensor_ref.cpp +++ b/unsupported/test/cxx11_tensor_ref.cpp @@ -49,8 +49,8 @@ static void test_simple_rvalue_ref() { Tensor input2(6); input2.setRandom(); - TensorRef> ref3(input1 + input2); - TensorRef> ref4 = input1 + input2; + TensorRef> ref3(input1 + input2); + TensorRef> ref4 = input1 + input2; VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data()); VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data()); @@ -144,7 +144,7 @@ static void test_ref_in_expr() { Tensor result(3, 5, 7); result.setRandom(); - TensorRef> result_ref(result); + TensorRef> result_ref(result); Tensor bias(3, 5, 7); bias.setRandom(); @@ -192,7 +192,7 @@ static void test_nested_ops_with_ref() { paddings[2] = std::make_pair(3, 4); paddings[3] = std::make_pair(0, 0); DSizes shuffle_dims(0, 1, 2, 3); - TensorRef> ref(m.pad(paddings)); + TensorRef> ref(m.pad(paddings)); array, 4> trivial; trivial[0] = std::make_pair(0, 0); trivial[1] = std::make_pair(0, 0);