mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
bug #632: implement general coefficient-wise "dense op sparse" operations through specialized evaluators instead of using SparseView.
This permits to deal with arbitrary storage order, and to by-pass the more complex iterator of the sparse-sparse case.
This commit is contained in:
parent
699634890a
commit
8ed1553d20
@ -526,22 +526,21 @@ template <typename A> struct promote_storage_type<const A, A>
|
||||
* the functor.
|
||||
* The default rules are as follows:
|
||||
* \code
|
||||
* A op A -> A
|
||||
* A op dense -> dense
|
||||
* dense op B -> dense
|
||||
* A * dense -> A
|
||||
* dense * B -> B
|
||||
* A op A -> A
|
||||
* A op dense -> dense
|
||||
* dense op B -> dense
|
||||
* sparse op dense -> sparse
|
||||
* dense op sparse -> sparse
|
||||
* \endcode
|
||||
*/
|
||||
template <typename A, typename B, typename Functor> struct cwise_promote_storage_type;
|
||||
|
||||
template <typename A, typename Functor> struct cwise_promote_storage_type<A,A,Functor> { typedef A ret; };
|
||||
template <typename Functor> struct cwise_promote_storage_type<Dense,Dense,Functor> { typedef Dense ret; };
|
||||
template <typename ScalarA, typename ScalarB> struct cwise_promote_storage_type<Dense,Dense,scalar_product_op<ScalarA,ScalarB> > { typedef Dense ret; };
|
||||
template <typename A, typename Functor> struct cwise_promote_storage_type<A,Dense,Functor> { typedef Dense ret; };
|
||||
template <typename B, typename Functor> struct cwise_promote_storage_type<Dense,B,Functor> { typedef Dense ret; };
|
||||
template <typename A, typename ScalarA, typename ScalarB> struct cwise_promote_storage_type<A,Dense,scalar_product_op<ScalarA,ScalarB> > { typedef A ret; };
|
||||
template <typename B, typename ScalarA, typename ScalarB> struct cwise_promote_storage_type<Dense,B,scalar_product_op<ScalarA,ScalarB> > { typedef B ret; };
|
||||
template <typename A, typename Functor> struct cwise_promote_storage_type<A,A,Functor> { typedef A ret; };
|
||||
template <typename Functor> struct cwise_promote_storage_type<Dense,Dense,Functor> { typedef Dense ret; };
|
||||
template <typename A, typename Functor> struct cwise_promote_storage_type<A,Dense,Functor> { typedef Dense ret; };
|
||||
template <typename B, typename Functor> struct cwise_promote_storage_type<Dense,B,Functor> { typedef Dense ret; };
|
||||
template <typename Functor> struct cwise_promote_storage_type<Sparse,Dense,Functor> { typedef Sparse ret; };
|
||||
template <typename Functor> struct cwise_promote_storage_type<Dense,Sparse,Functor> { typedef Sparse ret; };
|
||||
|
||||
/** \internal Specify the "storage kind" of multiplying an expression of kind A with kind B.
|
||||
* The template parameter ProductTag permits to specialize the resulting storage kind wrt to
|
||||
|
@ -49,15 +49,6 @@ class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
|
||||
typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
|
||||
typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
|
||||
class sparse_cwise_binary_op_inner_iterator_selector;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
namespace internal {
|
||||
|
||||
|
||||
// Generic "sparse OP sparse"
|
||||
template<typename XprType> struct binary_sparse_evaluator;
|
||||
@ -155,6 +146,182 @@ protected:
|
||||
evaluator<Rhs> m_rhsImpl;
|
||||
};
|
||||
|
||||
// dense op sparse
|
||||
template<typename BinaryOp, typename Lhs, typename Rhs>
|
||||
struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IndexBased, IteratorBased>
|
||||
: evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
|
||||
{
|
||||
protected:
|
||||
typedef typename evaluator<Rhs>::InnerIterator RhsIterator;
|
||||
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
|
||||
typedef typename traits<XprType>::Scalar Scalar;
|
||||
typedef typename XprType::StorageIndex StorageIndex;
|
||||
public:
|
||||
|
||||
class ReverseInnerIterator;
|
||||
class InnerIterator
|
||||
{
|
||||
enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
|
||||
public:
|
||||
|
||||
EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer)
|
||||
: m_lhsEval(aEval.m_lhsImpl), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor), m_id(-1), m_innerSize(aEval.m_expr.rhs().innerSize())
|
||||
{
|
||||
this->operator++();
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE InnerIterator& operator++()
|
||||
{
|
||||
++m_id;
|
||||
if(m_id<m_innerSize)
|
||||
{
|
||||
Scalar lhsVal = m_lhsEval.coeff(IsRowMajor?m_rhsIter.outer():m_id,
|
||||
IsRowMajor?m_id:m_rhsIter.outer());
|
||||
if(m_rhsIter && m_rhsIter.index()==m_id)
|
||||
{
|
||||
m_value = m_functor(lhsVal, m_rhsIter.value());
|
||||
++m_rhsIter;
|
||||
}
|
||||
else
|
||||
m_value = m_functor(lhsVal, Scalar(0));
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
|
||||
|
||||
EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; }
|
||||
EIGEN_STRONG_INLINE Index row() const { return IsRowMajor ? m_rhsIter.outer() : m_id; }
|
||||
EIGEN_STRONG_INLINE Index col() const { return IsRowMajor ? m_id : m_rhsIter.outer(); }
|
||||
|
||||
EIGEN_STRONG_INLINE operator bool() const { return m_id<m_innerSize; }
|
||||
|
||||
protected:
|
||||
const evaluator<Lhs> &m_lhsEval;
|
||||
RhsIterator m_rhsIter;
|
||||
const BinaryOp& m_functor;
|
||||
Scalar m_value;
|
||||
StorageIndex m_id;
|
||||
StorageIndex m_innerSize;
|
||||
};
|
||||
|
||||
|
||||
enum {
|
||||
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
|
||||
// Expose storage order of the sparse expression
|
||||
Flags = (XprType::Flags & ~RowMajorBit) | (int(Rhs::Flags)&RowMajorBit)
|
||||
};
|
||||
|
||||
explicit binary_evaluator(const XprType& xpr)
|
||||
: m_functor(xpr.functor()),
|
||||
m_lhsImpl(xpr.lhs()),
|
||||
m_rhsImpl(xpr.rhs()),
|
||||
m_expr(xpr)
|
||||
{
|
||||
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost);
|
||||
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
|
||||
}
|
||||
|
||||
inline Index nonZerosEstimate() const {
|
||||
return m_expr.size();
|
||||
}
|
||||
|
||||
protected:
|
||||
const BinaryOp m_functor;
|
||||
evaluator<Lhs> m_lhsImpl;
|
||||
evaluator<Rhs> m_rhsImpl;
|
||||
const XprType &m_expr;
|
||||
};
|
||||
|
||||
// sparse op dense
|
||||
template<typename BinaryOp, typename Lhs, typename Rhs>
|
||||
struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IteratorBased, IndexBased>
|
||||
: evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
|
||||
{
|
||||
protected:
|
||||
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
|
||||
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
|
||||
typedef typename traits<XprType>::Scalar Scalar;
|
||||
typedef typename XprType::StorageIndex StorageIndex;
|
||||
public:
|
||||
|
||||
class ReverseInnerIterator;
|
||||
class InnerIterator
|
||||
{
|
||||
enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
|
||||
public:
|
||||
|
||||
EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer)
|
||||
: m_lhsIter(aEval.m_lhsImpl,outer), m_rhsEval(aEval.m_rhsImpl), m_functor(aEval.m_functor), m_id(-1), m_innerSize(aEval.m_expr.lhs().innerSize())
|
||||
{
|
||||
this->operator++();
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE InnerIterator& operator++()
|
||||
{
|
||||
++m_id;
|
||||
if(m_id<m_innerSize)
|
||||
{
|
||||
Scalar rhsVal = m_rhsEval.coeff(IsRowMajor?m_lhsIter.outer():m_id,
|
||||
IsRowMajor?m_id:m_lhsIter.outer());
|
||||
if(m_lhsIter && m_lhsIter.index()==m_id)
|
||||
{
|
||||
m_value = m_functor(m_lhsIter.value(), rhsVal);
|
||||
++m_lhsIter;
|
||||
}
|
||||
else
|
||||
m_value = m_functor(Scalar(0),rhsVal);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
|
||||
|
||||
EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; }
|
||||
EIGEN_STRONG_INLINE Index row() const { return IsRowMajor ? m_lhsIter.outer() : m_id; }
|
||||
EIGEN_STRONG_INLINE Index col() const { return IsRowMajor ? m_id : m_lhsIter.outer(); }
|
||||
|
||||
EIGEN_STRONG_INLINE operator bool() const { return m_id<m_innerSize; }
|
||||
|
||||
protected:
|
||||
LhsIterator m_lhsIter;
|
||||
const evaluator<Rhs> &m_rhsEval;
|
||||
const BinaryOp& m_functor;
|
||||
Scalar m_value;
|
||||
StorageIndex m_id;
|
||||
StorageIndex m_innerSize;
|
||||
};
|
||||
|
||||
|
||||
enum {
|
||||
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
|
||||
// Expose storage order of the sparse expression
|
||||
Flags = (XprType::Flags & ~RowMajorBit) | (int(Lhs::Flags)&RowMajorBit)
|
||||
};
|
||||
|
||||
explicit binary_evaluator(const XprType& xpr)
|
||||
: m_functor(xpr.functor()),
|
||||
m_lhsImpl(xpr.lhs()),
|
||||
m_rhsImpl(xpr.rhs()),
|
||||
m_expr(xpr)
|
||||
{
|
||||
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost);
|
||||
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
|
||||
}
|
||||
|
||||
inline Index nonZerosEstimate() const {
|
||||
return m_expr.size();
|
||||
}
|
||||
|
||||
protected:
|
||||
const BinaryOp m_functor;
|
||||
evaluator<Lhs> m_lhsImpl;
|
||||
evaluator<Rhs> m_rhsImpl;
|
||||
const XprType &m_expr;
|
||||
};
|
||||
|
||||
// "sparse .* sparse"
|
||||
template<typename T, typename Lhs, typename Rhs>
|
||||
struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs>, IteratorBased, IteratorBased>
|
||||
@ -289,7 +456,8 @@ public:
|
||||
|
||||
enum {
|
||||
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
|
||||
Flags = XprType::Flags
|
||||
// Expose storage order of the sparse expression
|
||||
Flags = (XprType::Flags & ~RowMajorBit) | (int(Rhs::Flags)&RowMajorBit)
|
||||
};
|
||||
|
||||
explicit binary_evaluator(const XprType& xpr)
|
||||
@ -362,7 +530,8 @@ public:
|
||||
|
||||
enum {
|
||||
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
|
||||
Flags = XprType::Flags
|
||||
// Expose storage order of the sparse expression
|
||||
Flags = (XprType::Flags & ~RowMajorBit) | (int(Lhs::Flags)&RowMajorBit)
|
||||
};
|
||||
|
||||
explicit binary_evaluator(const XprType& xpr)
|
||||
@ -430,32 +599,32 @@ SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) c
|
||||
return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived());
|
||||
}
|
||||
|
||||
template<typename SparseDerived, typename DenseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const SparseDerived, const SparseView<DenseDerived,true> >
|
||||
operator+(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b)
|
||||
{
|
||||
return a.derived() + SparseView<DenseDerived,true>(b.derived());
|
||||
}
|
||||
|
||||
template<typename DenseDerived, typename SparseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const SparseView<DenseDerived,true>, const SparseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived>
|
||||
operator+(const MatrixBase<DenseDerived> &a, const SparseMatrixBase<SparseDerived> &b)
|
||||
{
|
||||
return SparseView<DenseDerived,true>(a.derived()) + b.derived();
|
||||
return CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived>(a.derived(), b.derived());
|
||||
}
|
||||
|
||||
template<typename SparseDerived, typename DenseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const SparseDerived, const SparseView<DenseDerived,true> >
|
||||
operator-(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b)
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>
|
||||
operator+(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b)
|
||||
{
|
||||
return a.derived() - SparseView<DenseDerived,true>(b.derived());
|
||||
return CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>(a.derived(), b.derived());
|
||||
}
|
||||
|
||||
template<typename DenseDerived, typename SparseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const SparseView<DenseDerived,true>, const SparseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived>
|
||||
operator-(const MatrixBase<DenseDerived> &a, const SparseMatrixBase<SparseDerived> &b)
|
||||
{
|
||||
return SparseView<DenseDerived,true>(a.derived()) - b.derived();
|
||||
return CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived>(a.derived(), b.derived());
|
||||
}
|
||||
|
||||
template<typename SparseDerived, typename DenseDerived>
|
||||
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>
|
||||
operator-(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b)
|
||||
{
|
||||
return CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>(a.derived(), b.derived());
|
||||
}
|
||||
|
||||
} // end namespace Eigen
|
||||
|
Loading…
Reference in New Issue
Block a user