Implement binaryop and transpose evaluators for sparse matrices

This commit is contained in:
Gael Guennebaud 2014-06-23 10:40:03 +02:00
parent ec0a8b2e6d
commit 3849cc65ee
11 changed files with 396 additions and 19 deletions

View File

@ -42,10 +42,10 @@ struct Sparse {};
#include "src/SparseCore/MappedSparseMatrix.h"
#include "src/SparseCore/SparseVector.h"
#include "src/SparseCore/SparseCwiseUnaryOp.h"
#include "src/SparseCore/SparseCwiseBinaryOp.h"
#include "src/SparseCore/SparseTranspose.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparseBlock.h"
#include "src/SparseCore/SparseTranspose.h"
#include "src/SparseCore/SparseCwiseBinaryOp.h"
#include "src/SparseCore/SparseDot.h"
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseRedux.h"

View File

@ -2,7 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2011-2013 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2011-2012 Jitse Niesen <jitse@maths.leeds.ac.uk>
//
// This Source Code Form is subject to the terms of the Mozilla
@ -738,8 +738,10 @@ void call_assignment_no_alias(Dst& dst, const Src& src, const Func& func)
&& int(Dst::SizeAtCompileTime) != 1
};
dst.resize(NeedToTranspose ? src.cols() : src.rows(),
NeedToTranspose ? src.rows() : src.cols());
typename Dst::Index dstRows = NeedToTranspose ? src.cols() : src.rows();
typename Dst::Index dstCols = NeedToTranspose ? src.rows() : src.cols();
if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
dst.resize(dstRows, dstCols);
typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst>::type ActualDstTypeCleaned;
typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst&>::type ActualDstType;
@ -749,7 +751,7 @@ void call_assignment_no_alias(Dst& dst, const Src& src, const Func& func)
EIGEN_STATIC_ASSERT_LVALUE(Dst)
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(ActualDstTypeCleaned,Src)
EIGEN_CHECK_BINARY_COMPATIBILIY(Func,typename ActualDstTypeCleaned::Scalar,typename Src::Scalar);
Assignment<ActualDstTypeCleaned,Src,Func>::run(actualDst, src, func);
}
template<typename Dst, typename Src>

View File

@ -2,7 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2011-2012 Jitse Niesen <jitse@maths.leeds.ac.uk>
//
// This Source Code Form is subject to the terms of the Mozilla
@ -253,7 +253,7 @@ struct evaluator<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
// -------------------- Transpose --------------------
template<typename ArgType>
struct unary_evaluator<Transpose<ArgType> >
struct unary_evaluator<Transpose<ArgType>, IndexBased>
: evaluator_base<Transpose<ArgType> >
{
typedef Transpose<ArgType> XprType;
@ -440,7 +440,7 @@ struct evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
};
template<typename BinaryOp, typename Lhs, typename Rhs>
struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IndexBased, IndexBased>
: evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
{
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
@ -135,7 +135,7 @@ class CwiseUnaryOpImpl<UnaryOp,XprType,Dense>
return derived().functor().packetOp(derived().nestedExpression().template packet<LoadMode>(index));
}
};
#else
#else // EIGEN_TEST_EVALUATORS
// Generic API dispatcher
template<typename UnaryOp, typename XprType, typename StorageKind>
class CwiseUnaryOpImpl
@ -144,7 +144,7 @@ class CwiseUnaryOpImpl
public:
typedef typename internal::generic_xpr_base<CwiseUnaryOp<UnaryOp, XprType> >::type Base;
};
#endif
#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen

View File

