mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Currently, the binding of LLT to Lapacke is done using a large macro. This factors out a large part of the functionality of the macro and implement them explicitly.
This commit is contained in:
parent
ec4efbd696
commit
b8b6566f0f
@ -39,60 +39,111 @@ namespace Eigen {
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename Scalar> struct lapacke_llt;
|
namespace lapacke_llt_helpers {
|
||||||
|
|
||||||
#define EIGEN_LAPACKE_LLT(EIGTYPE, BLASTYPE, LAPACKE_PREFIX) \
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
template<> struct lapacke_llt<EIGTYPE> \
|
// Translation from Eigen to Lapacke types
|
||||||
{ \
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
template<typename MatrixType> \
|
|
||||||
static inline Index potrf(MatrixType& m, char uplo) \
|
|
||||||
{ \
|
|
||||||
lapack_int matrix_order; \
|
|
||||||
lapack_int size, lda, info, StorageOrder; \
|
|
||||||
EIGTYPE* a; \
|
|
||||||
eigen_assert(m.rows()==m.cols()); \
|
|
||||||
/* Set up parameters for ?potrf */ \
|
|
||||||
size = convert_index<lapack_int>(m.rows()); \
|
|
||||||
StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor; \
|
|
||||||
matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \
|
|
||||||
a = &(m.coeffRef(0,0)); \
|
|
||||||
lda = convert_index<lapack_int>(m.outerStride()); \
|
|
||||||
\
|
|
||||||
info = LAPACKE_##LAPACKE_PREFIX##potrf( matrix_order, uplo, size, (BLASTYPE*)a, lda ); \
|
|
||||||
info = (info==0) ? -1 : info>0 ? info-1 : size; \
|
|
||||||
return info; \
|
|
||||||
} \
|
|
||||||
}; \
|
|
||||||
template<> struct llt_inplace<EIGTYPE, Lower> \
|
|
||||||
{ \
|
|
||||||
template<typename MatrixType> \
|
|
||||||
static Index blocked(MatrixType& m) \
|
|
||||||
{ \
|
|
||||||
return lapacke_llt<EIGTYPE>::potrf(m, 'L'); \
|
|
||||||
} \
|
|
||||||
template<typename MatrixType, typename VectorType> \
|
|
||||||
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
|
|
||||||
{ return Eigen::internal::llt_rank_update_lower(mat, vec, sigma); } \
|
|
||||||
}; \
|
|
||||||
template<> struct llt_inplace<EIGTYPE, Upper> \
|
|
||||||
{ \
|
|
||||||
template<typename MatrixType> \
|
|
||||||
static Index blocked(MatrixType& m) \
|
|
||||||
{ \
|
|
||||||
return lapacke_llt<EIGTYPE>::potrf(m, 'U'); \
|
|
||||||
} \
|
|
||||||
template<typename MatrixType, typename VectorType> \
|
|
||||||
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
|
|
||||||
{ \
|
|
||||||
Transpose<MatrixType> matt(mat); \
|
|
||||||
return llt_inplace<EIGTYPE, Lower>::rankUpdate(matt, vec.conjugate(), sigma); \
|
|
||||||
} \
|
|
||||||
};
|
|
||||||
|
|
||||||
EIGEN_LAPACKE_LLT(double, double, d)
|
// For complex numbers, the types in Eigen and Lapacke are different, but layout compatible.
|
||||||
EIGEN_LAPACKE_LLT(float, float, s)
|
template<typename Scalar> struct translate_type;
|
||||||
EIGEN_LAPACKE_LLT(dcomplex, lapack_complex_double, z)
|
template<> struct translate_type<float> { using type = float; };
|
||||||
EIGEN_LAPACKE_LLT(scomplex, lapack_complex_float, c)
|
template<> struct translate_type<double> { using type = double; };
|
||||||
|
template<> struct translate_type<dcomplex> { using type = lapack_complex_double; };
|
||||||
|
template<> struct translate_type<scomplex> { using type = lapack_complex_float; };
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
|
// Dispatch for potrf handling double, float, complex double, complex float types
|
||||||
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, double* a, lapack_int lda) {
|
||||||
|
return LAPACKE_dpotrf( matrix_order, uplo, size, a, lda );
|
||||||
|
}
|
||||||
|
|
||||||
|
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, float* a, lapack_int lda) {
|
||||||
|
return LAPACKE_spotrf( matrix_order, uplo, size, a, lda );
|
||||||
|
}
|
||||||
|
|
||||||
|
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_double* a, lapack_int lda) {
|
||||||
|
return LAPACKE_zpotrf( matrix_order, uplo, size, a, lda );
|
||||||
|
}
|
||||||
|
|
||||||
|
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_float* a, lapack_int lda) {
|
||||||
|
return LAPACKE_cpotrf( matrix_order, uplo, size, a, lda );
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
|
// Dispatch for rank update handling upper and lower parts
|
||||||
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template<unsigned Mode>
|
||||||
|
struct rank_update {};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct rank_update<Lower> {
|
||||||
|
template<typename MatrixType, typename VectorType>
|
||||||
|
static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
|
||||||
|
return Eigen::internal::llt_rank_update_lower(mat, vec, sigma);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct rank_update<Upper> {
|
||||||
|
template<typename MatrixType, typename VectorType>
|
||||||
|
static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
|
||||||
|
Transpose<MatrixType> matt(mat);
|
||||||
|
return Eigen::internal::llt_rank_update_lower(matt, vec.conjugate(), sigma);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
|
// Generic lapacke llt implementation that hands of to the dispatches
|
||||||
|
// -------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template<typename Scalar, unsigned Mode>
|
||||||
|
struct lapacke_llt {
|
||||||
|
using BlasType = typename translate_type<Scalar>::type;
|
||||||
|
template<typename MatrixType>
|
||||||
|
static Index blocked(MatrixType& m)
|
||||||
|
{
|
||||||
|
eigen_assert(m.rows()==m.cols());
|
||||||
|
/* Set up parameters for ?potrf */
|
||||||
|
lapack_int size = convert_index<lapack_int>(m.rows());
|
||||||
|
lapack_int StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor;
|
||||||
|
lapack_int matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;
|
||||||
|
Scalar* a = &(m.coeffRef(0,0));
|
||||||
|
lapack_int lda = convert_index<lapack_int>(m.outerStride());
|
||||||
|
|
||||||
|
lapack_int info = potrf( matrix_order, Mode == Lower ? 'L' : 'U', size, (BlasType*)a, lda );
|
||||||
|
info = (info==0) ? -1 : info>0 ? info-1 : size;
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename MatrixType, typename VectorType>
|
||||||
|
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma)
|
||||||
|
{
|
||||||
|
return rank_update<Mode>::run(mat, vec, sigma);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// end namespace lapacke_llt_helpers
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Here, we just put the generic implementation from lapacke_llt into a full specialization of the llt_inplace
|
||||||
|
* type. By being a full specialization, the versions defined here thus get precedence over the generic implementation
|
||||||
|
* in LLT.h for double, float and complex double, complex float types.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define EIGEN_LAPACKE_LLT(EIGTYPE) \
|
||||||
|
template<> struct llt_inplace<EIGTYPE, Lower> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Lower> {}; \
|
||||||
|
template<> struct llt_inplace<EIGTYPE, Upper> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Upper> {};
|
||||||
|
|
||||||
|
EIGEN_LAPACKE_LLT(double)
|
||||||
|
EIGEN_LAPACKE_LLT(float)
|
||||||
|
EIGEN_LAPACKE_LLT(dcomplex)
|
||||||
|
EIGEN_LAPACKE_LLT(scomplex)
|
||||||
|
|
||||||
|
#undef EIGEN_LAPACKE_LLT
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user