Implement evaluators for sparse times diagonal products.

This commit is contained in:
Gael Guennebaud 2014-06-27 15:54:44 +02:00
parent ae039dde13
commit 73e686c6a4
3 changed files with 139 additions and 3 deletions

View File

@ -48,6 +48,7 @@ struct Sparse {};
#include "src/SparseCore/SparseDot.h"
#include "src/SparseCore/SparseRedux.h"
#include "src/SparseCore/SparseView.h"
#include "src/SparseCore/SparseDiagonalProduct.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseFuzzy.h"
@ -55,7 +56,6 @@ struct Sparse {};
#include "src/SparseCore/SparseSparseProductWithPruning.h"
#include "src/SparseCore/SparseProduct.h"
#include "src/SparseCore/SparseDenseProduct.h"
#include "src/SparseCore/SparseDiagonalProduct.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/SparseSelfAdjointView.h"
#include "src/SparseCore/TriangularSolver.h"

View File

@ -24,8 +24,10 @@ namespace Eigen {
// for that particular case
// The two other cases are symmetric.
#ifndef EIGEN_TEST_EVALUATORS
namespace internal {
template<typename Lhs, typename Rhs>
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
{
@ -100,9 +102,14 @@ class SparseDiagonalProduct
LhsNested m_lhs;
RhsNested m_rhs;
};
#endif
namespace internal {
#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
class sparse_diagonal_product_inner_iterator_selector
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector
inline Index row() const { return m_outer; }
};
#else // EIGEN_TEST_EVALUATORS
enum {
SDP_AsScalarProduct,
SDP_AsCwiseProduct
};
template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
struct sparse_diagonal_product_evaluator;
template<typename Lhs, typename Rhs, int Options, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
{
typedef Product<Lhs, Rhs, Options> XprType;
typedef evaluator<XprType> type;
typedef evaluator<XprType> nestedType;
enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
};
template<typename Lhs, typename Rhs, int Options, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
{
typedef Product<Lhs, Rhs, Options> XprType;
typedef evaluator<XprType> type;
typedef evaluator<XprType> nestedType;
enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
};
template<typename SparseXprType, typename DiagonalCoeffType>
struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
{
protected:
typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
typedef typename SparseXprType::Scalar Scalar;
typedef typename SparseXprType::Index Index;
public:
class InnerIterator : public SparseXprInnerIterator
{
public:
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
: SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
{}
EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
protected:
typename DiagonalCoeffType::Scalar m_coeff;
};
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
: m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
{}
protected:
typename evaluator<SparseXprType>::nestedType m_sparseXprImpl;
typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl;
};
template<typename SparseXprType, typename DiagCoeffType>
struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
{
typedef typename SparseXprType::Scalar Scalar;
typedef typename SparseXprType::Index Index;
typedef CwiseBinaryOp<scalar_product_op<Scalar>,
const typename SparseXprType::ConstInnerVectorReturnType,
const DiagCoeffType> CwiseProductType;
typedef typename evaluator<CwiseProductType>::type CwiseProductEval;
typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator;
class InnerIterator : public CwiseProductIterator
{
public:
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
: CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0),
m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),
m_outer(outer)
{
::new (static_cast<CwiseProductIterator*>(this)) CwiseProductIterator(m_cwiseEval,0);
}
inline Index outer() const { return m_outer; }
inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; }
inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); }
protected:
Index m_outer;
CwiseProductEval m_cwiseEval;
};
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
: m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff)
{}
protected:
typename nested_eval<SparseXprType,1>::type m_sparseXprNested;
typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
: SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested;
};
#endif // EIGEN_TEST_EVALUATORS
} // end namespace internal
// SparseMatrixBase functions
#ifndef EIGEN_TEST_EVALUATORS
// SparseMatrixBase functions
template<typename Derived>
template<typename OtherDerived>
const SparseDiagonalProduct<Derived,OtherDerived>
@ -190,6 +311,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co
{
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
}
#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen

View File

@ -269,6 +269,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
operator*(const SparseMatrixBase<OtherDerived> &other) const;
#ifndef EIGEN_TEST_EVALUATORS
// sparse * diagonal
template<typename OtherDerived>
const SparseDiagonalProduct<Derived,OtherDerived>
@ -279,6 +280,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
const SparseDiagonalProduct<OtherDerived,Derived>
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
#else // EIGEN_TEST_EVALUATORS
// sparse * diagonal
template<typename OtherDerived>
const Product<Derived,OtherDerived>
operator*(const DiagonalBase<OtherDerived> &other) const
{ return Product<Derived,OtherDerived>(derived(), other.derived()); }
// diagonal * sparse
template<typename OtherDerived> friend
const Product<OtherDerived,Derived>
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
{ return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
#endif // EIGEN_TEST_EVALUATORS
/** dense * sparse (return a dense object unless it is an outer product) */
template<typename OtherDerived> friend