@ -2,7 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2009-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -108,6 +108,17 @@ struct TransposeImpl_base<MatrixType, false>
} // end namespace internal
#ifdef EIGEN_TEST_EVALUATORS
// Generic API dispatcher
template<typename XprType, typename StorageKind>
class TransposeImpl
: public internal::generic_xpr_base<Transpose<XprType> >::type
{
public:
typedef typename internal::generic_xpr_base<Transpose<XprType> >::type Base;
};
#endif
template<typename MatrixType> class TransposeImpl<MatrixType,Dense>
: public internal::TransposeImpl_base<MatrixType>::type
{

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -44,6 +44,8 @@ class sparse_cwise_binary_op_inner_iterator_selector;
} // end namespace internal
#ifndef EIGEN_TEST_EVALUATORS
template<typename BinaryOp, typename Lhs, typename Rhs>
class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
: public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
@ -291,6 +293,314 @@ class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs,
} // end namespace internal
#else // EIGEN_TEST_EVALUATORS
namespace internal {
// Generic "sparse OP sparse"
template<typename BinaryOp, typename Lhs, typename Rhs>
struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IteratorBased, IteratorBased>
: evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
{
protected:
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
typedef typename evaluator<Rhs>::InnerIterator RhsIterator;
public:
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
class ReverseInnerIterator;
class InnerIterator
{
typedef typename traits<XprType>::Scalar Scalar;
typedef typename XprType::Index Index;
public:
EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer)
: m_lhsIter(aEval.m_lhsImpl,outer), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor)
{
this->operator++();
}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{
if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
{
m_id = m_lhsIter.index();
m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
++m_lhsIter;
++m_rhsIter;
}
else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
{
m_id = m_lhsIter.index();
m_value = m_functor(m_lhsIter.value(), Scalar(0));
++m_lhsIter;
}
else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
{
m_id = m_rhsIter.index();
m_value = m_functor(Scalar(0), m_rhsIter.value());
++m_rhsIter;
}
else
{
m_value = 0; // this is to avoid a compilation warning
m_id = -1;
}
return *this;
}
EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
EIGEN_STRONG_INLINE Index index() const { return m_id; }
EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
protected:
LhsIterator m_lhsIter;
RhsIterator m_rhsIter;
const BinaryOp& m_functor;
Scalar m_value;
Index m_id;
};
enum {
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
Flags = XprType::Flags
};
binary_evaluator(const XprType& xpr)
: m_functor(xpr.functor()),
m_lhsImpl(xpr.lhs()),
m_rhsImpl(xpr.rhs())
{ }
protected:
const BinaryOp m_functor;
typename evaluator<Lhs>::nestedType m_lhsImpl;
typename evaluator<Rhs>::nestedType m_rhsImpl;
};
// "sparse .* sparse"
template<typename T, typename Lhs, typename Rhs>
struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs>, IteratorBased, IteratorBased>
: evaluator_base<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs> >
{
protected:
typedef scalar_product_op<T> BinaryOp;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
typedef typename evaluator<Rhs>::InnerIterator RhsIterator;
public:
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
class ReverseInnerIterator;
class InnerIterator
{
typedef typename traits<XprType>::Scalar Scalar;
typedef typename XprType::Index Index;
public:
EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer)
: m_lhsIter(aEval.m_lhsImpl,outer), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor)
{
while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
{
if (m_lhsIter.index() < m_rhsIter.index())
++m_lhsIter;
else
++m_rhsIter;
}
}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{
++m_lhsIter;
++m_rhsIter;
while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
{
if (m_lhsIter.index() < m_rhsIter.index())
++m_lhsIter;
else
++m_rhsIter;
}
return *this;
}
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
protected:
LhsIterator m_lhsIter;
RhsIterator m_rhsIter;
const BinaryOp& m_functor;
};
enum {
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
Flags = XprType::Flags
};
binary_evaluator(const XprType& xpr)
: m_functor(xpr.functor()),
m_lhsImpl(xpr.lhs()),
m_rhsImpl(xpr.rhs())
{ }
protected:
const BinaryOp m_functor;
typename evaluator<Lhs>::nestedType m_lhsImpl;
typename evaluator<Rhs>::nestedType m_rhsImpl;
};
// "dense .* sparse"
template<typename T, typename Lhs, typename Rhs>
struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs>, IndexBased, IteratorBased>
: evaluator_base<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs> >
{
protected:
typedef scalar_product_op<T> BinaryOp;
typedef typename evaluator<Lhs>::type LhsEvaluator;
typedef typename evaluator<Rhs>::InnerIterator RhsIterator;
public:
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
class ReverseInnerIterator;
class InnerIterator
{
typedef typename traits<XprType>::Scalar Scalar;
typedef typename XprType::Index Index;
enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
public:
EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer)
: m_lhsEval(aEval.m_lhsImpl), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor), m_outer(outer)
{}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{
++m_rhsIter;
return *this;
}
EIGEN_STRONG_INLINE Scalar value() const
{ return m_functor(m_lhsEval.coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
protected:
const LhsEvaluator &m_lhsEval;
RhsIterator m_rhsIter;
const BinaryOp& m_functor;
const Index m_outer;
};
enum {
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
Flags = XprType::Flags
};
binary_evaluator(const XprType& xpr)
: m_functor(xpr.functor()),
m_lhsImpl(xpr.lhs()),
m_rhsImpl(xpr.rhs())
{ }
protected:
const BinaryOp m_functor;
typename evaluator<Lhs>::nestedType m_lhsImpl;
typename evaluator<Rhs>::nestedType m_rhsImpl;
};
// "sparse .* dense"
template<typename T, typename Lhs, typename Rhs>
struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs>, IteratorBased, IndexBased>
: evaluator_base<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs> >
{
protected:
typedef scalar_product_op<T> BinaryOp;
typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
typedef typename evaluator<Rhs>::type RhsEvaluator;
public:
typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType;
class ReverseInnerIterator;
class InnerIterator
{
typedef typename traits<XprType>::Scalar Scalar;
typedef typename XprType::Index Index;
enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
public:
EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer)
: m_lhsIter(aEval.m_lhsImpl,outer), m_rhsEval(aEval.m_rhsImpl), m_functor(aEval.m_functor), m_outer(outer)
{}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{
++m_lhsIter;
return *this;
}
EIGEN_STRONG_INLINE Scalar value() const
{ return m_functor(m_lhsIter.value(),
m_rhsEval.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
protected:
LhsIterator m_lhsIter;
const RhsEvaluator &m_rhsEval;
const BinaryOp& m_functor;
const Index m_outer;
};
enum {
CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost,
Flags = XprType::Flags
};
binary_evaluator(const XprType& xpr)
: m_functor(xpr.functor()),
m_lhsImpl(xpr.lhs()),
m_rhsImpl(xpr.rhs())
{ }
protected:
const BinaryOp m_functor;
typename evaluator<Lhs>::nestedType m_lhsImpl;
typename evaluator<Rhs>::nestedType m_rhsImpl;
};
}
#endif
/***************************************************************************
* Implementation of SparseMatrixBase and SparseCwise functions/operators
***************************************************************************/

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -79,10 +79,12 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
* constructed from this one. See the \ref flags "list of flags".
*/
#ifndef EIGEN_TEST_EVALUATORS
CoeffReadCost = internal::traits<Derived>::CoeffReadCost,
/**< This is a rough measure of how expensive it is to read one coefficient from
* this expression.
*/
#endif
IsRowMajor = Flags&RowMajorBit ? 1 : 0,

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -12,6 +12,7 @@
namespace Eigen {
#ifndef EIGEN_TEST_EVALUATORS
template<typename MatrixType> class TransposeImpl<MatrixType,Sparse>
: public SparseMatrixBase<Transpose<MatrixType> >
{
@ -58,6 +59,57 @@ template<typename MatrixType> class TransposeImpl<MatrixType,Sparse>::ReverseInn
Index col() const { return Base::row(); }
};
#else // EIGEN_TEST_EVALUATORS
namespace internal {
template<typename ArgType>
struct unary_evaluator<Transpose<ArgType>, IteratorBased>
: public evaluator_base<Transpose<ArgType> >
{
typedef typename evaluator<ArgType>::InnerIterator EvalIterator;
typedef typename evaluator<ArgType>::ReverseInnerIterator EvalReverseIterator;
public:
typedef Transpose<ArgType> XprType;
typedef typename XprType::Index Index;
class InnerIterator : public EvalIterator
{
public:
EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& unaryOp, typename XprType::Index outer)
: EvalIterator(unaryOp.m_argImpl,outer)
{}
Index row() const { return EvalIterator::col(); }
Index col() const { return EvalIterator::row(); }
};
class ReverseInnerIterator : public EvalReverseIterator
{
public:
EIGEN_STRONG_INLINE ReverseInnerIterator(const unary_evaluator& unaryOp, typename XprType::Index outer)
: EvalReverseIterator(unaryOp.m_argImpl,outer)
{}
Index row() const { return EvalReverseIterator::col(); }
Index col() const { return EvalReverseIterator::row(); }
};
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
Flags = XprType::Flags
};
unary_evaluator(const XprType& op) :m_argImpl(op.nestedExpression()) {}
protected:
typename evaluator<ArgType>::nestedType m_argImpl;
};
} // end namespace internal
#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen
#endif // EIGEN_SPARSETRANSPOSE_H