Implement evaluators for sparse*dense products

This commit is contained in:
Gael Guennebaud 2014-07-01 17:53:18 +02:00
parent 1e6f53e070
commit 7390af91b6
3 changed files with 252 additions and 118 deletions

View File

@ -52,10 +52,10 @@ struct Sparse {};
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
#include "src/SparseCore/SparseSparseProductWithPruning.h"
#include "src/SparseCore/SparseProduct.h"
#include "src/SparseCore/SparseDenseProduct.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseFuzzy.h"
#include "src/SparseCore/SparseDenseProduct.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/SparseSelfAdjointView.h"
#include "src/SparseCore/TriangularSolver.h"

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -12,6 +12,152 @@
namespace Eigen {
namespace internal {
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
typename AlphaType,
int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
struct sparse_time_dense_product_impl;
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::Index Index;
#ifndef EIGEN_TEST_EVALUATORS
typedef typename Lhs::InnerIterator LhsInnerIterator;
#else
typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
#endif
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
{
#ifndef EIGEN_TEST_EVALUATORS
const Lhs &lhsEval(lhs);
#else
typename evaluator<Lhs>::type lhsEval(lhs);
#endif
for(Index c=0; c<rhs.cols(); ++c)
{
Index n = lhs.outerSize();
for(Index j=0; j<n; ++j)
{
typename Res::Scalar tmp(0);
for(LhsInnerIterator it(lhsEval,j); it ;++it)
tmp += it.value() * rhs.coeff(it.index(),c);
res.coeffRef(j,c) = alpha * tmp;
}
}
}
};
template<typename T1, typename T2/*, int _Options, typename _StrideType*/>
struct scalar_product_traits<T1, Ref<T2/*, _Options, _StrideType*/> >
{
enum {
Defined = 1
};
typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::Index Index;
#ifndef EIGEN_TEST_EVALUATORS
typedef typename Lhs::InnerIterator LhsInnerIterator;
#else
typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
#endif
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
{
#ifndef EIGEN_TEST_EVALUATORS
const Lhs &lhsEval(lhs);
#else
typename evaluator<Lhs>::type lhsEval(lhs);
#endif
for(Index c=0; c<rhs.cols(); ++c)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
// typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c));
for(LhsInnerIterator it(lhsEval,j); it ;++it)
res.coeffRef(it.index(),c) += it.value() * rhs_j;
}
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::Index Index;
#ifndef EIGEN_TEST_EVALUATORS
typedef typename Lhs::InnerIterator LhsInnerIterator;
#else
typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
#endif
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
{
#ifndef EIGEN_TEST_EVALUATORS
const Lhs &lhsEval(lhs);
#else
typename evaluator<Lhs>::type lhsEval(lhs);
#endif
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Res::RowXpr res_j(res.row(j));
for(LhsInnerIterator it(lhsEval,j); it ;++it)
res_j += (alpha*it.value()) * rhs.row(it.index());
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::Index Index;
#ifndef EIGEN_TEST_EVALUATORS
typedef typename Lhs::InnerIterator LhsInnerIterator;
#else
typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
#endif
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
{
#ifndef EIGEN_TEST_EVALUATORS
const Lhs &lhsEval(lhs);
#else
typename evaluator<Lhs>::type lhsEval(lhs);
#endif
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
for(LhsInnerIterator it(lhsEval,j); it ;++it)
res.row(it.index()) += (alpha*it.value()) * rhs_j;
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
{
sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
}
} // end namespace internal
#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
{
typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
@ -138,111 +284,6 @@ struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
typedef MatrixXpr XprKind;
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
typename AlphaType,
int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
struct sparse_time_dense_product_impl;
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::Index Index;
typedef typename Lhs::InnerIterator LhsInnerIterator;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
{
for(Index c=0; c<rhs.cols(); ++c)
{
Index n = lhs.outerSize();
for(Index j=0; j<n; ++j)
{
typename Res::Scalar tmp(0);
for(LhsInnerIterator it(lhs,j); it ;++it)
tmp += it.value() * rhs.coeff(it.index(),c);
res.coeffRef(j,c) = alpha * tmp;
}
}
}
};
template<typename T1, typename T2/*, int _Options, typename _StrideType*/>
struct scalar_product_traits<T1, Ref<T2/*, _Options, _StrideType*/> >
{
enum {
Defined = 1
};
typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::InnerIterator LhsInnerIterator;
typedef typename Lhs::Index Index;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
{
for(Index c=0; c<rhs.cols(); ++c)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
// typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c));
for(LhsInnerIterator it(lhs,j); it ;++it)
res.coeffRef(it.index(),c) += it.value() * rhs_j;
}
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::InnerIterator LhsInnerIterator;
typedef typename Lhs::Index Index;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Res::RowXpr res_j(res.row(j));
for(LhsInnerIterator it(lhs,j); it ;++it)
res_j += (alpha*it.value()) * rhs.row(it.index());
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::InnerIterator LhsInnerIterator;
typedef typename Lhs::Index Index;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
for(LhsInnerIterator it(lhs,j); it ;++it)
res.row(it.index()) += (alpha*it.value()) * rhs_j;
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
{
sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
}
} // end namespace internal
template<typename Lhs, typename Rhs>
@ -305,6 +346,87 @@ SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) cons
{
return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
}
#endif // EIGEN_TEST_EVALUATORS
#ifdef EIGEN_TEST_EVALUATORS
namespace internal {
template<typename Lhs, typename Rhs, int ProductType>
struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
LhsNested lhsNested(lhs);
RhsNested rhsNested(rhs);
dst.setZero();
internal::sparse_time_dense_product(lhsNested, rhsNested, dst, typename Dest::Scalar(1));
}
};
template<typename Lhs, typename Rhs, int ProductType>
struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
{
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
LhsNested lhsNested(lhs);
RhsNested rhsNested(rhs);
dst.setZero();
// transpoe everything
Transpose<Dest> dstT(dst);
internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, typename Dest::Scalar(1));
}
};
template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
{
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type Base;
product_evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
::new (static_cast<Base*>(this)) Base(m_result);
generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
}
protected:
PlainObject m_result;
};
template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DenseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
{
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type Base;
product_evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
::new (static_cast<Base*>(this)) Base(m_result);
generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
}
protected:
PlainObject m_result;
};
} // end namespace internal
#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen

