Currently, the binding of LLT to Lapacke is done using a large macro. This factors out a large part of the functionality of the macro and implement them explicitly.

This commit is contained in:
Erik Schultheis 2021-11-25 16:11:25 +00:00 committed by David Tellenbach
parent ec4efbd696
commit b8b6566f0f

View File

@ -39,60 +39,111 @@ namespace Eigen {
namespace internal {
template<typename Scalar> struct lapacke_llt;
namespace lapacke_llt_helpers {
#define EIGEN_LAPACKE_LLT(EIGTYPE, BLASTYPE, LAPACKE_PREFIX) \
template<> struct lapacke_llt<EIGTYPE> \
{ \
template<typename MatrixType> \
static inline Index potrf(MatrixType& m, char uplo) \
{ \
lapack_int matrix_order; \
lapack_int size, lda, info, StorageOrder; \
EIGTYPE* a; \
eigen_assert(m.rows()==m.cols()); \
/* Set up parameters for ?potrf */ \
size = convert_index<lapack_int>(m.rows()); \
StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor; \
matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \
a = &(m.coeffRef(0,0)); \
lda = convert_index<lapack_int>(m.outerStride()); \
\
info = LAPACKE_##LAPACKE_PREFIX##potrf( matrix_order, uplo, size, (BLASTYPE*)a, lda ); \
info = (info==0) ? -1 : info>0 ? info-1 : size; \
return info; \
} \
}; \
template<> struct llt_inplace<EIGTYPE, Lower> \
{ \
template<typename MatrixType> \
static Index blocked(MatrixType& m) \
{ \
return lapacke_llt<EIGTYPE>::potrf(m, 'L'); \
} \
template<typename MatrixType, typename VectorType> \
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
{ return Eigen::internal::llt_rank_update_lower(mat, vec, sigma); } \
}; \
template<> struct llt_inplace<EIGTYPE, Upper> \
{ \
template<typename MatrixType> \
static Index blocked(MatrixType& m) \
{ \
return lapacke_llt<EIGTYPE>::potrf(m, 'U'); \
} \
template<typename MatrixType, typename VectorType> \
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
{ \
Transpose<MatrixType> matt(mat); \
return llt_inplace<EIGTYPE, Lower>::rankUpdate(matt, vec.conjugate(), sigma); \
} \
// -------------------------------------------------------------------------------------------------------------------
// Translation from Eigen to Lapacke types
// -------------------------------------------------------------------------------------------------------------------
// For complex numbers, the types in Eigen and Lapacke are different, but layout compatible.
template<typename Scalar> struct translate_type;
template<> struct translate_type<float> { using type = float; };
template<> struct translate_type<double> { using type = double; };
template<> struct translate_type<dcomplex> { using type = lapack_complex_double; };
template<> struct translate_type<scomplex> { using type = lapack_complex_float; };
// -------------------------------------------------------------------------------------------------------------------
// Dispatch for potrf handling double, float, complex double, complex float types
// -------------------------------------------------------------------------------------------------------------------
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, double* a, lapack_int lda) {
return LAPACKE_dpotrf( matrix_order, uplo, size, a, lda );
}
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, float* a, lapack_int lda) {
return LAPACKE_spotrf( matrix_order, uplo, size, a, lda );
}
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_double* a, lapack_int lda) {
return LAPACKE_zpotrf( matrix_order, uplo, size, a, lda );
}
inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_float* a, lapack_int lda) {
return LAPACKE_cpotrf( matrix_order, uplo, size, a, lda );
}
// -------------------------------------------------------------------------------------------------------------------
// Dispatch for rank update handling upper and lower parts
// -------------------------------------------------------------------------------------------------------------------
template<unsigned Mode>
struct rank_update {};
template<>
struct rank_update<Lower> {
template<typename MatrixType, typename VectorType>
static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
return Eigen::internal::llt_rank_update_lower(mat, vec, sigma);
}
};
EIGEN_LAPACKE_LLT(double, double, d)
EIGEN_LAPACKE_LLT(float, float, s)
EIGEN_LAPACKE_LLT(dcomplex, lapack_complex_double, z)
EIGEN_LAPACKE_LLT(scomplex, lapack_complex_float, c)
template<>
struct rank_update<Upper> {
template<typename MatrixType, typename VectorType>
static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
Transpose<MatrixType> matt(mat);
return Eigen::internal::llt_rank_update_lower(matt, vec.conjugate(), sigma);
}
};
// -------------------------------------------------------------------------------------------------------------------
// Generic lapacke llt implementation that hands of to the dispatches
// -------------------------------------------------------------------------------------------------------------------
template<typename Scalar, unsigned Mode>
struct lapacke_llt {
using BlasType = typename translate_type<Scalar>::type;
template<typename MatrixType>
static Index blocked(MatrixType& m)
{
eigen_assert(m.rows()==m.cols());
/* Set up parameters for ?potrf */
lapack_int size = convert_index<lapack_int>(m.rows());
lapack_int StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor;
lapack_int matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;
Scalar* a = &(m.coeffRef(0,0));
lapack_int lda = convert_index<lapack_int>(m.outerStride());
lapack_int info = potrf( matrix_order, Mode == Lower ? 'L' : 'U', size, (BlasType*)a, lda );
info = (info==0) ? -1 : info>0 ? info-1 : size;
return info;
}
template<typename MatrixType, typename VectorType>
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma)
{
return rank_update<Mode>::run(mat, vec, sigma);
}
};
}
// end namespace lapacke_llt_helpers
/*
* Here, we just put the generic implementation from lapacke_llt into a full specialization of the llt_inplace
* type. By being a full specialization, the versions defined here thus get precedence over the generic implementation
* in LLT.h for double, float and complex double, complex float types.
*/
#define EIGEN_LAPACKE_LLT(EIGTYPE) \
template<> struct llt_inplace<EIGTYPE, Lower> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Lower> {}; \
template<> struct llt_inplace<EIGTYPE, Upper> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Upper> {};
EIGEN_LAPACKE_LLT(double)
EIGEN_LAPACKE_LLT(float)
EIGEN_LAPACKE_LLT(dcomplex)
EIGEN_LAPACKE_LLT(scomplex)
#undef EIGEN_LAPACKE_LLT
} // end namespace internal