add ReverseInnerIterators to loop over the elements in reverse order,

and partly fix bug #356 (issue in trisolve for upper-column major))
This commit is contained in:
Gael Guennebaud 2011-12-03 23:49:37 +01:00
parent a09cc5d4c0
commit 91e392a042
6 changed files with 142 additions and 34 deletions

View File

@ -32,39 +32,61 @@ class CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>
public:
class InnerIterator;
// typedef typename internal::remove_reference<LhsNested>::type _LhsNested;
class ReverseInnerIterator;
typedef CwiseUnaryOp<UnaryOp, MatrixType> Derived;
EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
protected:
typedef typename internal::traits<Derived>::_XprTypeNested _MatrixTypeNested;
typedef typename _MatrixTypeNested::InnerIterator MatrixTypeIterator;
typedef typename _MatrixTypeNested::ReverseInnerIterator MatrixTypeReverseIterator;
};
template<typename UnaryOp, typename MatrixType>
class CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>::InnerIterator
: public CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>::MatrixTypeIterator
{
typedef typename CwiseUnaryOpImpl::Scalar Scalar;
typedef typename internal::traits<Derived>::_XprTypeNested _MatrixTypeNested;
typedef typename _MatrixTypeNested::InnerIterator MatrixTypeIterator;
typedef typename MatrixType::Index Index;
typedef typename CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>::MatrixTypeIterator Base;
public:
EIGEN_STRONG_INLINE InnerIterator(const CwiseUnaryOpImpl& unaryOp, Index outer)
: m_iter(unaryOp.derived().nestedExpression(),outer), m_functor(unaryOp.derived().functor())
: Base(unaryOp.derived().nestedExpression(),outer), m_functor(unaryOp.derived().functor())
{}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{ ++m_iter; return *this; }
{ Base::operator++(); return *this; }
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_iter.value()); }
EIGEN_STRONG_INLINE Index index() const { return m_iter.index(); }
EIGEN_STRONG_INLINE Index row() const { return m_iter.row(); }
EIGEN_STRONG_INLINE Index col() const { return m_iter.col(); }
EIGEN_STRONG_INLINE operator bool() const { return m_iter; }
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(Base::value()); }
protected:
MatrixTypeIterator m_iter;
const UnaryOp m_functor;
private:
Scalar& valueRef();
};
template<typename UnaryOp, typename MatrixType>
class CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>::ReverseInnerIterator
: public CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>::MatrixTypeReverseIterator
{
typedef typename CwiseUnaryOpImpl::Scalar Scalar;
typedef typename CwiseUnaryOpImpl<UnaryOp,MatrixType,Sparse>::MatrixTypeReverseIterator Base;
public:
EIGEN_STRONG_INLINE ReverseInnerIterator(const CwiseUnaryOpImpl& unaryOp, Index outer)
: Base(unaryOp.derived().nestedExpression(),outer), m_functor(unaryOp.derived().functor())
{}
EIGEN_STRONG_INLINE ReverseInnerIterator& operator--()
{ Base::operator--(); return *this; }
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(Base::value()); }
protected:
const UnaryOp m_functor;
private:
Scalar& valueRef();
};
template<typename ViewOp, typename MatrixType>
@ -74,39 +96,58 @@ class CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>
public:
class InnerIterator;
// typedef typename internal::remove_reference<LhsNested>::type _LhsNested;
class ReverseInnerIterator;
typedef CwiseUnaryView<ViewOp, MatrixType> Derived;
EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
protected:
typedef typename internal::traits<Derived>::_MatrixTypeNested _MatrixTypeNested;
typedef typename _MatrixTypeNested::InnerIterator MatrixTypeIterator;
typedef typename _MatrixTypeNested::ReverseInnerIterator MatrixTypeReverseIterator;
};
template<typename ViewOp, typename MatrixType>
class CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>::InnerIterator
: public CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>::MatrixTypeIterator
{
typedef typename CwiseUnaryViewImpl::Scalar Scalar;
typedef typename internal::traits<Derived>::_MatrixTypeNested _MatrixTypeNested;
typedef typename _MatrixTypeNested::InnerIterator MatrixTypeIterator;
typedef typename MatrixType::Index Index;
typedef typename CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>::MatrixTypeIterator Base;
public:
EIGEN_STRONG_INLINE InnerIterator(const CwiseUnaryViewImpl& unaryView, Index outer)
: m_iter(unaryView.derived().nestedExpression(),outer), m_functor(unaryView.derived().functor())
EIGEN_STRONG_INLINE InnerIterator(const CwiseUnaryViewImpl& unaryOp, Index outer)
: Base(unaryOp.derived().nestedExpression(),outer), m_functor(unaryOp.derived().functor())
{}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{ ++m_iter; return *this; }
{ Base::operator++(); return *this; }
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_iter.value()); }
EIGEN_STRONG_INLINE Scalar& valueRef() { return m_functor(m_iter.valueRef()); }
EIGEN_STRONG_INLINE Index index() const { return m_iter.index(); }
EIGEN_STRONG_INLINE Index row() const { return m_iter.row(); }
EIGEN_STRONG_INLINE Index col() const { return m_iter.col(); }
EIGEN_STRONG_INLINE operator bool() const { return m_iter; }
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(Base::value()); }
EIGEN_STRONG_INLINE Scalar& valueRef() { return m_functor(Base::valueRef()); }
protected:
const ViewOp m_functor;
};
template<typename ViewOp, typename MatrixType>
class CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>::ReverseInnerIterator
: public CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>::MatrixTypeReverseIterator
{
typedef typename CwiseUnaryViewImpl::Scalar Scalar;
typedef typename CwiseUnaryViewImpl<ViewOp,MatrixType,Sparse>::MatrixTypeReverseIterator Base;
public:
EIGEN_STRONG_INLINE ReverseInnerIterator(const CwiseUnaryViewImpl& unaryOp, Index outer)
: Base(unaryOp.derived().nestedExpression(),outer), m_functor(unaryOp.derived().functor())
{}
EIGEN_STRONG_INLINE ReverseInnerIterator& operator--()
{ Base::operator--(); return *this; }
EIGEN_STRONG_INLINE Scalar value() const { return m_functor(Base::value()); }
EIGEN_STRONG_INLINE Scalar& valueRef() { return m_functor(Base::valueRef()); }
protected:
MatrixTypeIterator m_iter;
const ViewOp m_functor;
};

