From 2c03e6fccc274db665c6e4708f2cbde14813e826 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 22 Dec 2011 14:01:06 +0100 Subject: [PATCH] evaluate 1D sparse expressions into SparseVector and make the sparse operator<< and dot honor nested types --- Eigen/src/SparseCore/SparseDot.h | 12 +++++++-- Eigen/src/SparseCore/SparseMatrixBase.h | 13 +++++++--- Eigen/src/SparseCore/SparseUtil.h | 34 ++++++++++++++++++++----- Eigen/src/SparseCore/SparseVector.h | 4 +-- test/sparse_vector.cpp | 9 +++++-- 5 files changed, 56 insertions(+), 16 deletions(-) diff --git a/Eigen/src/SparseCore/SparseDot.h b/Eigen/src/SparseCore/SparseDot.h index 132bb47f3..4c600362d 100644 --- a/Eigen/src/SparseCore/SparseDot.h +++ b/Eigen/src/SparseCore/SparseDot.h @@ -62,8 +62,16 @@ SparseMatrixBase::dot(const SparseMatrixBase& other) cons eigen_assert(size() == other.size()); - typename Derived::InnerIterator i(derived(),0); - typename OtherDerived::InnerIterator j(other.derived(),0); + typedef typename Derived::Nested Nested; + typedef typename OtherDerived::Nested OtherNested; + typedef typename internal::remove_all::type NestedCleaned; + typedef typename internal::remove_all::type OtherNestedCleaned; + + const Nested nthis(derived()); + const OtherNested nother(other.derived()); + + typename NestedCleaned::InnerIterator i(nthis,0); + typename OtherNestedCleaned::InnerIterator j(nother,0); Scalar res(0); while (i && j) { diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index 536b6d596..6485a5227 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -274,12 +274,16 @@ template class SparseMatrixBase : public EigenBase friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) { + typedef typename Derived::Nested Nested; + typedef typename internal::remove_all::type NestedCleaned; + if (Flags&RowMajorBit) { - for (Index row=0; row class SparseMatrixBase : public EigenBase } else { + const Nested nm(m.derived()); if (m.cols() == 1) { Index row = 0; - for (typename Derived::InnerIterator it(m.derived(), 0); it; ++it) + for (typename NestedCleaned::InnerIterator it(nm.derived(), 0); it; ++it) { for ( ; row class SparseMatrixBase : public EigenBase } else { - SparseMatrix trans = m.derived(); + SparseMatrix trans = m; s << static_cast >&>(trans); } } diff --git a/Eigen/src/SparseCore/SparseUtil.h b/Eigen/src/SparseCore/SparseUtil.h index db9ae98e7..a03cdbd5b 100644 --- a/Eigen/src/SparseCore/SparseUtil.h +++ b/Eigen/src/SparseCore/SparseUtil.h @@ -103,17 +103,39 @@ template::Cols namespace internal { -template struct eval -{ - typedef typename traits::Scalar _Scalar; - enum { - _Flags = traits::Flags - }; +template struct sparse_eval; +template struct eval + : public sparse_eval::RowsAtCompileTime,traits::ColsAtCompileTime> +{}; + +template struct sparse_eval { + typedef typename traits::Scalar _Scalar; + enum { _Flags = traits::Flags| RowMajorBit }; + public: + typedef SparseVector<_Scalar, _Flags> type; +}; + +template struct sparse_eval { + typedef typename traits::Scalar _Scalar; + enum { _Flags = traits::Flags & (~RowMajorBit) }; + public: + typedef SparseVector<_Scalar, _Flags> type; +}; + +template struct sparse_eval { + typedef typename traits::Scalar _Scalar; + enum { _Flags = traits::Flags }; public: typedef SparseMatrix<_Scalar, _Flags> type; }; +template struct sparse_eval { + typedef typename traits::Scalar _Scalar; + public: + typedef Matrix<_Scalar, 1, 1> type; +}; + template struct plain_matrix_type { typedef typename traits::Scalar _Scalar; diff --git a/Eigen/src/SparseCore/SparseVector.h b/Eigen/src/SparseCore/SparseVector.h index ea83eab1a..e027c5d01 100644 --- a/Eigen/src/SparseCore/SparseVector.h +++ b/Eigen/src/SparseCore/SparseVector.h @@ -47,7 +47,7 @@ struct traits > typedef Sparse StorageKind; typedef MatrixXpr XprKind; enum { - IsColVector = _Options & RowMajorBit ? 0 : 1, + IsColVector = (_Options & RowMajorBit) ? 0 : 1, RowsAtCompileTime = IsColVector ? Dynamic : 1, ColsAtCompileTime = IsColVector ? 1 : Dynamic, @@ -320,7 +320,7 @@ protected: const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit); if(needToTranspose) { - Index size = other.innerSize(); + Index size = other.size(); Index nnz = other.nonZeros(); resize(size); reserve(nnz); diff --git a/test/sparse_vector.cpp b/test/sparse_vector.cpp index 5be4f5d9a..09d36a51b 100644 --- a/test/sparse_vector.cpp +++ b/test/sparse_vector.cpp @@ -34,9 +34,9 @@ template void sparse_vector(int rows, int cols) typedef SparseMatrix SparseMatrixType; Scalar eps = 1e-6; - SparseMatrixType m1(rows,cols); + SparseMatrixType m1(rows,rows); SparseVectorType v1(rows), v2(rows), v3(rows); - DenseMatrix refM1 = DenseMatrix::Zero(rows, cols); + DenseMatrix refM1 = DenseMatrix::Zero(rows, rows); DenseVector refV1 = DenseVector::Random(rows), refV2 = DenseVector::Random(rows), refV3 = DenseVector::Random(rows); @@ -86,6 +86,11 @@ template void sparse_vector(int rows, int cols) VERIFY_IS_APPROX(v1.dot(v2), refV1.dot(refV2)); VERIFY_IS_APPROX(v1.dot(refV2), refV1.dot(refV2)); + VERIFY_IS_APPROX(v1.dot(m1*v2), refV1.dot(refM1*refV2)); + int i = internal::random(0,rows-1); + VERIFY_IS_APPROX(v1.dot(m1.col(i)), refV1.dot(refM1.col(i))); + + VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm()); }