From 1d98cc5e5da88254c784c4f02517bf5a47f007bc Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Tue, 25 Jan 2011 21:22:04 -0500 Subject: [PATCH] eigen2 support: implement part, mimic eigen2 behavior braindeadness-for-braindeadness --- Eigen/src/Core/MatrixBase.h | 8 ++-- Eigen/src/Core/SelfAdjointView.h | 24 ++++++++++++ Eigen/src/Core/TriangularMatrix.h | 45 ++++++++++++++++------- Eigen/src/Core/util/ForwardDeclarations.h | 3 ++ test/eigen2/eigen2_triangular.cpp | 29 +++++++++++++++ 5 files changed, 92 insertions(+), 17 deletions(-) diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index da4af6bfd..fbdc059cf 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -242,16 +242,16 @@ template class MatrixBase typename MatrixBase::template DiagonalIndexReturnType::Type diagonal(Index index); typename MatrixBase::template ConstDiagonalIndexReturnType::Type diagonal(Index index) const; - #ifdef EIGEN2_SUPPORT - template TriangularView part(); - template const TriangularView part() const; + //#ifdef EIGEN2_SUPPORT + template typename internal::eigen2_part_return_type::type part(); + template const typename internal::eigen2_part_return_type::type part() const; // huuuge hack. make Eigen2's matrix.part() work in eigen3. Problem: Diagonal is now a class template instead // of an integer constant. Solution: overload the part() method template wrt template parameters list. template class U> const DiagonalWrapper part() const { return diagonal().asDiagonal(); } - #endif // EIGEN2_SUPPORT + //#endif // EIGEN2_SUPPORT template struct TriangularViewReturnType { typedef TriangularView Type; }; template struct ConstTriangularViewReturnType { typedef const TriangularView Type; }; diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 5d8468884..92d58b9f8 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -48,6 +48,7 @@ struct traits > : traits typedef typename nested::type MatrixTypeNested; typedef typename remove_reference::type _MatrixTypeNested; typedef MatrixType ExpressionType; + typedef typename MatrixType::PlainObject DenseMatrixType; enum { Mode = UpLo | SelfAdjoint, Flags = _MatrixTypeNested::Flags & (HereditaryBits) @@ -171,6 +172,29 @@ template class SelfAdjointView EigenvaluesReturnType eigenvalues() const; RealScalar operatorNorm() const; + + #ifdef EIGEN2_SUPPORT + template + SelfAdjointView& operator=(const MatrixBase& other) + { + enum { + OtherPart = UpLo == Upper ? StrictlyLower : StrictlyUpper + }; + m_matrix.const_cast_derived().template triangularView() = other; + m_matrix.const_cast_derived().template triangularView() = other.adjoint(); + return *this; + } + template + SelfAdjointView& operator=(const TriangularView& other) + { + enum { + OtherPart = UpLo == Upper ? StrictlyLower : StrictlyUpper + }; + m_matrix.const_cast_derived().template triangularView() = other.toDenseMatrix(); + m_matrix.const_cast_derived().template triangularView() = other.toDenseMatrix().adjoint(); + return *this; + } + #endif protected: const typename MatrixType::Nested m_matrix; diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index ce5b53631..714d56a5b 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -48,6 +48,7 @@ template class TriangularBase : public EigenBase typedef typename internal::traits::Scalar Scalar; typedef typename internal::traits::StorageKind StorageKind; typedef typename internal::traits::Index Index; + typedef typename internal::traits::DenseMatrixType DenseMatrixType; inline TriangularBase() { eigen_assert(!((Mode&UnitDiag) && (Mode&ZeroDiag))); } @@ -88,6 +89,13 @@ template class TriangularBase : public EigenBase template void evalToLazy(MatrixBase &other) const; + DenseMatrixType toDenseMatrix() const + { + DenseMatrixType res(rows(), cols()); + evalToLazy(res); + return res; + } + protected: void check_coordinates(Index row, Index col) const @@ -137,6 +145,7 @@ struct traits > : traits typedef typename nested::type MatrixTypeNested; typedef typename remove_reference::type _MatrixTypeNested; typedef MatrixType ExpressionType; + typedef typename MatrixType::PlainObject DenseMatrixType; enum { Mode = _Mode, Flags = (_MatrixTypeNested::Flags & (HereditaryBits) & (~(PacketAccessBit | DirectAccessBit | LinearAccessBit))) | Mode, @@ -159,7 +168,7 @@ template class TriangularView typedef typename internal::traits::Scalar Scalar; typedef _MatrixType MatrixType; - typedef typename MatrixType::PlainObject DenseMatrixType; + typedef typename internal::traits::DenseMatrixType DenseMatrixType; protected: typedef typename MatrixType::Nested MatrixTypeNested; @@ -269,13 +278,6 @@ template class TriangularView inline const TriangularView,TransposeMode> transpose() const { return m_matrix.transpose(); } - DenseMatrixType toDenseMatrix() const - { - DenseMatrixType res(rows(), cols()); - evalToLazy(res); - return res; - } - /** Efficient triangular matrix times vector/matrix product */ template TriangularProduct @@ -310,18 +312,18 @@ template class TriangularView const typename eigen2_product_return_type::type operator*(const TriangularView& rhs) const { - return toDenseMatrix() * rhs.toDenseMatrix(); + return this->toDenseMatrix() * rhs.toDenseMatrix(); } template bool isApprox(const TriangularView& other, typename NumTraits::Real precision = NumTraits::dummy_precision()) const { - return toDenseMatrix().isApprox(other.toDenseMatrix(), precision); + return this->toDenseMatrix().isApprox(other.toDenseMatrix(), precision); } template bool isApprox(const MatrixBase& other, typename NumTraits::Real precision = NumTraits::dummy_precision()) const { - return toDenseMatrix().isApprox(other, precision); + return this->toDenseMatrix().isApprox(other, precision); } #endif // EIGEN2_SUPPORT @@ -707,10 +709,27 @@ void TriangularBase::evalToLazy(MatrixBase &other) const ***************************************************************************/ #ifdef EIGEN2_SUPPORT + +// implementation of part<>(), including the SelfAdjoint case. + +namespace internal { +template +struct eigen2_part_return_type +{ + typedef TriangularView type; +}; + +template +struct eigen2_part_return_type +{ + typedef SelfAdjointView type; +}; +} + /** \deprecated use MatrixBase::triangularView() */ template template -const TriangularView MatrixBase::part() const +const typename internal::eigen2_part_return_type::type MatrixBase::part() const { return derived(); } @@ -718,7 +737,7 @@ const TriangularView MatrixBase::part() const /** \deprecated use MatrixBase::triangularView() */ template template -TriangularView MatrixBase::part() +typename internal::eigen2_part_return_type::type MatrixBase::part() { return derived(); } diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 578f8d8e6..548da3986 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -268,6 +268,9 @@ template class Cwise; template class Minor; template class LU; template class QR; +namespace internal { +template struct eigen2_part_return_type; +} #endif #endif // EIGEN_FORWARDDECLARATIONS_H diff --git a/test/eigen2/eigen2_triangular.cpp b/test/eigen2/eigen2_triangular.cpp index c81fad0da..43b42e3a5 100644 --- a/test/eigen2/eigen2_triangular.cpp +++ b/test/eigen2/eigen2_triangular.cpp @@ -124,8 +124,37 @@ template void triangular(const MatrixType& m) } +void selfadjoint() +{ + Matrix2i m; + m << 1, 2, + 3, 4; + + Matrix2i m1 = Matrix2i::Zero(); + m1.part() = m; + Matrix2i ref1; + ref1 << 1, 2, + 2, 4; + VERIFY(m1 == ref1); + + Matrix2i m2 = Matrix2i::Zero(); + m2.part() = m.part(); + Matrix2i ref2; + ref2 << 1, 2, + 2, 4; + VERIFY(m2 == ref2); + + Matrix2i m3 = Matrix2i::Zero(); + m3.part() = m.part(); + Matrix2i ref3; + ref3 << 1, 0, + 0, 4; + VERIFY(m3 == ref3); +} + void test_eigen2_triangular() { + CALL_SUBTEST_8( selfadjoint() ); for(int i = 0; i < g_repeat ; i++) { CALL_SUBTEST_1( triangular(Matrix()) ); CALL_SUBTEST_2( triangular(Matrix()) );