From c0f76ce2cf1780c7b638071d45cba2a3d04f2f6f Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 11 Jul 2014 16:25:36 +0200 Subject: [PATCH] Fix bug #838: fix dense * sparse and sparse * dense outer products --- Eigen/src/SparseCore/SparseDenseProduct.h | 33 +++++++++++---- test/sparse_product.cpp | 51 +++++++++++------------ 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h index 610833f3b..a3772c255 100644 --- a/Eigen/src/SparseCore/SparseDenseProduct.h +++ b/Eigen/src/SparseCore/SparseDenseProduct.h @@ -19,7 +19,10 @@ template struct SparseDenseProductRet template struct SparseDenseProductReturnType { - typedef SparseDenseOuterProduct Type; + typedef typename internal::conditional< + Lhs::IsRowMajor, + SparseDenseOuterProduct, + SparseDenseOuterProduct >::type Type; }; template struct DenseSparseProductReturnType @@ -29,7 +32,10 @@ template struct DenseSparseProductRet template struct DenseSparseProductReturnType { - typedef SparseDenseOuterProduct Type; + typedef typename internal::conditional< + Rhs::IsRowMajor, + SparseDenseOuterProduct, + SparseDenseOuterProduct >::type Type; }; namespace internal { @@ -114,17 +120,30 @@ class SparseDenseOuterProduct::InnerIterator : public _LhsNes typedef typename SparseDenseOuterProduct::Index Index; public: EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) - : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer)) - { - } + : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits::StorageKind() )) + { } inline Index outer() const { return m_outer; } - inline Index row() const { return Transpose ? Base::row() : m_outer; } - inline Index col() const { return Transpose ? m_outer : Base::row(); } + inline Index row() const { return Transpose ? m_outer : Base::index(); } + inline Index col() const { return Transpose ? Base::index() : m_outer; } inline Scalar value() const { return Base::value() * m_factor; } protected: + static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) + { + return rhs.coeff(outer); + } + + static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse()) + { + typename Traits::_RhsNested::InnerIterator it(rhs, outer); + if (it && it.index()==0) + return it.value(); + + return Scalar(0); + } + Index m_outer; Scalar m_factor; }; diff --git a/test/sparse_product.cpp b/test/sparse_product.cpp index a2ea9d5b7..2c8b97e5b 100644 --- a/test/sparse_product.cpp +++ b/test/sparse_product.cpp @@ -9,32 +9,6 @@ #include "sparse.h" -template struct test_outer; - -template struct test_outer { - static void run(SparseMatrixType& m2, SparseMatrixType& m4, DenseMatrix& refMat2, DenseMatrix& refMat4) { - typedef typename SparseMatrixType::Index Index; - Index c = internal::random(0,m2.cols()-1); - Index c1 = internal::random(0,m2.cols()-1); - VERIFY_IS_APPROX(m4=m2.col(c)*refMat2.col(c1).transpose(), refMat4=refMat2.col(c)*refMat2.col(c1).transpose()); - VERIFY_IS_APPROX(m4=refMat2.col(c1)*m2.col(c).transpose(), refMat4=refMat2.col(c1)*refMat2.col(c).transpose()); - } -}; - -template struct test_outer { - static void run(SparseMatrixType& m2, SparseMatrixType& m4, DenseMatrix& refMat2, DenseMatrix& refMat4) { - typedef typename SparseMatrixType::Index Index; - Index r = internal::random(0,m2.rows()-1); - Index c1 = internal::random(0,m2.cols()-1); - VERIFY_IS_APPROX(m4=m2.row(r).transpose()*refMat2.col(c1).transpose(), refMat4=refMat2.row(r).transpose()*refMat2.col(c1).transpose()); - VERIFY_IS_APPROX(m4=refMat2.col(c1)*m2.row(r), refMat4=refMat2.col(c1)*refMat2.row(r)); - } -}; - -// (m2,m4,refMat2,refMat4,dv1); -// VERIFY_IS_APPROX(m4=m2.innerVector(c)*dv1.transpose(), refMat4=refMat2.colVector(c)*dv1.transpose()); -// VERIFY_IS_APPROX(m4=dv1*mcm.col(c).transpose(), refMat4=dv1*refMat2.col(c).transpose()); - template void sparse_product() { typedef typename SparseMatrixType::Index Index; @@ -119,7 +93,30 @@ template void sparse_product() VERIFY_IS_APPROX(dm4=refMat2t.transpose()*m3t.transpose(), refMat4=refMat2t.transpose()*refMat3t.transpose()); // sparse * dense and dense * sparse outer product - test_outer::run(m2,m4,refMat2,refMat4); + { + Index c = internal::random(0,depth-1); + Index r = internal::random(0,rows-1); + Index c1 = internal::random(0,cols-1); + Index r1 = internal::random(0,depth-1); + + VERIFY_IS_APPROX( m4=m2.col(c)*refMat3.col(c1).transpose(), refMat4=refMat2.col(c)*refMat3.col(c1).transpose()); + VERIFY_IS_APPROX(dm4=m2.col(c)*refMat3.col(c1).transpose(), refMat4=refMat2.col(c)*refMat3.col(c1).transpose()); + + VERIFY_IS_APPROX(m4=refMat3.col(c1)*m2.col(c).transpose(), refMat4=refMat3.col(c1)*refMat2.col(c).transpose()); + VERIFY_IS_APPROX(dm4=refMat3.col(c1)*m2.col(c).transpose(), refMat4=refMat3.col(c1)*refMat2.col(c).transpose()); + + VERIFY_IS_APPROX( m4=refMat3.row(r1).transpose()*m2.col(c).transpose(), refMat4=refMat3.row(r1).transpose()*refMat2.col(c).transpose()); + VERIFY_IS_APPROX(dm4=refMat3.row(r1).transpose()*m2.col(c).transpose(), refMat4=refMat3.row(r1).transpose()*refMat2.col(c).transpose()); + + VERIFY_IS_APPROX( m4=m2.row(r).transpose()*refMat3.col(c1).transpose(), refMat4=refMat2.row(r).transpose()*refMat3.col(c1).transpose()); + VERIFY_IS_APPROX(dm4=m2.row(r).transpose()*refMat3.col(c1).transpose(), refMat4=refMat2.row(r).transpose()*refMat3.col(c1).transpose()); + + VERIFY_IS_APPROX( m4=refMat3.col(c1)*m2.row(r), refMat4=refMat3.col(c1)*refMat2.row(r)); + VERIFY_IS_APPROX(dm4=refMat3.col(c1)*m2.row(r), refMat4=refMat3.col(c1)*refMat2.row(r)); + + VERIFY_IS_APPROX( m4=refMat3.row(r1).transpose()*m2.row(r), refMat4=refMat3.row(r1).transpose()*refMat2.row(r)); + VERIFY_IS_APPROX(dm4=refMat3.row(r1).transpose()*m2.row(r), refMat4=refMat3.row(r1).transpose()*refMat2.row(r)); + } VERIFY_IS_APPROX(m6=m6*m6, refMat6=refMat6*refMat6);