Port Cholesky module to evaluators

This commit is contained in:
Gael Guennebaud 2014-03-11 13:33:44 +01:00
parent 9be72cda2a
commit 082f7ddc37
4 changed files with 93 additions and 32 deletions

View File

@ -181,6 +181,17 @@ template<typename _MatrixType, int _UpLo> class LDLT
*
* \sa MatrixBase::ldlt()
*/
#ifdef EIGEN_TEST_EVALUATORS
template<typename Rhs>
inline const Solve<LDLT, Rhs>
solve(const MatrixBase<Rhs>& b) const
{
eigen_assert(m_isInitialized && "LDLT is not initialized.");
eigen_assert(m_matrix.rows()==b.rows()
&& "LDLT::solve(): invalid number of rows of the right hand side matrix b");
return Solve<LDLT, Rhs>(*this, b.derived());
}
#else
template<typename Rhs>
inline const internal::solve_retval<LDLT, Rhs>
solve(const MatrixBase<Rhs>& b) const
@ -190,6 +201,7 @@ template<typename _MatrixType, int _UpLo> class LDLT
&& "LDLT::solve(): invalid number of rows of the right hand side matrix b");
return internal::solve_retval<LDLT, Rhs>(*this, b.derived());
}
#endif
#ifdef EIGEN2_SUPPORT
template<typename OtherDerived, typename ResultType>
@ -234,6 +246,12 @@ template<typename _MatrixType, int _UpLo> class LDLT
return Success;
}
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
#endif
protected:
/** \internal
@ -492,7 +510,44 @@ LDLT<MatrixType,_UpLo>& LDLT<MatrixType,_UpLo>::rankUpdate(const MatrixBase<Deri
return *this;
}
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename _MatrixType, int _UpLo>
template<typename RhsType, typename DstType>
void LDLT<_MatrixType,_UpLo>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
eigen_assert(rhs.rows() == rows());
// dst = P b
dst = m_transpositions * rhs;
// dst = L^-1 (P b)
matrixL().solveInPlace(dst);
// dst = D^-1 (L^-1 P b)
// more precisely, use pseudo-inverse of D (see bug 241)
using std::abs;
EIGEN_USING_STD_MATH(max);
const Diagonal<const MatrixType> vecD = vectorD();
RealScalar tolerance = (max)( vecD.array().abs().maxCoeff() * NumTraits<Scalar>::epsilon(),
RealScalar(1) / NumTraits<RealScalar>::highest()); // motivated by LAPACK's xGELSS
for (Index i = 0; i < vecD.size(); ++i)
{
if(abs(vecD(i)) > tolerance)
dst.row(i) /= vecD(i);
else
dst.row(i).setZero();
}
// dst = L^-T (D^-1 L^-1 P b)
matrixU().solveInPlace(dst);
// dst = P^-1 (L^-T D^-1 L^-1 P b) = A^-1 b
dst = m_transpositions.transpose() * dst;
}
#endif
namespace internal {
#ifndef EIGEN_TEST_EVALUATORS
template<typename _MatrixType, int _UpLo, typename Rhs>
struct solve_retval<LDLT<_MatrixType,_UpLo>, Rhs>
: solve_retval_base<LDLT<_MatrixType,_UpLo>, Rhs>
@ -502,37 +557,10 @@ struct solve_retval<LDLT<_MatrixType,_UpLo>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
eigen_assert(rhs().rows() == dec().matrixLDLT().rows());
// dst = P b
dst = dec().transpositionsP() * rhs();
// dst = L^-1 (P b)
dec().matrixL().solveInPlace(dst);
// dst = D^-1 (L^-1 P b)
// more precisely, use pseudo-inverse of D (see bug 241)
using std::abs;
EIGEN_USING_STD_MATH(max);
typedef typename LDLTType::MatrixType MatrixType;
typedef typename LDLTType::Scalar Scalar;
typedef typename LDLTType::RealScalar RealScalar;
const Diagonal<const MatrixType> vectorD = dec().vectorD();
RealScalar tolerance = (max)(vectorD.array().abs().maxCoeff() * NumTraits<Scalar>::epsilon(),
RealScalar(1) / NumTraits<RealScalar>::highest()); // motivated by LAPACK's xGELSS
for (Index i = 0; i < vectorD.size(); ++i) {
if(abs(vectorD(i)) > tolerance)
dst.row(i) /= vectorD(i);
else
dst.row(i).setZero();
}
// dst = L^-T (D^-1 L^-1 P b)
dec().matrixU().solveInPlace(dst);
// dst = P^-1 (L^-T D^-1 L^-1 P b) = A^-1 b
dst = dec().transpositionsP().transpose() * dst;
dec()._solve_impl(rhs(),dst);
}
};
#endif
}
/** \internal use x = ldlt_object.solve(x);

View File

@ -117,6 +117,17 @@ template<typename _MatrixType, int _UpLo> class LLT
*
* \sa solveInPlace(), MatrixBase::llt()
*/
#ifdef EIGEN_TEST_EVALUATORS
template<typename Rhs>
inline const Solve<LLT, Rhs>
solve(const MatrixBase<Rhs>& b) const
{
eigen_assert(m_isInitialized && "LLT is not initialized.");
eigen_assert(m_matrix.rows()==b.rows()
&& "LLT::solve(): invalid number of rows of the right hand side matrix b");
return Solve<LLT, Rhs>(*this, b.derived());
}
#else
template<typename Rhs>
inline const internal::solve_retval<LLT, Rhs>
solve(const MatrixBase<Rhs>& b) const
@ -126,6 +137,7 @@ template<typename _MatrixType, int _UpLo> class LLT
&& "LLT::solve(): invalid number of rows of the right hand side matrix b");
return internal::solve_retval<LLT, Rhs>(*this, b.derived());
}
#endif
#ifdef EIGEN2_SUPPORT
template<typename OtherDerived, typename ResultType>
@ -173,6 +185,12 @@ template<typename _MatrixType, int _UpLo> class LLT
template<typename VectorType>
LLT rankUpdate(const VectorType& vec, const RealScalar& sigma = 1);
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const;
#endif
protected:
/** \internal
* Used to compute and store L
@ -416,7 +434,18 @@ LLT<_MatrixType,_UpLo> LLT<_MatrixType,_UpLo>::rankUpdate(const VectorType& v, c
return *this;
}
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename _MatrixType,int _UpLo>
template<typename RhsType, typename DstType>
void LLT<_MatrixType,_UpLo>::_solve_impl(const RhsType &rhs, DstType &dst) const
{
dst = rhs;
solveInPlace(dst);
}
#endif
namespace internal {
#ifndef EIGEN_TEST_EVALUATORS
template<typename _MatrixType, int UpLo, typename Rhs>
struct solve_retval<LLT<_MatrixType, UpLo>, Rhs>
: solve_retval_base<LLT<_MatrixType, UpLo>, Rhs>
@ -430,6 +459,7 @@ struct solve_retval<LLT<_MatrixType, UpLo>, Rhs>
dec().solveInPlace(dst);
}
};
#endif
}
/** \internal use x = llt_object.solve(x);

View File

@ -39,8 +39,11 @@ struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType>
enum {
Mode = UpLo | SelfAdjoint,
Flags = MatrixTypeNestedCleaned::Flags & (HereditaryBits)
& (~(PacketAccessBit | DirectAccessBit | LinearAccessBit)), // FIXME these flags should be preserved
& (~(PacketAccessBit | DirectAccessBit | LinearAccessBit)) // FIXME these flags should be preserved
#ifndef EIGEN_TEST_EVALUATORS
,
CoeffReadCost = MatrixTypeNestedCleaned::CoeffReadCost
#endif
};
};
}

View File

@ -348,7 +348,7 @@ template<typename T, int n=1, typename PlainObject = typename eval<T>::type> str
// When using evaluators, we never evaluate when assembling the expression!!
// TODO: get rid of this nested class since it's just an alias for ref_selector.
template<typename T, int n=1, typename PlainObject = typename eval<T>::type> struct nested
template<typename T, int n=1, typename PlainObject = void> struct nested
{
typedef typename ref_selector<T>::type type;
};