From 896126588997f89d647ec857a4dd832e462a013b Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 21 Oct 2015 09:47:43 +0200 Subject: [PATCH] bug #1064: add support for Ref --- Eigen/src/SparseCore/SparseRef.h | 147 +++++++++++++++++++++++++++- Eigen/src/SparseCore/SparseVector.h | 3 + test/sparse_ref.cpp | 20 +++- 3 files changed, 165 insertions(+), 5 deletions(-) diff --git a/Eigen/src/SparseCore/SparseRef.h b/Eigen/src/SparseCore/SparseRef.h index e10bf6878..f9735fd1c 100644 --- a/Eigen/src/SparseCore/SparseRef.h +++ b/Eigen/src/SparseCore/SparseRef.h @@ -19,7 +19,7 @@ enum { namespace internal { template class SparseRefBase; - + template struct traits, _Options, _StrideType> > : public traits > @@ -27,7 +27,7 @@ struct traits, _Options, _Stride typedef SparseMatrix PlainObjectType; enum { Options = _Options, - Flags = traits >::Flags | CompressedAccessBit | NestByRefBit + Flags = traits::Flags | CompressedAccessBit | NestByRefBit }; template struct match { @@ -48,7 +48,35 @@ struct traits, _Options, _ Flags = (traits >::Flags | CompressedAccessBit | NestByRefBit) & ~LvalueBit }; }; - + +template +struct traits, _Options, _StrideType> > + : public traits > +{ + typedef SparseVector PlainObjectType; + enum { + Options = _Options, + Flags = traits::Flags | CompressedAccessBit | NestByRefBit + }; + + template struct match { + enum { + MatchAtCompileTime = (Derived::Flags&CompressedAccessBit) && Derived::IsVectorAtCompileTime + }; + typedef typename internal::conditional::type type; + }; + +}; + +template +struct traits, _Options, _StrideType> > + : public traits, _Options, _StrideType> > +{ + enum { + Flags = (traits >::Flags | CompressedAccessBit | NestByRefBit) & ~LvalueBit + }; +}; + template struct traits > : public traits {}; @@ -195,6 +223,99 @@ class Ref, Options, StrideType }; + +/** + * \ingroup Sparse_Module + * + * \brief A sparse vector expression referencing an existing sparse vector expression + * + * \tparam PlainObjectType the equivalent sparse matrix type of the referenced data + * \tparam Options Not used for SparseVector. + * \tparam StrideType Only used for dense Ref + * + * \sa class Ref + */ +template +class Ref, Options, StrideType > + : public internal::SparseRefBase, Options, StrideType > > +{ + typedef SparseVector PlainObjectType; + typedef internal::traits Traits; + template + inline Ref(const SparseVector& expr); + public: + + typedef internal::SparseRefBase Base; + EIGEN_SPARSE_PUBLIC_INTERFACE(Ref) + + #ifndef EIGEN_PARSED_BY_DOXYGEN + template + inline Ref(SparseVector& expr) + { + EIGEN_STATIC_ASSERT(bool(Traits::template match >::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); + Base::construct(expr.derived()); + } + + template + inline Ref(const SparseCompressedBase& expr) + #else + template + inline Ref(SparseCompressedBase& expr) + #endif + { + EIGEN_STATIC_ASSERT(bool(internal::is_lvalue::value), THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY); + EIGEN_STATIC_ASSERT(bool(Traits::template match::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); + Base::construct(expr.const_cast_derived()); + } +}; + +// this is the const ref version +template +class Ref, Options, StrideType> + : public internal::SparseRefBase, Options, StrideType> > +{ + typedef SparseVector TPlainObjectType; + typedef internal::traits Traits; + public: + + typedef internal::SparseRefBase Base; + EIGEN_SPARSE_PUBLIC_INTERFACE(Ref) + + template + inline Ref(const SparseMatrixBase& expr) + { + construct(expr.derived(), typename Traits::template match::type()); + } + + inline Ref(const Ref& other) : Base(other) { + // copy constructor shall not copy the m_object, to avoid unnecessary malloc and copy + } + + template + inline Ref(const RefBase& other) { + construct(other.derived(), typename Traits::template match::type()); + } + + protected: + + template + void construct(const Expression& expr,internal::true_type) + { + Base::construct(expr); + } + + template + void construct(const Expression& expr, internal::false_type) + { + TPlainObjectType* obj = reinterpret_cast(m_object_bytes); + ::new (obj) TPlainObjectType(expr); + Base::construct(*obj); + } + + protected: + char m_object_bytes[sizeof(TPlainObjectType)]; +}; + namespace internal { template @@ -217,6 +338,26 @@ struct evaluator, Options, explicit evaluator(const XprType &mat) : Base(mat) {} }; +template +struct evaluator, Options, StrideType> > + : evaluator, Options, StrideType> > > +{ + typedef evaluator, Options, StrideType> > > Base; + typedef Ref, Options, StrideType> XprType; + evaluator() : Base() {} + explicit evaluator(const XprType &mat) : Base(mat) {} +}; + +template +struct evaluator, Options, StrideType> > + : evaluator, Options, StrideType> > > +{ + typedef evaluator, Options, StrideType> > > Base; + typedef Ref, Options, StrideType> XprType; + evaluator() : Base() {} + explicit evaluator(const XprType &mat) : Base(mat) {} +}; + } } // end namespace Eigen diff --git a/Eigen/src/SparseCore/SparseVector.h b/Eigen/src/SparseCore/SparseVector.h index f941fa5e1..94f8d0341 100644 --- a/Eigen/src/SparseCore/SparseVector.h +++ b/Eigen/src/SparseCore/SparseVector.h @@ -235,6 +235,9 @@ class SparseVector inline SparseVector(const SparseMatrixBase& other) : m_size(0) { + #ifdef EIGEN_SPARSE_CREATE_TEMPORARY_PLUGIN + EIGEN_SPARSE_CREATE_TEMPORARY_PLUGIN + #endif check_template_parameters(); *this = other.derived(); } diff --git a/test/sparse_ref.cpp b/test/sparse_ref.cpp index d2d475616..f4aefbb48 100644 --- a/test/sparse_ref.cpp +++ b/test/sparse_ref.cpp @@ -53,10 +53,14 @@ EIGEN_DONT_INLINE void call_ref_3(const Ref, StandardC VERIFY_IS_EQUAL(a.toDense(),b.toDense()); } +template +EIGEN_DONT_INLINE void call_ref_4(Ref > a, const B &b) { VERIFY_IS_EQUAL(a.toDense(),b.toDense()); } + +template +EIGEN_DONT_INLINE void call_ref_5(const Ref >& a, const B &b) { VERIFY_IS_EQUAL(a.toDense(),b.toDense()); } + void call_ref() { -// SparseVector > ca = VectorXcf::Random(10).sparseView(); -// SparseVector a = VectorXf::Random(10).sparseView(); SparseMatrix A = MatrixXf::Random(10,10).sparseView(0.5,1); SparseMatrix B = MatrixXf::Random(10,10).sparseView(0.5,1); SparseMatrix C = MatrixXf::Random(10,10).sparseView(0.5,1); @@ -111,6 +115,15 @@ void call_ref() VERIFY_EVALUATION_COUNT( call_ref_2(vr, vr.transpose()), 0); VERIFY_EVALUATION_COUNT( call_ref_2(A.block(1,1,3,3), A.block(1,1,3,3)), 1); // should be 0 (allocate starts/nnz only) + + VERIFY_EVALUATION_COUNT( call_ref_4(vc, vc), 0); + VERIFY_EVALUATION_COUNT( call_ref_4(vr, vr.transpose()), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(vc, vc), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(vr, vr.transpose()), 0); + VERIFY_EVALUATION_COUNT( call_ref_4(A.col(2), A.col(2)), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(A.col(2), A.col(2)), 0); + // VERIFY_EVALUATION_COUNT( call_ref_4(A.row(2), A.row(2).transpose()), 1); // does not compile on purpose + VERIFY_EVALUATION_COUNT( call_ref_5(A.row(2), A.row(2).transpose()), 1); } void test_sparse_ref() @@ -119,5 +132,8 @@ void test_sparse_ref() CALL_SUBTEST_1( check_const_correctness(SparseMatrix()) ); CALL_SUBTEST_1( check_const_correctness(SparseMatrix()) ); CALL_SUBTEST_2( call_ref() ); + + CALL_SUBTEST_3( check_const_correctness(SparseVector()) ); + CALL_SUBTEST_3( check_const_correctness(SparseVector()) ); } }