mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Implement evaluators for sparse times diagonal products.
This commit is contained in:
parent
ae039dde13
commit
73e686c6a4
@ -48,6 +48,7 @@ struct Sparse {};
|
||||
#include "src/SparseCore/SparseDot.h"
|
||||
#include "src/SparseCore/SparseRedux.h"
|
||||
#include "src/SparseCore/SparseView.h"
|
||||
#include "src/SparseCore/SparseDiagonalProduct.h"
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
#include "src/SparseCore/SparsePermutation.h"
|
||||
#include "src/SparseCore/SparseFuzzy.h"
|
||||
@ -55,7 +56,6 @@ struct Sparse {};
|
||||
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
||||
#include "src/SparseCore/SparseProduct.h"
|
||||
#include "src/SparseCore/SparseDenseProduct.h"
|
||||
#include "src/SparseCore/SparseDiagonalProduct.h"
|
||||
#include "src/SparseCore/SparseTriangularView.h"
|
||||
#include "src/SparseCore/SparseSelfAdjointView.h"
|
||||
#include "src/SparseCore/TriangularSolver.h"
|
||||
|
@ -24,8 +24,10 @@ namespace Eigen {
|
||||
// for that particular case
|
||||
// The two other cases are symmetric.
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
namespace internal {
|
||||
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
|
||||
{
|
||||
@ -100,9 +102,14 @@ class SparseDiagonalProduct
|
||||
LhsNested m_lhs;
|
||||
RhsNested m_rhs;
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace internal {
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
|
||||
|
||||
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
|
||||
class sparse_diagonal_product_inner_iterator_selector
|
||||
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
|
||||
@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector
|
||||
inline Index row() const { return m_outer; }
|
||||
};
|
||||
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
enum {
|
||||
SDP_AsScalarProduct,
|
||||
SDP_AsCwiseProduct
|
||||
};
|
||||
|
||||
template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
|
||||
struct sparse_diagonal_product_evaluator;
|
||||
|
||||
template<typename Lhs, typename Rhs, int Options, int ProductTag>
|
||||
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||
: public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
|
||||
{
|
||||
typedef Product<Lhs, Rhs, Options> XprType;
|
||||
typedef evaluator<XprType> type;
|
||||
typedef evaluator<XprType> nestedType;
|
||||
enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
|
||||
|
||||
typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
|
||||
product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Options, int ProductTag>
|
||||
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||
: public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
|
||||
{
|
||||
typedef Product<Lhs, Rhs, Options> XprType;
|
||||
typedef evaluator<XprType> type;
|
||||
typedef evaluator<XprType> nestedType;
|
||||
enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
|
||||
|
||||
typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
|
||||
product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
|
||||
};
|
||||
|
||||
template<typename SparseXprType, typename DiagonalCoeffType>
|
||||
struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
|
||||
{
|
||||
protected:
|
||||
typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
|
||||
typedef typename SparseXprType::Scalar Scalar;
|
||||
typedef typename SparseXprType::Index Index;
|
||||
|
||||
public:
|
||||
class InnerIterator : public SparseXprInnerIterator
|
||||
{
|
||||
public:
|
||||
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
|
||||
: SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
|
||||
m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
|
||||
{}
|
||||
|
||||
EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
|
||||
protected:
|
||||
typename DiagonalCoeffType::Scalar m_coeff;
|
||||
};
|
||||
|
||||
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
|
||||
: m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
|
||||
{}
|
||||
|
||||
protected:
|
||||
typename evaluator<SparseXprType>::nestedType m_sparseXprImpl;
|
||||
typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl;
|
||||
};
|
||||
|
||||
|
||||
template<typename SparseXprType, typename DiagCoeffType>
|
||||
struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
|
||||
{
|
||||
typedef typename SparseXprType::Scalar Scalar;
|
||||
typedef typename SparseXprType::Index Index;
|
||||
|
||||
typedef CwiseBinaryOp<scalar_product_op<Scalar>,
|
||||
const typename SparseXprType::ConstInnerVectorReturnType,
|
||||
const DiagCoeffType> CwiseProductType;
|
||||
|
||||
typedef typename evaluator<CwiseProductType>::type CwiseProductEval;
|
||||
typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator;
|
||||
|
||||
class InnerIterator : public CwiseProductIterator
|
||||
{
|
||||
public:
|
||||
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
|
||||
: CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0),
|
||||
m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),
|
||||
m_outer(outer)
|
||||
{
|
||||
::new (static_cast<CwiseProductIterator*>(this)) CwiseProductIterator(m_cwiseEval,0);
|
||||
}
|
||||
|
||||
inline Index outer() const { return m_outer; }
|
||||
inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; }
|
||||
inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); }
|
||||
|
||||
protected:
|
||||
Index m_outer;
|
||||
CwiseProductEval m_cwiseEval;
|
||||
};
|
||||
|
||||
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
|
||||
: m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff)
|
||||
{}
|
||||
|
||||
protected:
|
||||
typename nested_eval<SparseXprType,1>::type m_sparseXprNested;
|
||||
typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
|
||||
: SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested;
|
||||
};
|
||||
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
// SparseMatrixBase functions
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
// SparseMatrixBase functions
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
const SparseDiagonalProduct<Derived,OtherDerived>
|
||||
@ -190,6 +311,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co
|
||||
{
|
||||
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
|
||||
}
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
|
@ -269,6 +269,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
||||
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
|
||||
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
// sparse * diagonal
|
||||
template<typename OtherDerived>
|
||||
const SparseDiagonalProduct<Derived,OtherDerived>
|
||||
@ -279,6 +280,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
||||
const SparseDiagonalProduct<OtherDerived,Derived>
|
||||
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
||||
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
// sparse * diagonal
|
||||
template<typename OtherDerived>
|
||||
const Product<Derived,OtherDerived>
|
||||
operator*(const DiagonalBase<OtherDerived> &other) const
|
||||
{ return Product<Derived,OtherDerived>(derived(), other.derived()); }
|
||||
|
||||
// diagonal * sparse
|
||||
template<typename OtherDerived> friend
|
||||
const Product<OtherDerived,Derived>
|
||||
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
||||
{ return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
/** dense * sparse (return a dense object unless it is an outer product) */
|
||||
template<typename OtherDerived> friend
|
||||
|
Loading…
Reference in New Issue
Block a user