mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-31 19:00:35 +08:00
Implement evaluator for sparse outer products
This commit is contained in:
parent
36e6c9064f
commit
3eba5e1101
@ -109,7 +109,7 @@ struct evaluator_base
|
||||
typedef evaluator<ExpressionType> type;
|
||||
typedef evaluator<ExpressionType> nestedType;
|
||||
|
||||
typedef typename ExpressionType::Index Index;
|
||||
typedef typename traits<ExpressionType>::Index Index;
|
||||
// TODO that's not very nice to have to propagate all these traits. They are currently only needed to handle outer,inner indices.
|
||||
typedef traits<ExpressionType> ExpressionTraits;
|
||||
};
|
||||
|
@ -186,14 +186,14 @@ void assign_sparse_to_sparse(DstXprType &dst, const SrcXprType &src)
|
||||
SrcEvaluatorType srcEvaluator(src);
|
||||
|
||||
const bool transpose = (DstEvaluatorType::Flags & RowMajorBit) != (SrcEvaluatorType::Flags & RowMajorBit);
|
||||
const Index outerSize = (int(SrcEvaluatorType::Flags) & RowMajorBit) ? src.rows() : src.cols();
|
||||
const Index outerEvaluationSize = (SrcEvaluatorType::Flags&RowMajorBit) ? src.rows() : src.cols();
|
||||
if ((!transpose) && src.isRValue())
|
||||
{
|
||||
// eval without temporary
|
||||
dst.resize(src.rows(), src.cols());
|
||||
dst.setZero();
|
||||
dst.reserve((std::max)(src.rows(),src.cols())*2);
|
||||
for (Index j=0; j<outerSize; ++j)
|
||||
for (Index j=0; j<outerEvaluationSize; ++j)
|
||||
{
|
||||
dst.startVec(j);
|
||||
for (typename SrcEvaluatorType::InnerIterator it(srcEvaluator, j); it; ++it)
|
||||
@ -213,11 +213,11 @@ void assign_sparse_to_sparse(DstXprType &dst, const SrcXprType &src)
|
||||
|
||||
enum { Flip = (DstEvaluatorType::Flags & RowMajorBit) != (SrcEvaluatorType::Flags & RowMajorBit) };
|
||||
|
||||
const Index outerSize = src.outerSize();
|
||||
|
||||
DstXprType temp(src.rows(), src.cols());
|
||||
|
||||
temp.reserve((std::max)(src.rows(),src.cols())*2);
|
||||
for (Index j=0; j<outerSize; ++j)
|
||||
for (Index j=0; j<outerEvaluationSize; ++j)
|
||||
{
|
||||
temp.startVec(j);
|
||||
for (typename SrcEvaluatorType::InnerIterator it(srcEvaluator, j); it; ++it)
|
||||
@ -256,7 +256,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
|
||||
dst.setZero();
|
||||
typename internal::evaluator<SrcXprType>::type srcEval(src);
|
||||
typename internal::evaluator<DstXprType>::type dstEval(dst);
|
||||
for (Index j=0; j<src.outerSize(); ++j)
|
||||
const Index outerEvaluationSize = (internal::evaluator<SrcXprType>::Flags&RowMajorBit) ? src.rows() : src.cols();
|
||||
for (Index j=0; j<outerEvaluationSize; ++j)
|
||||
for (typename internal::evaluator<SrcXprType>::InnerIterator i(srcEval,j); i; ++i)
|
||||
dstEval.coeffRef(i.row(),i.col()) = i.value();
|
||||
}
|
||||
|
@ -13,7 +13,10 @@
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
|
||||
template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; };
|
||||
template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; };
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
|
||||
typename AlphaType,
|
||||
int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
|
||||
@ -445,6 +448,156 @@ protected:
|
||||
PlainObject m_result;
|
||||
};
|
||||
|
||||
|
||||
// template<typename Lhs, typename Rhs, bool Transpose, typename LhsIterator>
|
||||
// class sparse_dense_outer_product_iterator : public LhsIterator
|
||||
// {
|
||||
// typedef typename SparseDenseOuterProduct::Index Index;
|
||||
// public:
|
||||
// template<typename XprEval>
|
||||
// EIGEN_STRONG_INLINE InnerIterator(const XprEval& prod, Index outer)
|
||||
// : LhsIterator(prod.lhs(), 0),
|
||||
// m_outer(outer), m_empty(false), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
|
||||
// {}
|
||||
//
|
||||
// inline Index outer() const { return m_outer; }
|
||||
// inline Index row() const { return Transpose ? m_outer : Base::index(); }
|
||||
// inline Index col() const { return Transpose ? Base::index() : m_outer; }
|
||||
//
|
||||
// inline Scalar value() const { return Base::value() * m_factor; }
|
||||
// inline operator bool() const { return Base::operator bool() && !m_empty; }
|
||||
//
|
||||
// protected:
|
||||
// Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) const
|
||||
// {
|
||||
// return rhs.coeff(outer);
|
||||
// }
|
||||
//
|
||||
// Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
|
||||
// {
|
||||
// typename Traits::_RhsNested::InnerIterator it(rhs, outer);
|
||||
// if (it && it.index()==0 && it.value()!=Scalar(0))
|
||||
// return it.value();
|
||||
// m_empty = true;
|
||||
// return Scalar(0);
|
||||
// }
|
||||
//
|
||||
// Index m_outer;
|
||||
// bool m_empty;
|
||||
// Scalar m_factor;
|
||||
// };
|
||||
|
||||
template<typename LhsT, typename RhsT, bool Transpose>
|
||||
struct sparse_dense_outer_product_evaluator
|
||||
{
|
||||
protected:
|
||||
typedef typename conditional<Transpose,RhsT,LhsT>::type Lhs1;
|
||||
typedef typename conditional<Transpose,LhsT,RhsT>::type Rhs;
|
||||
typedef Product<LhsT,RhsT> ProdXprType;
|
||||
|
||||
// if the actual left-hand side is a dense vector,
|
||||
// then build a sparse-view so that we can seamlessly iterator over it.
|
||||
typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
|
||||
Lhs1, SparseView<Lhs1> >::type Lhs;
|
||||
typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
|
||||
Lhs1 const&, SparseView<Lhs1> >::type LhsArg;
|
||||
|
||||
typedef typename evaluator<Lhs>::type LhsEval;
|
||||
typedef typename evaluator<Rhs>::type RhsEval;
|
||||
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
|
||||
typedef typename ProdXprType::Scalar Scalar;
|
||||
typedef typename ProdXprType::Index Index;
|
||||
|
||||
public:
|
||||
enum {
|
||||
Flags = Transpose ? RowMajorBit : 0,
|
||||
CoeffReadCost = Dynamic
|
||||
};
|
||||
|
||||
class InnerIterator : public LhsIterator
|
||||
{
|
||||
public:
|
||||
InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer)
|
||||
: LhsIterator(xprEval.m_lhsXprImpl, 0),
|
||||
m_outer(outer),
|
||||
m_empty(false),
|
||||
m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<Rhs>::StorageKind() ))
|
||||
{}
|
||||
|
||||
EIGEN_STRONG_INLINE Index outer() const { return m_outer; }
|
||||
EIGEN_STRONG_INLINE Index row() const { return Transpose ? m_outer : LhsIterator::index(); }
|
||||
EIGEN_STRONG_INLINE Index col() const { return Transpose ? LhsIterator::index() : m_outer; }
|
||||
|
||||
EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; }
|
||||
EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); }
|
||||
|
||||
|
||||
protected:
|
||||
Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const
|
||||
{
|
||||
return rhs.coeff(outer);
|
||||
}
|
||||
|
||||
Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse())
|
||||
{
|
||||
typename RhsEval::InnerIterator it(rhs, outer);
|
||||
if (it && it.index()==0 && it.value()!=Scalar(0))
|
||||
return it.value();
|
||||
m_empty = true;
|
||||
return Scalar(0);
|
||||
}
|
||||
|
||||
Index m_outer;
|
||||
bool m_empty;
|
||||
Scalar m_factor;
|
||||
};
|
||||
|
||||
sparse_dense_outer_product_evaluator(const Lhs &lhs, const Rhs &rhs)
|
||||
: m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
|
||||
{}
|
||||
|
||||
// transpose case
|
||||
sparse_dense_outer_product_evaluator(const Rhs &rhs, const Lhs1 &lhs)
|
||||
: m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
|
||||
{}
|
||||
|
||||
protected:
|
||||
const LhsArg m_lhs;
|
||||
typename evaluator<Lhs>::nestedType m_lhsXprImpl;
|
||||
typename evaluator<Rhs>::nestedType m_rhsXprImpl;
|
||||
};
|
||||
|
||||
// sparse * dense outer product
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||
: sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor>
|
||||
{
|
||||
typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base;
|
||||
|
||||
typedef Product<Lhs, Rhs> XprType;
|
||||
typedef typename XprType::PlainObject PlainObject;
|
||||
|
||||
product_evaluator(const XprType& xpr)
|
||||
: Base(xpr.lhs(), xpr.rhs())
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||
: sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor>
|
||||
{
|
||||
typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base;
|
||||
|
||||
typedef Product<Lhs, Rhs> XprType;
|
||||
typedef typename XprType::PlainObject PlainObject;
|
||||
|
||||
product_evaluator(const XprType& xpr)
|
||||
: Base(xpr.lhs(), xpr.rhs())
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
@ -1132,7 +1132,7 @@ EIGEN_DONT_INLINE SparseMatrix<Scalar,_Options,_Index>& SparseMatrix<Scalar,_Opt
|
||||
{
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value),
|
||||
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
|
||||
|
||||
|
||||
const bool needToTranspose = (Flags & RowMajorBit) != (internal::evaluator<OtherDerived>::Flags & RowMajorBit);
|
||||
if (needToTranspose)
|
||||
{
|
||||
@ -1140,7 +1140,7 @@ EIGEN_DONT_INLINE SparseMatrix<Scalar,_Options,_Index>& SparseMatrix<Scalar,_Opt
|
||||
// 1 - compute the number of coeffs per dest inner vector
|
||||
// 2 - do the actual copy/eval
|
||||
// Since each coeff of the rhs has to be evaluated twice, let's evaluate it if needed
|
||||
typedef typename internal::nested_eval<OtherDerived,2>::type OtherCopy;
|
||||
typedef typename internal::nested_eval<OtherDerived,2,typename internal::plain_matrix_type<OtherDerived>::type >::type OtherCopy;
|
||||
typedef typename internal::remove_all<OtherCopy>::type _OtherCopy;
|
||||
typedef internal::evaluator<_OtherCopy> OtherCopyEval;
|
||||
OtherCopy otherCopy(other.derived());
|
||||
|
Loading…
x
Reference in New Issue
Block a user