Currently, the binding of LLT to Lapacke is done using a large macro. This factors out a large part of the functionality of the macro and implement them explicitly.

This commit is contained in:
Erik Schultheis 2021-11-25 16:11:25 +00:00 committed by David Tellenbach
parent ec4efbd696
commit b8b6566f0f

View File

@ -39,60 +39,111 @@ namespace Eigen {
namespace internal { namespace internal {
template<typename Scalar> struct lapacke_llt; namespace lapacke_llt_helpers {
#define EIGEN_LAPACKE_LLT(EIGTYPE, BLASTYPE, LAPACKE_PREFIX) \ // -------------------------------------------------------------------------------------------------------------------
template<> struct lapacke_llt<EIGTYPE> \ // Translation from Eigen to Lapacke types
{ \ // -------------------------------------------------------------------------------------------------------------------
template<typename MatrixType> \
static inline Index potrf(MatrixType& m, char uplo) \ // For complex numbers, the types in Eigen and Lapacke are different, but layout compatible.
{ \ template<typename Scalar> struct translate_type;
lapack_int matrix_order; \ template<> struct translate_type<float> { using type = float; };
lapack_int size, lda, info, StorageOrder; \ template<> struct translate_type<double> { using type = double; };
EIGTYPE* a; \ template<> struct translate_type<dcomplex> { using type = lapack_complex_double; };
eigen_assert(m.rows()==m.cols()); \ template<> struct translate_type<scomplex> { using type = lapack_complex_float; };
/* Set up parameters for ?potrf */ \
size = convert_index<lapack_int>(m.rows()); \ // -------------------------------------------------------------------------------------------------------------------
StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor; \ // Dispatch for potrf handling double, float, complex double, complex float types
matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \ // -------------------------------------------------------------------------------------------------------------------
a = &(m.coeffRef(0,0)); \
lda = convert_index<lapack_int>(m.outerStride()); \ inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, double* a, lapack_int lda) {
\ return LAPACKE_dpotrf( matrix_order, uplo, size, a, lda );
info = LAPACKE_##LAPACKE_PREFIX##potrf( matrix_order, uplo, size, (BLASTYPE*)a, lda ); \ }
info = (info==0) ? -1 : info>0 ? info-1 : size; \
return info; \ inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, float* a, lapack_int lda) {
} \ return LAPACKE_spotrf( matrix_order, uplo, size, a, lda );
}; \ }
template<> struct llt_inplace<EIGTYPE, Lower> \
{ \ inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_double* a, lapack_int lda) {
template<typename MatrixType> \ return LAPACKE_zpotrf( matrix_order, uplo, size, a, lda );
static Index blocked(MatrixType& m) \ }
{ \
return lapacke_llt<EIGTYPE>::potrf(m, 'L'); \ inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_float* a, lapack_int lda) {
} \ return LAPACKE_cpotrf( matrix_order, uplo, size, a, lda );
template<typename MatrixType, typename VectorType> \ }
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \
{ return Eigen::internal::llt_rank_update_lower(mat, vec, sigma); } \ // -------------------------------------------------------------------------------------------------------------------
}; \ // Dispatch for rank update handling upper and lower parts
template<> struct llt_inplace<EIGTYPE, Upper> \ // -------------------------------------------------------------------------------------------------------------------
{ \
template<typename MatrixType> \ template<unsigned Mode>
static Index blocked(MatrixType& m) \ struct rank_update {};
{ \
return lapacke_llt<EIGTYPE>::potrf(m, 'U'); \ template<>
} \ struct rank_update<Lower> {
template<typename MatrixType, typename VectorType> \ template<typename MatrixType, typename VectorType>
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \ static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
{ \ return Eigen::internal::llt_rank_update_lower(mat, vec, sigma);
Transpose<MatrixType> matt(mat); \ }
return llt_inplace<EIGTYPE, Lower>::rankUpdate(matt, vec.conjugate(), sigma); \
} \
}; };
EIGEN_LAPACKE_LLT(double, double, d) template<>
EIGEN_LAPACKE_LLT(float, float, s) struct rank_update<Upper> {
EIGEN_LAPACKE_LLT(dcomplex, lapack_complex_double, z) template<typename MatrixType, typename VectorType>
EIGEN_LAPACKE_LLT(scomplex, lapack_complex_float, c) static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) {
Transpose<MatrixType> matt(mat);
return Eigen::internal::llt_rank_update_lower(matt, vec.conjugate(), sigma);
}
};
// -------------------------------------------------------------------------------------------------------------------
// Generic lapacke llt implementation that hands of to the dispatches
// -------------------------------------------------------------------------------------------------------------------
template<typename Scalar, unsigned Mode>
struct lapacke_llt {
using BlasType = typename translate_type<Scalar>::type;
template<typename MatrixType>
static Index blocked(MatrixType& m)
{
eigen_assert(m.rows()==m.cols());
/* Set up parameters for ?potrf */
lapack_int size = convert_index<lapack_int>(m.rows());
lapack_int StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor;
lapack_int matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;
Scalar* a = &(m.coeffRef(0,0));
lapack_int lda = convert_index<lapack_int>(m.outerStride());
lapack_int info = potrf( matrix_order, Mode == Lower ? 'L' : 'U', size, (BlasType*)a, lda );
info = (info==0) ? -1 : info>0 ? info-1 : size;
return info;
}
template<typename MatrixType, typename VectorType>
static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma)
{
return rank_update<Mode>::run(mat, vec, sigma);
}
};
}
// end namespace lapacke_llt_helpers
/*
* Here, we just put the generic implementation from lapacke_llt into a full specialization of the llt_inplace
* type. By being a full specialization, the versions defined here thus get precedence over the generic implementation
* in LLT.h for double, float and complex double, complex float types.
*/
#define EIGEN_LAPACKE_LLT(EIGTYPE) \
template<> struct llt_inplace<EIGTYPE, Lower> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Lower> {}; \
template<> struct llt_inplace<EIGTYPE, Upper> : public lapacke_llt_helpers::lapacke_llt<EIGTYPE, Upper> {};
EIGEN_LAPACKE_LLT(double)
EIGEN_LAPACKE_LLT(float)
EIGEN_LAPACKE_LLT(dcomplex)
EIGEN_LAPACKE_LLT(scomplex)
#undef EIGEN_LAPACKE_LLT
} // end namespace internal } // end namespace internal