Add support to directly evaluate the product of two sparse matrices within a dense matrix.

This commit is contained in:
Gael Guennebaud 2015-10-26 18:20:00 +01:00
parent a5324a131f
commit e6f8c5c325
5 changed files with 132 additions and 8 deletions

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -255,6 +255,89 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
}
};
} // end namespace internal
namespace internal {
template<typename Lhs, typename Rhs, typename ResultType>
static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef typename remove_all<Lhs>::type::Scalar Scalar;
Index cols = rhs.outerSize();
eigen_assert(lhs.outerSize() == rhs.innerSize());
evaluator<Lhs> lhsEval(lhs);
evaluator<Rhs> rhsEval(rhs);
for (Index j=0; j<cols; ++j)
{
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
{
Scalar y = rhsIt.value();
Index k = rhsIt.index();
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
{
Index i = lhsIt.index();
Scalar x = lhsIt.value();
res.coeffRef(i,j) += x * y;
}
}
}
}
} // end namespace internal
namespace internal {
template<typename Lhs, typename Rhs, typename ResultType,
int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor>
struct sparse_sparse_to_dense_product_selector;
template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor>
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res);
}
};
template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor>
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
ColMajorMatrix lhsCol(lhs);
internal::sparse_sparse_to_dense_product_impl<ColMajorMatrix,Rhs,ResultType>(lhsCol, rhs, res);
}
};
template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor>
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
ColMajorMatrix rhsCol(rhs);
internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorMatrix,ResultType>(lhs, rhsCol, res);
}
};
template<typename Lhs, typename Rhs, typename ResultType>
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor>
{
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{
Transpose<ResultType> trRes(res);
internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes);
}
};
} // end namespace internal
} // end namespace Eigen

View File

@ -133,8 +133,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Sparse, Scalar>
};
// Sparse to Dense assignment
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
template< typename DstXprType, typename SrcXprType, typename Functor>
struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense>
{
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{
@ -149,8 +149,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
}
};
template< typename DstXprType, typename SrcXprType, typename Scalar>
struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense, Scalar>
template< typename DstXprType, typename SrcXprType>
struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense>
{
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
{

View File

@ -281,7 +281,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
// sparse * sparse
template<typename OtherDerived>
const Product<Derived,OtherDerived>
const Product<Derived,OtherDerived,AliasFreeProduct>
operator*(const SparseMatrixBase<OtherDerived> &other) const;
// sparse * dense

View File

@ -25,10 +25,10 @@ namespace Eigen {
* */
template<typename Derived>
template<typename OtherDerived>
inline const Product<Derived,OtherDerived>
inline const Product<Derived,OtherDerived,AliasFreeProduct>
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
{
return Product<Derived,OtherDerived>(derived(), other.derived());
return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived());
}
namespace internal {
@ -61,6 +61,36 @@ struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, Produc
: public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
{};
// Dense = sparse * sparse
template< typename DstXprType, typename Lhs, typename Rhs, int Options/*, typename Scalar*/>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
{
dst.setZero();
dst += src;
}
};
// Dense += sparse * sparse
template< typename DstXprType, typename Lhs, typename Rhs, int Options>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar> &)
{
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
LhsNested lhsNested(src.lhs());
RhsNested rhsNested(src.rhs());
internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type,
typename remove_all<RhsNested>::type, DstXprType>::run(lhsNested,rhsNested,dst);
}
};
template<typename Lhs, typename Rhs, int Options>
struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>

View File

@ -76,6 +76,17 @@ template<typename SparseMatrixType> void sparse_product()
VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose());
VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose());
// dense ?= sparse * sparse
VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3);
VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3);
VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3, refMat4 =refMat2t.transpose()*refMat3);
VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3, refMat4+=refMat2t.transpose()*refMat3);
VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3t.transpose(), refMat4 =refMat2t.transpose()*refMat3t.transpose());
VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3t.transpose(), refMat4+=refMat2t.transpose()*refMat3t.transpose());
VERIFY_IS_APPROX(dm4 =m2*m3t.transpose(), refMat4 =refMat2*refMat3t.transpose());
VERIFY_IS_APPROX(dm4+=m2*m3t.transpose(), refMat4+=refMat2*refMat3t.transpose());
VERIFY_IS_APPROX(dm4 = m2*m3*s1, refMat4 = refMat2*refMat3*s1);
// test aliasing
m4 = m2; refMat4 = refMat2;
VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3);