mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Implement evaluators for sparse * sparse products
This commit is contained in:
parent
0ad7a644df
commit
441f97b2df
@ -49,12 +49,12 @@ struct Sparse {};
|
|||||||
#include "src/SparseCore/SparseRedux.h"
|
#include "src/SparseCore/SparseRedux.h"
|
||||||
#include "src/SparseCore/SparseView.h"
|
#include "src/SparseCore/SparseView.h"
|
||||||
#include "src/SparseCore/SparseDiagonalProduct.h"
|
#include "src/SparseCore/SparseDiagonalProduct.h"
|
||||||
|
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
|
||||||
|
#include "src/SparseCore/SparseProduct.h"
|
||||||
#ifndef EIGEN_TEST_EVALUATORS
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
#include "src/SparseCore/SparsePermutation.h"
|
#include "src/SparseCore/SparsePermutation.h"
|
||||||
#include "src/SparseCore/SparseFuzzy.h"
|
#include "src/SparseCore/SparseFuzzy.h"
|
||||||
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
|
|
||||||
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
||||||
#include "src/SparseCore/SparseProduct.h"
|
|
||||||
#include "src/SparseCore/SparseDenseProduct.h"
|
#include "src/SparseCore/SparseDenseProduct.h"
|
||||||
#include "src/SparseCore/SparseTriangularView.h"
|
#include "src/SparseCore/SparseTriangularView.h"
|
||||||
#include "src/SparseCore/SparseSelfAdjointView.h"
|
#include "src/SparseCore/SparseSelfAdjointView.h"
|
||||||
|
@ -37,6 +37,11 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r
|
|||||||
// Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
|
// Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
|
||||||
Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
|
Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
typename evaluator<Lhs>::type lhsEval(lhs);
|
||||||
|
typename evaluator<Rhs>::type rhsEval(rhs);
|
||||||
|
#endif
|
||||||
|
|
||||||
res.setZero();
|
res.setZero();
|
||||||
res.reserve(Index(estimated_nnz_prod));
|
res.reserve(Index(estimated_nnz_prod));
|
||||||
// we compute each column of the result, one after the other
|
// we compute each column of the result, one after the other
|
||||||
@ -45,11 +50,19 @@ static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& r
|
|||||||
|
|
||||||
res.startVec(j);
|
res.startVec(j);
|
||||||
Index nnz = 0;
|
Index nnz = 0;
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
||||||
|
#else
|
||||||
|
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
Scalar y = rhsIt.value();
|
Scalar y = rhsIt.value();
|
||||||
Index k = rhsIt.index();
|
Index k = rhsIt.index();
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt)
|
for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt)
|
||||||
|
#else
|
||||||
|
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
Index i = lhsIt.index();
|
Index i = lhsIt.index();
|
||||||
Scalar x = lhsIt.value();
|
Scalar x = lhsIt.value();
|
||||||
|
@ -190,8 +190,10 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
inline Derived& operator=(const SparseSparseProduct<Lhs,Rhs>& product);
|
inline Derived& operator=(const SparseSparseProduct<Lhs,Rhs>& product);
|
||||||
|
#endif
|
||||||
|
|
||||||
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
|
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
|
||||||
{
|
{
|
||||||
@ -264,12 +266,12 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
|
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
|
||||||
cwiseProduct(const MatrixBase<OtherDerived> &other) const;
|
cwiseProduct(const MatrixBase<OtherDerived> &other) const;
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
// sparse * sparse
|
// sparse * sparse
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
|
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
|
||||||
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||||
|
|
||||||
#ifndef EIGEN_TEST_EVALUATORS
|
|
||||||
// sparse * diagonal
|
// sparse * diagonal
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
const SparseDiagonalProduct<Derived,OtherDerived>
|
const SparseDiagonalProduct<Derived,OtherDerived>
|
||||||
@ -292,6 +294,11 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
const Product<OtherDerived,Derived>
|
const Product<OtherDerived,Derived>
|
||||||
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
||||||
{ return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
{ return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
||||||
|
|
||||||
|
// sparse * sparse
|
||||||
|
template<typename OtherDerived>
|
||||||
|
const Product<Derived,OtherDerived>
|
||||||
|
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||||
#endif // EIGEN_TEST_EVALUATORS
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
/** dense * sparse (return a dense object unless it is an outer product) */
|
/** dense * sparse (return a dense object unless it is an outer product) */
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
struct SparseSparseProductReturnType
|
struct SparseSparseProductReturnType
|
||||||
{
|
{
|
||||||
@ -183,6 +185,68 @@ SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other
|
|||||||
return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
|
return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
|
||||||
|
/** \returns an expression of the product of two sparse matrices.
|
||||||
|
* By default a conservative product preserving the symbolic non zeros is performed.
|
||||||
|
* The automatic pruning of the small values can be achieved by calling the pruned() function
|
||||||
|
* in which case a totally different product algorithm is employed:
|
||||||
|
* \code
|
||||||
|
* C = (A*B).pruned(); // supress numerical zeros (exact)
|
||||||
|
* C = (A*B).pruned(ref);
|
||||||
|
* C = (A*B).pruned(ref,epsilon);
|
||||||
|
* \endcode
|
||||||
|
* where \c ref is a meaningful non zero reference value.
|
||||||
|
* */
|
||||||
|
template<typename Derived>
|
||||||
|
template<typename OtherDerived>
|
||||||
|
inline const Product<Derived,OtherDerived>
|
||||||
|
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
|
||||||
|
{
|
||||||
|
return Product<Derived,OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductType>
|
||||||
|
struct generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
|
||||||
|
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
|
||||||
|
LhsNested lhsNested(lhs);
|
||||||
|
RhsNested rhsNested(rhs);
|
||||||
|
internal::conservative_sparse_sparse_product_selector<typename remove_all<LhsNested>::type,
|
||||||
|
typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
|
||||||
|
{
|
||||||
|
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
|
||||||
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
|
typedef typename evaluator<PlainObject>::type Base;
|
||||||
|
|
||||||
|
product_evaluator(const XprType& xpr)
|
||||||
|
: m_result(xpr.rows(), xpr.cols())
|
||||||
|
{
|
||||||
|
::new (static_cast<Base*>(this)) Base(m_result);
|
||||||
|
generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
PlainObject m_result;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_SPARSEPRODUCT_H
|
#endif // EIGEN_SPARSEPRODUCT_H
|
||||||
|
Loading…
Reference in New Issue
Block a user