mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Currently, the binding of LLT to Lapacke is done using a large macro. This factors out a large part of the functionality of the macro and implement them explicitly.
This commit is contained in:
parent
ec4efbd696
commit
b8b6566f0f
@ -39,60 +39,111 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Scalar> struct lapacke_llt;
|
||||
namespace lapacke_llt_helpers {
|
||||
|
||||
#define EIGEN_LAPACKE_LLT(EIGTYPE, BLASTYPE, LAPACKE_PREFIX) \
|
||||
template<> struct lapacke_llt<EIGTYPE> \
|
||||
{ \
|
||||
template<typename MatrixType> \
|
||||
static inline Index potrf(MatrixType& m, char uplo) \
|
||||
{ \
|
||||
lapack_int matrix_order; \
|
||||
lapack_int size, lda, info, StorageOrder; \
|
||||
EIGTYPE* a; \
|
||||
eigen_assert(m.rows()==m.cols()); \
|
||||
/* Set up parameters for ?potrf */ \
|
||||
size = convert_index<lapack_int>(m.rows()); \
|
||||
StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor; \
|
||||
matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \
|
||||
a = &(m.coeffRef(0,0)); \
|
||||
lda = convert_index<lapack_int>(m.outerStride()); \
|
||||
\
|
||||
info = LAPACKE_##LAPACKE_PREFIX##potrf( matrix_order, uplo, size, (BLASTYPE*)a, lda ); \
|
||||
info = (info==0) ? -1 : info>0 ? info-1 : size; \
|
||||
return info; \
|
||||
} \
|
||||
}; \
|
||||
template<> struct llt_inplace<EIGTYPE, Lower> \
|
||||
{ \
|
||||
template<typename MatrixType> \
|
||||
static Index blocked(MatrixType& m) \
|
||||
{ \
|
||||
return lapacke_llt<EIGTYPE>::potrf(m, 'L'); \
|
||||
} \
|
||||
template<typename MatrixType, typename VectorType> \
|
||||
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
|
||||
{ return Eigen::internal::llt_rank_update_lower(mat, vec, sigma); } \
|
||||
}; \
|
||||
template<> struct llt_inplace<EIGTYPE, Upper> \
|
||||
{ \
|
||||
template<typename MatrixType> \
|
||||
static Index blocked(MatrixType& m) \
|
||||
{ \
|
||||
return lapacke_llt<EIGTYPE>::potrf(m, 'U'); \
|
||||
} \
|
||||
template<typename MatrixType, typename VectorType> \
|
||||
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
|
||||
{ \
|
||||
Transpose<MatrixType> matt(mat); \
|
||||
return llt_inplace<EIGTYPE, Lower>::rankUpdate(matt, vec.conjugate(), sigma); \
|
||||
} \
|
||||
};
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
// Translation from Eigen to Lapacke types
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
EIGEN_LAPACKE_LLT(double, double, d)
|
||||
EIGEN_LAPACKE_LLT(float, float, s)
|
||||
EIGEN_LAPACKE_LLT(dcomplex, lapack_complex_double, z)
|
||||
EIGEN_LAPACKE_LLT(scomplex, lapack_complex_float, c)
|
||||
// For complex numbers, the types in Eigen and Lapacke are different, but layout compatible.
|
||||
template<typename Scalar> struct translate_type;
|
||||
template<> struct translate_type<float> { using type = float; };
|
||||
template<> struct translate_type<double> { using type = double; };
|
||||
template<> struct translate_type<dcomplex> { using type = lapack_complex_double; };
|
||||
template<> struct translate_type<scomplex> { using type = lapack_complex_float; };
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
// Dispatch for potrf handling double, float, complex double, complex float types
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, double* a, lapack_int lda) {
|
||||
return LAPACKE_dpotrf( matrix_order, uplo, size, a, lda );
|
||||
}
|
||||
|
||||
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, float* a, lapack_int lda) {
|
||||
return LAPACKE_spotrf( matrix_order, uplo, size, a, lda );
|
||||
}
|
||||
|
||||
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_double* a, lapack_int lda) {
|
||||
return LAPACKE_zpotrf( matrix_order, uplo, size, a, lda );
|
||||
}
|
||||
|
||||
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_float* a, lapack_int lda) {
|
||||
return LAPACKE_cpotrf( matrix_order, uplo, size, a, lda );
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
// Dispatch for rank update handling upper and lower parts
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
template<unsigned Mode>
|
||||
struct rank_update {};
|
||||
|
||||
template<>
|
||||
struct rank_update<Lower> {
|
||||
template<typename MatrixType, typename VectorType>
|
||||
static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
|
||||
return Eigen::internal::llt_rank_update_lower(mat, vec, sigma);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct rank_update<Upper> {
|
||||
template<typename MatrixType, typename VectorType>
|
||||
static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
|
||||
Transpose<MatrixType> matt(mat);
|
||||
return Eigen::internal::llt_rank_update_lower(matt, vec.conjugate(), sigma);
|
||||
}
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
// Generic lapacke llt implementation that hands of to the dispatches
|
||||
// -------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
template<typename Scalar, unsigned Mode>
|
||||
struct lapacke_llt {
|
||||
using BlasType = typename translate_type<Scalar>::type;
|
||||
template<typename MatrixType>
|
||||
static Index blocked(MatrixType& m)
|
||||
{
|
||||
eigen_assert(m.rows()==m.cols());
|
||||
/* Set up parameters for ?potrf */
|
||||
lapack_int size = convert_index<lapack_int>(m.rows());
|
||||
lapack_int StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor;
|
||||
lapack_int matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;
|
||||
Scalar* a = &(m.coeffRef(0,0));
|
||||
lapack_int lda = convert_index<lapack_int>(m.outerStride());
|
||||
|
||||
lapack_int info = potrf( matrix_order, Mode == Lower ? 'L' : 'U', size, (BlasType*)a, lda );
|
||||
info = (info==0) ? -1 : info>0 ? info-1 : size;
|
||||
return info;
|
||||
}
|
||||
|
||||
template<typename MatrixType, typename VectorType>
|
||||
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma)
|
||||
{
|
||||
return rank_update<Mode>::run(mat, vec, sigma);
|
||||
}
|
||||
};
|
||||
}
|
||||
// end namespace lapacke_llt_helpers
|
||||
|
||||
/*
|
||||
* Here, we just put the generic implementation from lapacke_llt into a full specialization of the llt_inplace
|
||||
* type. By being a full specialization, the versions defined here thus get precedence over the generic implementation
|
||||
* in LLT.h for double, float and complex double, complex float types.
|
||||
*/
|
||||
|
||||
#define EIGEN_LAPACKE_LLT(EIGTYPE) \
|
||||
template<> struct llt_inplace<EIGTYPE, Lower> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Lower> {}; \
|
||||
template<> struct llt_inplace<EIGTYPE, Upper> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Upper> {};
|
||||
|
||||
EIGEN_LAPACKE_LLT(double)
|
||||
EIGEN_LAPACKE_LLT(float)
|
||||
EIGEN_LAPACKE_LLT(dcomplex)
|
||||
EIGEN_LAPACKE_LLT(scomplex)
|
||||
|
||||
#undef EIGEN_LAPACKE_LLT
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user