mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-18 14:34:17 +08:00
implement optimized path for selfadjoint rank 1 update (safe regarding dynamic alloc)
This commit is contained in:
parent
3eb74cf9fc
commit
a486d5590a
@ -28,9 +28,109 @@
|
||||
/**********************************************************************
|
||||
* This file implements a self adjoint product: C += A A^T updating only
|
||||
* half of the selfadjoint matrix C.
|
||||
* It corresponds to the level 3 SYRK Blas routine.
|
||||
* It corresponds to the level 3 SYRK and level 2 SYR Blas routines.
|
||||
**********************************************************************/
|
||||
|
||||
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update;
|
||||
|
||||
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
|
||||
{
|
||||
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha)
|
||||
{
|
||||
internal::conj_if<ConjRhs> cj;
|
||||
typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap;
|
||||
typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjRhsType;
|
||||
for (Index i=0; i<size; ++i)
|
||||
{
|
||||
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1)))
|
||||
+= (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs>
|
||||
{
|
||||
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha)
|
||||
{
|
||||
selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vec,alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename MatrixType, typename OtherType, int UpLo, bool OtherIsVector = OtherType::IsVectorAtCompileTime>
|
||||
struct selfadjoint_product_selector;
|
||||
|
||||
template<typename MatrixType, typename OtherType, int UpLo>
|
||||
struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
|
||||
{
|
||||
static void run(MatrixType& mat, const OtherType& other, typename MatrixType::Scalar alpha)
|
||||
{
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::Index Index;
|
||||
typedef internal::blas_traits<OtherType> OtherBlasTraits;
|
||||
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
|
||||
typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
|
||||
const ActualOtherType actualOther = OtherBlasTraits::extract(other.derived());
|
||||
|
||||
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
|
||||
|
||||
enum {
|
||||
StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||
UseOtherDirectly = _ActualOtherType::InnerStrideAtCompileTime==1
|
||||
};
|
||||
internal::gemv_static_vector_if<Scalar,OtherType::SizeAtCompileTime,OtherType::MaxSizeAtCompileTime,!UseOtherDirectly> static_other;
|
||||
|
||||
bool freeOtherPtr = false;
|
||||
Scalar* actualOtherPtr;
|
||||
if(UseOtherDirectly)
|
||||
actualOtherPtr = const_cast<Scalar*>(actualOther.data());
|
||||
else
|
||||
{
|
||||
if((actualOtherPtr=static_other.data())==0)
|
||||
{
|
||||
freeOtherPtr = true;
|
||||
actualOtherPtr = ei_aligned_stack_new(Scalar,other.size());
|
||||
}
|
||||
Map<typename _ActualOtherType::PlainObject>(actualOtherPtr, actualOther.size()) = actualOther;
|
||||
}
|
||||
|
||||
selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
|
||||
OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
|
||||
(!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>
|
||||
::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha);
|
||||
|
||||
if((!UseOtherDirectly) && freeOtherPtr) ei_aligned_stack_delete(Scalar, actualOtherPtr, other.size());
|
||||
}
|
||||
};
|
||||
|
||||
template<typename MatrixType, typename OtherType, int UpLo>
|
||||
struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
|
||||
{
|
||||
static void run(MatrixType& mat, const OtherType& other, typename MatrixType::Scalar alpha)
|
||||
{
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::Index Index;
|
||||
typedef internal::blas_traits<OtherType> OtherBlasTraits;
|
||||
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
|
||||
typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
|
||||
const ActualOtherType actualOther = OtherBlasTraits::extract(other.derived());
|
||||
|
||||
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
|
||||
|
||||
enum { IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0 };
|
||||
|
||||
internal::general_matrix_matrix_triangular_product<Index,
|
||||
Scalar, _ActualOtherType::Flags&RowMajorBit ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
|
||||
Scalar, _ActualOtherType::Flags&RowMajorBit ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
|
||||
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
|
||||
::run(mat.cols(), actualOther.cols(),
|
||||
&actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
|
||||
mat.data(), mat.outerStride(), actualAlpha);
|
||||
}
|
||||
};
|
||||
|
||||
// high level API
|
||||
|
||||
template<typename MatrixType, unsigned int UpLo>
|
||||
@ -38,22 +138,7 @@ template<typename DerivedU>
|
||||
SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
|
||||
::rankUpdate(const MatrixBase<DerivedU>& u, Scalar alpha)
|
||||
{
|
||||
typedef internal::blas_traits<DerivedU> UBlasTraits;
|
||||
typedef typename UBlasTraits::DirectLinearAccessType ActualUType;
|
||||
typedef typename internal::remove_all<ActualUType>::type _ActualUType;
|
||||
const ActualUType actualU = UBlasTraits::extract(u.derived());
|
||||
|
||||
Scalar actualAlpha = alpha * UBlasTraits::extractScalarFactor(u.derived());
|
||||
|
||||
enum { IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0 };
|
||||
|
||||
internal::general_matrix_matrix_triangular_product<Index,
|
||||
Scalar, _ActualUType::Flags&RowMajorBit ? RowMajor : ColMajor, UBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
|
||||
Scalar, _ActualUType::Flags&RowMajorBit ? ColMajor : RowMajor, (!UBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
|
||||
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
|
||||
::run(_expression().cols(), actualU.cols(),
|
||||
&actualU.coeffRef(0,0), actualU.outerStride(), &actualU.coeffRef(0,0), actualU.outerStride(),
|
||||
_expression().const_cast_derived().data(), _expression().outerStride(), actualAlpha);
|
||||
selfadjoint_product_selector<MatrixType,DerivedU,UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
@ -44,6 +44,8 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
Rhs3 rhs3 = Rhs3::Random(internal::random<int>(1,320), rows);
|
||||
|
||||
Scalar s1 = internal::random<Scalar>();
|
||||
|
||||
Index c = internal::random<Index>(0,cols-1);
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(rhs2,s1)._expression()),
|
||||
@ -68,6 +70,30 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX(m2.template selfadjointView<Upper>().rankUpdate(rhs3.adjoint(),s1)._expression(),
|
||||
(s1 * rhs3.adjoint() * rhs3).eval().template triangularView<Upper>().toDenseMatrix());
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.col(c),s1)._expression()),
|
||||
((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.col(c),s1)._expression()),
|
||||
((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.col(c).conjugate(),s1)._expression()),
|
||||
((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.col(c).conjugate(),s1)._expression()),
|
||||
((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.row(c),s1)._expression()),
|
||||
((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.row(c).adjoint(),s1)._expression()),
|
||||
((s1 * m1.row(c).adjoint() * m1.row(c).adjoint().adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
}
|
||||
|
||||
void test_product_syrk()
|
||||
|
Loading…
Reference in New Issue
Block a user