addd matrix * self adjoint high level API

This commit is contained in:
Gael Guennebaud 2009-07-23 10:05:38 +02:00
parent f696efc00e
commit ddb3ac98a2
2 changed files with 69 additions and 37 deletions

View File

@ -53,8 +53,9 @@ struct ei_traits<SelfAdjointView<MatrixType, TriangularPart> > : ei_traits<Matri
};
};
template<typename Lhs,typename Rhs,bool RhsIsVector=Rhs::IsVectorAtCompileTime>
struct ei_selfadjoint_matrix_product_returntype;
template <typename Lhs, int LhsMode, bool LhsIsVector,
typename Rhs, int RhsMode, bool RhsIsVector>
struct ei_selfadjoint_product_returntype;
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
@ -99,10 +100,22 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
/** Efficient self-adjoint matrix times vector/matrix product */
template<typename OtherDerived>
ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>
ei_selfadjoint_product_returntype<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived());
return ei_selfadjoint_product_returntype
<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
(m_matrix, rhs.derived());
}
/** Efficient vector/matrix times self-adjoint matrix product */
template<typename OtherDerived> friend
ei_selfadjoint_product_returntype<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs)
{
return ei_selfadjoint_product_returntype
<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
(lhs.derived(),rhs.m_matrix);
}
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
@ -125,6 +138,14 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
const typename MatrixType::Nested m_matrix;
};
// template<typename OtherDerived, typename MatrixType, unsigned int UpLo>
// ei_selfadjoint_matrix_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >
// operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView<MatrixType,UpLo>& rhs)
// {
// return ei_matrix_selfadjoint_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >(lhs.derived(),rhs);
// }
template<typename Derived1, typename Derived2, int UnrollCount, bool ClearOpposite>
struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, UnrollCount, ClearOpposite>
{
@ -163,14 +184,14 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
* Wrapper to ei_product_selfadjoint_vector
***************************************************************************/
template<typename Lhs,typename Rhs>
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>,
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>
: public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>,
Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
@ -178,10 +199,10 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
{
dst.resize(m_rhs.rows(), m_rhs.cols());
ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit,
Lhs::Mode&(UpperTriangularBit|LowerTriangularBit)>
LhsMode&(UpperTriangularBit|LowerTriangularBit)>
(
m_lhs.rows(), // size
m_lhs._expression().data(), // lhs
m_lhs.data(), // lhs
m_lhs.stride(), // lhsStride,
m_rhs.data(), // rhs
// int rhsIncr,
@ -189,7 +210,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
);
}
const Lhs m_lhs;
const typename Lhs::Nested m_lhs;
const typename Rhs::Nested m_rhs;
};
@ -197,25 +218,36 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
* Wrapper to ei_product_selfadjoint_matrix
***************************************************************************/
template<typename Lhs,typename Rhs>
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>,
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>
: public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>,
Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
typedef typename ei_traits<Lhs>::ExpressionType LhsExpr;
typedef typename LhsExpr::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
enum { UpLo = ei_traits<Lhs>::Mode&(UpperTriangularBit|LowerTriangularBit) };
enum {
LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
LhsIsSelfAdjoint = (LhsMode&SelfAdjointBit)==SelfAdjointBit,
RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit),
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
};
template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); }
@ -231,26 +263,19 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
{
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression());
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
ei_product_selfadjoint_matrix<Scalar,
EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,
ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, true,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
ei_traits<Rhs>::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate),
EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,
ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,
ei_traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,bool(RhsBlasTraits::NeedToConjugate)),
ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
::run(
lhs.rows(), rhs.cols(), // sizes
@ -261,7 +286,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
);
}
const Lhs m_lhs;
const LhsNested m_lhs;
const RhsNested m_rhs;
};

View File

@ -138,6 +138,13 @@ template<typename MatrixType> void symm(const MatrixType& m)
m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
// test matrix * selfadjoint
m2 = m1.template triangularView<LowerTriangular>();
VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<LowerTriangular>(),
rhs23 = (rhs2) * (m1));
VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<LowerTriangular>(),
rhs23 = (s2*rhs2) * (s1*m1));
}
void test_product_selfadjoint()
{