View File

@ -282,6 +282,17 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
const SparseDiagonalProduct<OtherDerived,Derived>
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
/** dense * sparse (return a dense object unless it is an outer product) */
template<typename OtherDerived> friend
const typename DenseSparseProductReturnType<OtherDerived,Derived>::Type
operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs)
{ return typename DenseSparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); }
/** sparse * dense (returns a dense object unless it is an outer product) */
template<typename OtherDerived>
const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
operator*(const MatrixBase<OtherDerived> &other) const;
#else // EIGEN_TEST_EVALUATORS
// sparse * diagonal
template<typename OtherDerived>
@ -299,18 +310,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
template<typename OtherDerived>
const Product<Derived,OtherDerived>
operator*(const SparseMatrixBase<OtherDerived> &other) const;
#endif // EIGEN_TEST_EVALUATORS
/** dense * sparse (return a dense object unless it is an outer product) */
template<typename OtherDerived> friend
const typename DenseSparseProductReturnType<OtherDerived,Derived>::Type
operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs)
{ return typename DenseSparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); }
/** sparse * dense (returns a dense object unless it is an outer product) */
// sparse * dense
template<typename OtherDerived>
const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
operator*(const MatrixBase<OtherDerived> &other) const;
const Product<Derived,OtherDerived>
operator*(const MatrixBase<OtherDerived> &other) const
{ return Product<Derived,OtherDerived>(derived(), other.derived()); }
// dense * sparse
template<typename OtherDerived> friend
const Product<OtherDerived,Derived>
operator*(const MatrixBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
{ return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
#endif // EIGEN_TEST_EVALUATORS
/** \returns an expression of P H P^-1 where H is the matrix represented by \c *this */
SparseSymmetricPermutationProduct<Derived,Upper|Lower> twistedBy(const PermutationMatrix<Dynamic,Dynamic,Index>& perm) const