mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-24 14:45:14 +08:00
addd matrix * self adjoint high level API
This commit is contained in:
parent
f696efc00e
commit
ddb3ac98a2
@ -53,8 +53,9 @@ struct ei_traits<SelfAdjointView<MatrixType, TriangularPart> > : ei_traits<Matri
|
||||
};
|
||||
};
|
||||
|
||||
template<typename Lhs,typename Rhs,bool RhsIsVector=Rhs::IsVectorAtCompileTime>
|
||||
struct ei_selfadjoint_matrix_product_returntype;
|
||||
template <typename Lhs, int LhsMode, bool LhsIsVector,
|
||||
typename Rhs, int RhsMode, bool RhsIsVector>
|
||||
struct ei_selfadjoint_product_returntype;
|
||||
|
||||
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
|
||||
template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
|
||||
@ -99,10 +100,22 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
|
||||
|
||||
/** Efficient self-adjoint matrix times vector/matrix product */
|
||||
template<typename OtherDerived>
|
||||
ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>
|
||||
ei_selfadjoint_product_returntype<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
|
||||
operator*(const MatrixBase<OtherDerived>& rhs) const
|
||||
{
|
||||
return ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived());
|
||||
return ei_selfadjoint_product_returntype
|
||||
<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
|
||||
(m_matrix, rhs.derived());
|
||||
}
|
||||
|
||||
/** Efficient vector/matrix times self-adjoint matrix product */
|
||||
template<typename OtherDerived> friend
|
||||
ei_selfadjoint_product_returntype<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
|
||||
operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs)
|
||||
{
|
||||
return ei_selfadjoint_product_returntype
|
||||
<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
|
||||
(lhs.derived(),rhs.m_matrix);
|
||||
}
|
||||
|
||||
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
|
||||
@ -125,6 +138,14 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
|
||||
const typename MatrixType::Nested m_matrix;
|
||||
};
|
||||
|
||||
|
||||
// template<typename OtherDerived, typename MatrixType, unsigned int UpLo>
|
||||
// ei_selfadjoint_matrix_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >
|
||||
// operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView<MatrixType,UpLo>& rhs)
|
||||
// {
|
||||
// return ei_matrix_selfadjoint_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >(lhs.derived(),rhs);
|
||||
// }
|
||||
|
||||
template<typename Derived1, typename Derived2, int UnrollCount, bool ClearOpposite>
|
||||
struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, UnrollCount, ClearOpposite>
|
||||
{
|
||||
@ -163,14 +184,14 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
|
||||
* Wrapper to ei_product_selfadjoint_vector
|
||||
***************************************************************************/
|
||||
|
||||
template<typename Lhs,typename Rhs>
|
||||
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
|
||||
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>,
|
||||
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
|
||||
struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>
|
||||
: public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>,
|
||||
Matrix<typename ei_traits<Rhs>::Scalar,
|
||||
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
|
||||
{
|
||||
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
|
||||
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
|
||||
ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
|
||||
: m_lhs(lhs), m_rhs(rhs)
|
||||
{}
|
||||
|
||||
@ -178,10 +199,10 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
|
||||
{
|
||||
dst.resize(m_rhs.rows(), m_rhs.cols());
|
||||
ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit,
|
||||
Lhs::Mode&(UpperTriangularBit|LowerTriangularBit)>
|
||||
LhsMode&(UpperTriangularBit|LowerTriangularBit)>
|
||||
(
|
||||
m_lhs.rows(), // size
|
||||
m_lhs._expression().data(), // lhs
|
||||
m_lhs.data(), // lhs
|
||||
m_lhs.stride(), // lhsStride,
|
||||
m_rhs.data(), // rhs
|
||||
// int rhsIncr,
|
||||
@ -189,7 +210,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
|
||||
);
|
||||
}
|
||||
|
||||
const Lhs m_lhs;
|
||||
const typename Lhs::Nested m_lhs;
|
||||
const typename Rhs::Nested m_rhs;
|
||||
};
|
||||
|
||||
@ -197,25 +218,36 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
|
||||
* Wrapper to ei_product_selfadjoint_matrix
|
||||
***************************************************************************/
|
||||
|
||||
template<typename Lhs,typename Rhs>
|
||||
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
|
||||
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>,
|
||||
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
|
||||
struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>
|
||||
: public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>,
|
||||
Matrix<typename ei_traits<Rhs>::Scalar,
|
||||
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
|
||||
Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
|
||||
{
|
||||
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
|
||||
ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
|
||||
: m_lhs(lhs), m_rhs(rhs)
|
||||
{}
|
||||
|
||||
typedef typename Lhs::Scalar Scalar;
|
||||
|
||||
typedef typename Lhs::Nested LhsNested;
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
|
||||
|
||||
typedef typename Rhs::Nested RhsNested;
|
||||
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
|
||||
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
|
||||
|
||||
typedef typename ei_traits<Lhs>::ExpressionType LhsExpr;
|
||||
typedef typename LhsExpr::Nested LhsNested;
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
|
||||
enum { UpLo = ei_traits<Lhs>::Mode&(UpperTriangularBit|LowerTriangularBit) };
|
||||
enum {
|
||||
LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
|
||||
LhsIsSelfAdjoint = (LhsMode&SelfAdjointBit)==SelfAdjointBit,
|
||||
RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit),
|
||||
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
|
||||
};
|
||||
|
||||
template<typename Dest> inline void _addTo(Dest& dst) const
|
||||
{ evalTo(dst,1); }
|
||||
@ -231,26 +263,19 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
|
||||
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
|
||||
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
|
||||
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
|
||||
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
|
||||
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression());
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
ei_product_selfadjoint_matrix<Scalar,
|
||||
EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,
|
||||
ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, true,
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
|
||||
ei_traits<Rhs>::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate),
|
||||
EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,
|
||||
ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
|
||||
EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,
|
||||
ei_traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,bool(RhsBlasTraits::NeedToConjugate)),
|
||||
ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
|
||||
::run(
|
||||
lhs.rows(), rhs.cols(), // sizes
|
||||
@ -261,7 +286,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
|
||||
);
|
||||
}
|
||||
|
||||
const Lhs m_lhs;
|
||||
const LhsNested m_lhs;
|
||||
const RhsNested m_rhs;
|
||||
};
|
||||
|
||||
|
@ -138,6 +138,13 @@ template<typename MatrixType> void symm(const MatrixType& m)
|
||||
m2 = m1.template triangularView<UpperTriangular>();
|
||||
VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
|
||||
rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
|
||||
|
||||
// test matrix * selfadjoint
|
||||
m2 = m1.template triangularView<LowerTriangular>();
|
||||
VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<LowerTriangular>(),
|
||||
rhs23 = (rhs2) * (m1));
|
||||
VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<LowerTriangular>(),
|
||||
rhs23 = (s2*rhs2) * (s1*m1));
|
||||
}
|
||||
void test_product_selfadjoint()
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user