mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-13 18:37:27 +08:00
Add support to directly evaluate the product of two sparse matrices within a dense matrix.
This commit is contained in:
parent
a5324a131f
commit
e6f8c5c325
@ -1,7 +1,7 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
@ -255,6 +255,89 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
||||
Index cols = rhs.outerSize();
|
||||
eigen_assert(lhs.outerSize() == rhs.innerSize());
|
||||
|
||||
evaluator<Lhs> lhsEval(lhs);
|
||||
evaluator<Rhs> rhsEval(rhs);
|
||||
|
||||
for (Index j=0; j<cols; ++j)
|
||||
{
|
||||
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||
{
|
||||
Scalar y = rhsIt.value();
|
||||
Index k = rhsIt.index();
|
||||
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
|
||||
{
|
||||
Index i = lhsIt.index();
|
||||
Scalar x = lhsIt.value();
|
||||
res.coeffRef(i,j) += x * y;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType,
|
||||
int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||
int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
struct sparse_sparse_to_dense_product_selector;
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor>
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor>
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||
ColMajorMatrix lhsCol(lhs);
|
||||
internal::sparse_sparse_to_dense_product_impl<ColMajorMatrix,Rhs,ResultType>(lhsCol, rhs, res);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor>
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||
ColMajorMatrix rhsCol(rhs);
|
||||
internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorMatrix,ResultType>(lhs, rhsCol, res);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename ResultType>
|
||||
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor>
|
||||
{
|
||||
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||
{
|
||||
Transpose<ResultType> trRes(res);
|
||||
internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -133,8 +133,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Sparse, Scalar>
|
||||
};
|
||||
|
||||
// Sparse to Dense assignment
|
||||
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
|
||||
struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
|
||||
template< typename DstXprType, typename SrcXprType, typename Functor>
|
||||
struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense>
|
||||
{
|
||||
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
|
||||
{
|
||||
@ -149,8 +149,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
|
||||
}
|
||||
};
|
||||
|
||||
template< typename DstXprType, typename SrcXprType, typename Scalar>
|
||||
struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense, Scalar>
|
||||
template< typename DstXprType, typename SrcXprType>
|
||||
struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense>
|
||||
{
|
||||
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
|
||||
{
|
||||
|
@ -281,7 +281,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
||||
|
||||
// sparse * sparse
|
||||
template<typename OtherDerived>
|
||||
const Product<Derived,OtherDerived>
|
||||
const Product<Derived,OtherDerived,AliasFreeProduct>
|
||||
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||
|
||||
// sparse * dense
|
||||
|
@ -25,10 +25,10 @@ namespace Eigen {
|
||||
* */
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
inline const Product<Derived,OtherDerived>
|
||||
inline const Product<Derived,OtherDerived,AliasFreeProduct>
|
||||
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
|
||||
{
|
||||
return Product<Derived,OtherDerived>(derived(), other.derived());
|
||||
return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived());
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
@ -61,6 +61,36 @@ struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, Produc
|
||||
: public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
|
||||
{};
|
||||
|
||||
// Dense = sparse * sparse
|
||||
template< typename DstXprType, typename Lhs, typename Rhs, int Options/*, typename Scalar*/>
|
||||
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
|
||||
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
|
||||
{
|
||||
typedef Product<Lhs,Rhs,Options> SrcXprType;
|
||||
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
|
||||
{
|
||||
dst.setZero();
|
||||
dst += src;
|
||||
}
|
||||
};
|
||||
|
||||
// Dense += sparse * sparse
|
||||
template< typename DstXprType, typename Lhs, typename Rhs, int Options>
|
||||
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
|
||||
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
|
||||
{
|
||||
typedef Product<Lhs,Rhs,Options> SrcXprType;
|
||||
static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar> &)
|
||||
{
|
||||
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
|
||||
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
|
||||
LhsNested lhsNested(src.lhs());
|
||||
RhsNested rhsNested(src.rhs());
|
||||
internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type,
|
||||
typename remove_all<RhsNested>::type, DstXprType>::run(lhsNested,rhsNested,dst);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Options>
|
||||
struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
|
||||
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>
|
||||
|
@ -76,6 +76,17 @@ template<typename SparseMatrixType> void sparse_product()
|
||||
VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose());
|
||||
VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose());
|
||||
|
||||
// dense ?= sparse * sparse
|
||||
VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3);
|
||||
VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3);
|
||||
VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3, refMat4 =refMat2t.transpose()*refMat3);
|
||||
VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3, refMat4+=refMat2t.transpose()*refMat3);
|
||||
VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3t.transpose(), refMat4 =refMat2t.transpose()*refMat3t.transpose());
|
||||
VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3t.transpose(), refMat4+=refMat2t.transpose()*refMat3t.transpose());
|
||||
VERIFY_IS_APPROX(dm4 =m2*m3t.transpose(), refMat4 =refMat2*refMat3t.transpose());
|
||||
VERIFY_IS_APPROX(dm4+=m2*m3t.transpose(), refMat4+=refMat2*refMat3t.transpose());
|
||||
VERIFY_IS_APPROX(dm4 = m2*m3*s1, refMat4 = refMat2*refMat3*s1);
|
||||
|
||||
// test aliasing
|
||||
m4 = m2; refMat4 = refMat2;
|
||||
VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3);
|
||||
|
Loading…
x
Reference in New Issue
Block a user