View File

@ -210,6 +210,7 @@ class SparseMatrix
public:
class InnerIterator;
class ReverseInnerIterator;
/** Removes all non zeros but keep allocated memory */
inline void setZero()
@ -889,4 +890,37 @@ class SparseMatrix<Scalar,_Options,_Index>::InnerIterator
Index m_end;
};
template<typename Scalar, int _Options, typename _Index>
class SparseMatrix<Scalar,_Options,_Index>::ReverseInnerIterator
{
public:
ReverseInnerIterator(const SparseMatrix& mat, Index outer)
: m_values(mat._valuePtr()), m_indices(mat._innerIndexPtr()), m_outer(outer), m_start(mat.m_outerIndex[outer])
{
if(mat.compressed())
m_id = mat.m_outerIndex[outer+1];
else
m_id = m_start + mat.m_innerNonZeros[outer];
}
inline ReverseInnerIterator& operator--() { --m_id; return *this; }
inline const Scalar& value() const { return m_values[m_id-1]; }
inline Scalar& valueRef() { return const_cast<Scalar&>(m_values[m_id-1]); }
inline Index index() const { return m_indices[m_id-1]; }
inline Index outer() const { return m_outer; }
inline Index row() const { return IsRowMajor ? m_outer : index(); }
inline Index col() const { return IsRowMajor ? index() : m_outer; }
inline operator bool() const { return (m_id > m_start); }
protected:
const Scalar* m_values;
const Index* m_indices;
const Index m_outer;
Index m_id;
const Index m_start;
};
#endif // EIGEN_SPARSEMATRIX_H

View File

@ -38,12 +38,16 @@ template<typename MatrixType, int Mode> class SparseTriangularView
: public SparseMatrixBase<SparseTriangularView<MatrixType,Mode> >
{
enum { SkipFirst = (Mode==Lower && !(MatrixType::Flags&RowMajorBit))
|| (Mode==Upper && (MatrixType::Flags&RowMajorBit)) };
|| (Mode==Upper && (MatrixType::Flags&RowMajorBit)),
SkipLast = !SkipFirst
};
public:
EIGEN_SPARSE_PUBLIC_INTERFACE(SparseTriangularView)
class InnerIterator;
class ReverseInnerIterator;
inline Index rows() const { return m_matrix.rows(); }
inline Index cols() const { return m_matrix.cols(); }
@ -92,6 +96,28 @@ class SparseTriangularView<MatrixType,Mode>::InnerIterator : public MatrixType::
}
};
template<typename MatrixType, int Mode>
class SparseTriangularView<MatrixType,Mode>::ReverseInnerIterator : public MatrixType::ReverseInnerIterator
{
typedef typename MatrixType::ReverseInnerIterator Base;
public:
EIGEN_STRONG_INLINE ReverseInnerIterator(const SparseTriangularView& view, Index outer)
: Base(view.nestedExpression(), outer)
{
if(SkipLast)
while((*this) && this->index()>outer)
--(*this);
}
inline Index row() const { return Base::row(); }
inline Index col() const { return Base::col(); }
EIGEN_STRONG_INLINE operator bool() const
{
return SkipLast ? Base::operator bool() : (Base::operator bool() && this->index() >= this->outer());
}
};
template<typename Derived>
template<int Mode>
inline const SparseTriangularView<Derived, Mode>

View File

@ -156,9 +156,11 @@ struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
{
if(!(Mode & UnitDiag))
{
// FIXME lhs.coeff(i,i) might not be always efficient while it must simply be the
// last element of the column !
other.coeffRef(i,col) /= lhs.innerVector(i).lastCoeff();
typename Lhs::ReverseInnerIterator it(lhs, i);
while(it && it.index()!=i)
--it;
eigen_assert(it && it.index()==i);
other.coeffRef(i,col) /= it.value();
}
typename Lhs::InnerIterator it(lhs, i);
for(; it && it.index()<i; ++it)

View File

@ -198,6 +198,9 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
VERIFY_IS_APPROX(m1.col(0).dot(refM2.row(0)), refM1.col(0).dot(refM2.row(0)));
VERIFY_IS_APPROX(m1.conjugate(), refM1.conjugate());
VERIFY_IS_APPROX(m1.real(), refM1.real());
refM4.setRandom();
// sparse cwise* dense
VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4));

View File

@ -72,6 +72,8 @@ template<typename Scalar> void sparse_solvers(int rows, int cols)
initSparse<Scalar>(density, refMat2, m2, ForceNonZeroDiag|MakeUpperTriangular, &zeroCoords, &nonzeroCoords);
VERIFY_IS_APPROX(refMat2.template triangularView<Upper>().solve(vec2),
m2.template triangularView<Upper>().solve(vec3));
VERIFY_IS_APPROX(refMat2.conjugate().template triangularView<Upper>().solve(vec2),
m2.conjugate().template triangularView<Upper>().solve(vec3));
// lower - transpose
initSparse<Scalar>(density, refMat2, m2, ForceNonZeroDiag|MakeLowerTriangular, &zeroCoords, &nonzeroCoords);