From ddb3ac98a20d5f56146a53d485b1899a46b9f912 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 23 Jul 2009 10:05:38 +0200 Subject: [PATCH] addd matrix * self adjoint high level API --- Eigen/src/Core/SelfAdjointView.h | 99 ++++++++++++++++++++------------ test/product_selfadjoint.cpp | 7 +++ 2 files changed, 69 insertions(+), 37 deletions(-) diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 7f5fd7533..c64ebc174 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -53,8 +53,9 @@ struct ei_traits > : ei_traits -struct ei_selfadjoint_matrix_product_returntype; +template +struct ei_selfadjoint_product_returntype; // FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? template class SelfAdjointView @@ -99,10 +100,22 @@ template class SelfAdjointView /** Efficient self-adjoint matrix times vector/matrix product */ template - ei_selfadjoint_matrix_product_returntype + ei_selfadjoint_product_returntype operator*(const MatrixBase& rhs) const { - return ei_selfadjoint_matrix_product_returntype(*this, rhs.derived()); + return ei_selfadjoint_product_returntype + + (m_matrix, rhs.derived()); + } + + /** Efficient vector/matrix times self-adjoint matrix product */ + template friend + ei_selfadjoint_product_returntype + operator*(const MatrixBase& lhs, const SelfAdjointView& rhs) + { + return ei_selfadjoint_product_returntype + + (lhs.derived(),rhs.m_matrix); } /** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this: @@ -125,6 +138,14 @@ template class SelfAdjointView const typename MatrixType::Nested m_matrix; }; + +// template +// ei_selfadjoint_matrix_product_returntype > +// operator*(const MatrixBase& lhs, const SelfAdjointView& rhs) +// { +// return ei_matrix_selfadjoint_product_returntype >(lhs.derived(),rhs); +// } + template struct ei_triangular_assignment_selector { @@ -163,14 +184,14 @@ struct ei_triangular_assignment_selector -struct ei_selfadjoint_matrix_product_returntype - : public ReturnByValue, +template +struct ei_selfadjoint_product_returntype + : public ReturnByValue, Matrix::Scalar, Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > { typedef typename ei_cleantype::type RhsNested; - ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs) + ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} @@ -178,10 +199,10 @@ struct ei_selfadjoint_matrix_product_returntype { dst.resize(m_rhs.rows(), m_rhs.cols()); ei_product_selfadjoint_vector::Flags&RowMajorBit, - Lhs::Mode&(UpperTriangularBit|LowerTriangularBit)> + LhsMode&(UpperTriangularBit|LowerTriangularBit)> ( m_lhs.rows(), // size - m_lhs._expression().data(), // lhs + m_lhs.data(), // lhs m_lhs.stride(), // lhsStride, m_rhs.data(), // rhs // int rhsIncr, @@ -189,7 +210,7 @@ struct ei_selfadjoint_matrix_product_returntype ); } - const Lhs m_lhs; + const typename Lhs::Nested m_lhs; const typename Rhs::Nested m_rhs; }; @@ -197,25 +218,36 @@ struct ei_selfadjoint_matrix_product_returntype * Wrapper to ei_product_selfadjoint_matrix ***************************************************************************/ -template -struct ei_selfadjoint_matrix_product_returntype - : public ReturnByValue, +template +struct ei_selfadjoint_product_returntype + : public ReturnByValue, Matrix::Scalar, - Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > + Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > { - ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs) + ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {} typedef typename Lhs::Scalar Scalar; + + typedef typename Lhs::Nested LhsNested; + typedef typename ei_cleantype::type _LhsNested; + typedef ei_blas_traits<_LhsNested> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename ei_cleantype::type _ActualLhsType; + typedef typename Rhs::Nested RhsNested; typedef typename ei_cleantype::type _RhsNested; + typedef ei_blas_traits<_RhsNested> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename ei_cleantype::type _ActualRhsType; - typedef typename ei_traits::ExpressionType LhsExpr; - typedef typename LhsExpr::Nested LhsNested; - typedef typename ei_cleantype::type _LhsNested; - - enum { UpLo = ei_traits::Mode&(UpperTriangularBit|LowerTriangularBit) }; + enum { + LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit), + LhsIsSelfAdjoint = (LhsMode&SelfAdjointBit)==SelfAdjointBit, + RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit), + RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit + }; template inline void _addTo(Dest& dst) const { evalTo(dst,1); } @@ -231,26 +263,19 @@ struct ei_selfadjoint_matrix_product_returntype template void evalTo(Dest& dst, Scalar alpha) const { - typedef ei_blas_traits<_LhsNested> LhsBlasTraits; - typedef ei_blas_traits<_RhsNested> RhsBlasTraits; - - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; - - typedef typename ei_cleantype::type _ActualLhsType; - typedef typename ei_cleantype::type _ActualRhsType; - - const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression()); + const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression()) + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); ei_product_selfadjoint_matrix::Flags &RowMajorBit) ? RowMajor : ColMajor, true, - NumTraits::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)), - ei_traits::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate), + EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular, + ei_traits::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint, + NumTraits::IsComplex && EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)), + EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular, + ei_traits::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint, + NumTraits::IsComplex && EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,bool(RhsBlasTraits::NeedToConjugate)), ei_traits::Flags&RowMajorBit ? RowMajor : ColMajor> ::run( lhs.rows(), rhs.cols(), // sizes @@ -261,7 +286,7 @@ struct ei_selfadjoint_matrix_product_returntype ); } - const Lhs m_lhs; + const LhsNested m_lhs; const RhsNested m_rhs; }; diff --git a/test/product_selfadjoint.cpp b/test/product_selfadjoint.cpp index 814d542e4..44bafad93 100644 --- a/test/product_selfadjoint.cpp +++ b/test/product_selfadjoint.cpp @@ -138,6 +138,13 @@ template void symm(const MatrixType& m) m2 = m1.template triangularView(); VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView() * (s2*rhs3).conjugate(), rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate()); + + // test matrix * selfadjoint + m2 = m1.template triangularView(); + VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView(), + rhs23 = (rhs2) * (m1)); + VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView(), + rhs23 = (s2*rhs2) * (s1*m1)); } void test_product_selfadjoint() {