From e6f8c5c325fca53b53436b6bd8d66749444216bb Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 26 Oct 2015 18:20:00 +0100 Subject: [PATCH] Add support to directly evaluate the product of two sparse matrices within a dense matrix. --- .../ConservativeSparseSparseProduct.h | 85 ++++++++++++++++++- Eigen/src/SparseCore/SparseAssign.h | 8 +- Eigen/src/SparseCore/SparseMatrixBase.h | 2 +- Eigen/src/SparseCore/SparseProduct.h | 34 +++++++- test/sparse_product.cpp | 11 +++ 5 files changed, 132 insertions(+), 8 deletions(-) diff --git a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h index a61ceb7cc..0f6835846 100644 --- a/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +++ b/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2008-2014 Gael Guennebaud +// Copyright (C) 2008-2015 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -255,6 +255,89 @@ struct conservative_sparse_sparse_product_selector +static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res) +{ + typedef typename remove_all::type::Scalar Scalar; + Index cols = rhs.outerSize(); + eigen_assert(lhs.outerSize() == rhs.innerSize()); + + evaluator lhsEval(lhs); + evaluator rhsEval(rhs); + + for (Index j=0; j::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) + { + Scalar y = rhsIt.value(); + Index k = rhsIt.index(); + for (typename evaluator::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt) + { + Index i = lhsIt.index(); + Scalar x = lhsIt.value(); + res.coeffRef(i,j) += x * y; + } + } + } +} + + +} // end namespace internal + +namespace internal { + +template::Flags&RowMajorBit) ? RowMajor : ColMajor, + int RhsStorageOrder = (traits::Flags&RowMajorBit) ? RowMajor : ColMajor> +struct sparse_sparse_to_dense_product_selector; + +template +struct sparse_sparse_to_dense_product_selector +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + internal::sparse_sparse_to_dense_product_impl(lhs, rhs, res); + } +}; + +template +struct sparse_sparse_to_dense_product_selector +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + typedef SparseMatrix ColMajorMatrix; + ColMajorMatrix lhsCol(lhs); + internal::sparse_sparse_to_dense_product_impl(lhsCol, rhs, res); + } +}; + +template +struct sparse_sparse_to_dense_product_selector +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + typedef SparseMatrix ColMajorMatrix; + ColMajorMatrix rhsCol(rhs); + internal::sparse_sparse_to_dense_product_impl(lhs, rhsCol, res); + } +}; + +template +struct sparse_sparse_to_dense_product_selector +{ + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) + { + Transpose trRes(res); + internal::sparse_sparse_to_dense_product_impl >(rhs, lhs, trRes); + } +}; + + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/SparseCore/SparseAssign.h b/Eigen/src/SparseCore/SparseAssign.h index e984bbdb3..4b663a59e 100644 --- a/Eigen/src/SparseCore/SparseAssign.h +++ b/Eigen/src/SparseCore/SparseAssign.h @@ -133,8 +133,8 @@ struct Assignment }; // Sparse to Dense assignment -template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> -struct Assignment +template< typename DstXprType, typename SrcXprType, typename Functor> +struct Assignment { static void run(DstXprType &dst, const SrcXprType &src, const Functor &func) { @@ -149,8 +149,8 @@ struct Assignment } }; -template< typename DstXprType, typename SrcXprType, typename Scalar> -struct Assignment, Sparse2Dense, Scalar> +template< typename DstXprType, typename SrcXprType> +struct Assignment, Sparse2Dense> { static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) { diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index 4e720904e..38eb1c37a 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -281,7 +281,7 @@ template class SparseMatrixBase : public EigenBase // sparse * sparse template - const Product + const Product operator*(const SparseMatrixBase &other) const; // sparse * dense diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h index da8919ecc..26680b7a7 100644 --- a/Eigen/src/SparseCore/SparseProduct.h +++ b/Eigen/src/SparseCore/SparseProduct.h @@ -25,10 +25,10 @@ namespace Eigen { * */ template template -inline const Product +inline const Product SparseMatrixBase::operator*(const SparseMatrixBase &other) const { - return Product(derived(), other.derived()); + return Product(derived(), other.derived()); } namespace internal { @@ -61,6 +61,36 @@ struct generic_product_impl {}; +// Dense = sparse * sparse +template< typename DstXprType, typename Lhs, typename Rhs, int Options/*, typename Scalar*/> +struct Assignment, internal::assign_op, Sparse2Dense/*, + typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/> +{ + typedef Product SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) + { + dst.setZero(); + dst += src; + } +}; + +// Dense += sparse * sparse +template< typename DstXprType, typename Lhs, typename Rhs, int Options> +struct Assignment, internal::add_assign_op, Sparse2Dense/*, + typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/> +{ + typedef Product SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op &) + { + typedef typename nested_eval::type LhsNested; + typedef typename nested_eval::type RhsNested; + LhsNested lhsNested(src.lhs()); + RhsNested rhsNested(src.rhs()); + internal::sparse_sparse_to_dense_product_selector::type, + typename remove_all::type, DstXprType>::run(lhsNested,rhsNested,dst); + } +}; + template struct evaluator > > : public evaluator::PlainObject> diff --git a/test/sparse_product.cpp b/test/sparse_product.cpp index f1e5b8e4c..8c83f08d7 100644 --- a/test/sparse_product.cpp +++ b/test/sparse_product.cpp @@ -76,6 +76,17 @@ template void sparse_product() VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose()); VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose()); + // dense ?= sparse * sparse + VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3); + VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3); + VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3, refMat4 =refMat2t.transpose()*refMat3); + VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3, refMat4+=refMat2t.transpose()*refMat3); + VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3t.transpose(), refMat4 =refMat2t.transpose()*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3t.transpose(), refMat4+=refMat2t.transpose()*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4 =m2*m3t.transpose(), refMat4 =refMat2*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4+=m2*m3t.transpose(), refMat4+=refMat2*refMat3t.transpose()); + VERIFY_IS_APPROX(dm4 = m2*m3*s1, refMat4 = refMat2*refMat3*s1); + // test aliasing m4 = m2; refMat4 = refMat2; VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3);