add selfadjointView from a trinagularView

This commit is contained in:
Gael Guennebaud 2009-07-31 17:35:55 +02:00
parent 2796bcabb1
commit 18429156a1
2 changed files with 36 additions and 50 deletions

View File

@ -156,7 +156,9 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
typedef typename ei_traits<TriangularView>::Scalar Scalar; typedef typename ei_traits<TriangularView>::Scalar Scalar;
typedef _MatrixType MatrixType; typedef _MatrixType MatrixType;
typedef typename MatrixType::PlainMatrixType PlainMatrixType; typedef typename MatrixType::PlainMatrixType PlainMatrixType;
typedef typename MatrixType::Nested MatrixTypeNested;
typedef typename ei_cleantype<MatrixTypeNested>::type _MatrixTypeNested;
enum { enum {
Mode = _Mode, Mode = _Mode,
TransposeMode = (Mode & UpperTriangularBit ? LowerTriangularBit : 0) TransposeMode = (Mode & UpperTriangularBit ? LowerTriangularBit : 0)
@ -286,6 +288,17 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
void solveInPlace(const MatrixBase<OtherDerived>& other) const void solveInPlace(const MatrixBase<OtherDerived>& other) const
{ return solveInPlace<OnTheLeft>(other); } { return solveInPlace<OnTheLeft>(other); }
const SelfAdjointView<_MatrixTypeNested,Mode> selfadjointView() const
{
EIGEN_STATIC_ASSERT((Mode&UnitDiagBit)==0,PROGRAMMING_ERROR);
return SelfAdjointView<_MatrixTypeNested,Mode>(m_matrix);
}
SelfAdjointView<_MatrixTypeNested,Mode> selfadjointView()
{
EIGEN_STATIC_ASSERT((Mode&UnitDiagBit)==0,PROGRAMMING_ERROR);
return SelfAdjointView<_MatrixTypeNested,Mode>(m_matrix);
}
template<typename OtherDerived> template<typename OtherDerived>
void swap(const TriangularBase<OtherDerived>& other) void swap(const TriangularBase<OtherDerived>& other)
{ {
@ -300,7 +313,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
protected: protected:
const typename MatrixType::Nested m_matrix; const MatrixTypeNested m_matrix;
}; };
/*************************************************************************** /***************************************************************************
@ -562,6 +575,10 @@ void TriangularBase<Derived>::evalToDenseLazy(MatrixBase<DenseDerived> &other) c
>::run(other.derived(), derived()._expression()); >::run(other.derived(), derived()._expression());
} }
/***************************************************************************
* Implementation of TriangularView methods
***************************************************************************/
/*************************************************************************** /***************************************************************************
* Implementation of MatrixBase methods * Implementation of MatrixBase methods
***************************************************************************/ ***************************************************************************/

View File

@ -24,12 +24,11 @@
#include "main.h" #include "main.h"
template<typename Lhs, typename Rhs> #define VERIFY_TRSM(TRI,XB) { \
void solve_ref(const Lhs& lhs, Rhs& rhs) XB.setRandom(); ref = XB; \
{ TRI.template solveInPlace(XB); \
for (int j=0; j<rhs.cols(); ++j) VERIFY_IS_APPROX(TRI.toDense() * XB, ref); \
lhs.solveInPlace(rhs.col(j)); }
}
template<typename Scalar> void trsm(int size,int cols) template<typename Scalar> void trsm(int size,int cols)
{ {
@ -37,53 +36,23 @@ template<typename Scalar> void trsm(int size,int cols)
Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size); Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size);
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size); Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size);
Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRef(size,cols), cmRhs(size,cols);
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRef(size,cols), rmRhs(size,cols);
cmLhs.setRandom(); cmLhs.diagonal().cwise() += 10; Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRhs(size,cols), ref(size,cols);
rmLhs.setRandom(); rmLhs.diagonal().cwise() += 10; Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRhs(size,cols);
cmRhs.setRandom(); cmRef = cmRhs; cmLhs.setRandom(); cmLhs *= 0.1; cmLhs.diagonal().cwise() += 1;
cmLhs.conjugate().template triangularView<LowerTriangular>().solveInPlace(cmRhs); rmLhs.setRandom(); rmLhs *= 0.1; rmLhs.diagonal().cwise() += 1;
solve_ref(cmLhs.conjugate().template triangularView<LowerTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
cmRhs.setRandom(); cmRef = cmRhs; VERIFY_TRSM(cmLhs.conjugate().template triangularView<LowerTriangular>(), cmRhs);
cmLhs.conjugate().template triangularView<UpperTriangular>().solveInPlace(cmRhs); VERIFY_TRSM(cmLhs .template triangularView<UpperTriangular>(), cmRhs);
solve_ref(cmLhs.conjugate().template triangularView<UpperTriangular>(),cmRef); VERIFY_TRSM(cmLhs .template triangularView<LowerTriangular>(), rmRhs);
VERIFY_IS_APPROX(cmRhs, cmRef); VERIFY_TRSM(cmLhs.conjugate().template triangularView<UpperTriangular>(), rmRhs);
rmRhs.setRandom(); rmRef = rmRhs;
cmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
solve_ref(cmLhs.template triangularView<LowerTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
rmRhs.setRandom(); rmRef = rmRhs; VERIFY_TRSM(cmLhs.conjugate().template triangularView<UnitLowerTriangular>(), cmRhs);
cmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs); VERIFY_TRSM(cmLhs .template triangularView<UnitUpperTriangular>(), rmRhs);
solve_ref(cmLhs.template triangularView<UpperTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
VERIFY_TRSM(rmLhs .template triangularView<LowerTriangular>(), cmRhs);
cmRhs.setRandom(); cmRef = cmRhs; VERIFY_TRSM(rmLhs.conjugate().template triangularView<UnitUpperTriangular>(), rmRhs);
rmLhs.template triangularView<UnitLowerTriangular>().solveInPlace(cmRhs);
solve_ref(rmLhs.template triangularView<UnitLowerTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
cmRhs.setRandom(); cmRef = cmRhs;
rmLhs.template triangularView<UnitUpperTriangular>().solveInPlace(cmRhs);
solve_ref(rmLhs.template triangularView<UnitUpperTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
rmRhs.setRandom(); rmRef = rmRhs;
rmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
solve_ref(rmLhs.template triangularView<LowerTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
rmRhs.setRandom(); rmRef = rmRhs;
rmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
} }
void test_product_trsm() void test_product_trsm()