Relax dependency on MKL for EIGEN_USE_BLAS

This commit is contained in:
Gael Guennebaud 2016-04-11 15:17:14 +02:00
parent 4e8e5888d7
commit 6a9ca88e7e
9 changed files with 98 additions and 132 deletions

View File

@ -50,25 +50,26 @@ template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \
struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \
Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Specialized> { \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
{ \
if (lhs==rhs) { \
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} else { \
general_matrix_matrix_triangular_product<Index, \
Scalar, LhsStorageOrder, ConjugateLhs, \
Scalar, RhsStorageOrder, ConjugateRhs, \
ColMajor, UpLo, BuiltIn> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} \
} \
};
EIGEN_MKL_RANKUPDATE_SPECIALIZE(double)
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex)
EIGEN_MKL_RANKUPDATE_SPECIALIZE(float)
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex)
// TODO handle complex cases
// EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex)
// EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex)
// SYRK for float/double
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \
@ -80,7 +81,7 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
conjA = ((AStorageOrder==ColMajor) && ConjugateA) ? 1 : 0 \
}; \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
\
@ -105,7 +106,7 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
conjA = (((AStorageOrder==ColMajor) && ConjugateA) || ((AStorageOrder==RowMajor) && !ConjugateA)) ? 1 : 0 \
}; \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
\
@ -132,11 +133,12 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
};
EIGEN_MKL_RANKUPDATE_R(double, double, dsyrk)
EIGEN_MKL_RANKUPDATE_R(float, float, ssyrk)
EIGEN_MKL_RANKUPDATE_R(double, double, dsyrk_)
EIGEN_MKL_RANKUPDATE_R(float, float, ssyrk_)
//EIGEN_MKL_RANKUPDATE_C(dcomplex, MKL_Complex16, double, zherk)
//EIGEN_MKL_RANKUPDATE_C(scomplex, MKL_Complex8, double, cherk)
// TODO hanlde complex cases
// EIGEN_MKL_RANKUPDATE_C(dcomplex, double, double, zherk_)
// EIGEN_MKL_RANKUPDATE_C(scomplex, float, float, cherk_)
} // end namespace internal

View File

@ -68,9 +68,8 @@ static void run(Index rows, Index cols, Index depth, \
char transa, transb; \
MKL_INT m, n, k, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX a_tmp, b_tmp; \
EIGTYPE myone(1);\
\
/* Set transpose options */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
@ -81,10 +80,6 @@ static void run(Index rows, Index cols, Index depth, \
n = (MKL_INT)cols; \
k = (MKL_INT)depth; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \
ldb = (MKL_INT)rhsStride; \
@ -105,13 +100,13 @@ static void run(Index rows, Index cols, Index depth, \
ldb = b_tmp.outerStride(); \
} else b = _rhs; \
\
MKLPREFIX##gemm(&transa, &transb, &m, &n, &k, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
}};
GEMM_SPECIALIZATION(double, d, double, d)
GEMM_SPECIALIZATION(float, f, float, s)
GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, z)
GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, c)
GEMM_SPECIALIZATION(double, d, double, d)
GEMM_SPECIALIZATION(float, f, float, s)
GEMM_SPECIALIZATION(dcomplex, cd, double, z)
GEMM_SPECIALIZATION(scomplex, cf, float, c)
} // end namespase internal

View File

@ -98,15 +98,13 @@ static void run( \
EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
{ \
MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \
MKLTYPE alpha_, beta_; \
const EIGTYPE *x_ptr, myone(1); \
const EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \
char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
if (LhsStorageOrder==RowMajor) { \
m=cols; \
n=rows; \
m = cols; \
n = rows; \
}\
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
GEMVVector x_tmp; \
if (ConjugateRhs) { \
Map<const GEMVVector, 0, InnerStride<> > map_x(rhs,cols,1,InnerStride<>(incx)); \
@ -114,14 +112,14 @@ static void run( \
x_ptr=x_tmp.data(); \
incx=1; \
} else x_ptr=rhs; \
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \
}\
};
EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d)
EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s)
EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, z)
EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, MKL_Complex8, c)
EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d)
EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s)
EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, double, z)
EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, float, c)
} // end namespase internal

View File

@ -52,24 +52,19 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
EIGTYPE myone(1);\
\
/* Set transpose options */ \
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \
ldb = (MKL_INT)rhsStride; \
@ -86,7 +81,7 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
ldb = b_tmp.outerStride(); \
} else b = _rhs; \
\
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
\
} \
};
@ -103,25 +98,20 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> a_tmp; \
EIGTYPE myone(1); \
\
/* Set transpose options */ \
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \
ldb = (MKL_INT)rhsStride; \
@ -154,15 +144,15 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
ldb = b_tmp.outerStride(); \
} \
\
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
\
} \
};
EIGEN_MKL_SYMM_L(double, double, d, d)
EIGEN_MKL_SYMM_L(float, float, f, s)
EIGEN_MKL_HEMM_L(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_HEMM_L(scomplex, MKL_Complex8, cf, c)
EIGEN_MKL_HEMM_L(dcomplex, double, cd, z)
EIGEN_MKL_HEMM_L(scomplex, float, cf, c)
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
@ -179,23 +169,18 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
EIGTYPE myone(1);\
\
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \
ldb = (MKL_INT)lhsStride; \
@ -212,7 +197,7 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = b_tmp.outerStride(); \
} else b = _lhs; \
\
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
\
} \
};
@ -229,24 +214,19 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
EIGTYPE myone(1); \
\
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \
ldb = (MKL_INT)lhsStride; \
@ -279,14 +259,14 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = b_tmp.outerStride(); \
} \
\
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
} \
};
EIGEN_MKL_SYMM_R(double, double, d, d)
EIGEN_MKL_SYMM_R(float, float, f, s)
EIGEN_MKL_HEMM_R(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_HEMM_R(scomplex, MKL_Complex8, cf, c)
EIGEN_MKL_HEMM_R(dcomplex, double, cd, z)
EIGEN_MKL_HEMM_R(scomplex, float, cf, c)
} // end namespace internal

View File

@ -86,25 +86,23 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
IsLower = UpLo == Lower ? 1 : 0 \
}; \
MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \
MKLTYPE alpha_, beta_; \
const EIGTYPE *x_ptr, myone(1); \
EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
SYMVVector x_tmp; \
if (ConjugateRhs) { \
Map<const SYMVVector, 0 > map_x(_rhs,size,1); \
x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \
} else x_ptr=_rhs; \
MKLFUNC(&uplo, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \
}\
};
EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv)
EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv)
EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv)
EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv)
EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv_)
EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv_)
EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, float, chemv_)
} // end namespace internal

View File

@ -109,7 +109,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
if (rows != depth) { \
\
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
/* FIXME handle mkl_domain_get_max_threads */ \
/*int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS);*/ int nthr = 1;\
\
if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \
@ -134,10 +135,6 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
EIGTYPE *b; \
const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \
MKLTYPE alpha_; \
\
/* Set alpha_*/ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
\
/* Set m, n */ \
m = (MKL_INT)diagSize; \
@ -175,7 +172,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
} \
/*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
/* call ?trmm*/ \
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -184,9 +181,9 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
};
EIGEN_MKL_TRMM_L(double, double, d, d)
EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMM_L(dcomplex, double, cd, z)
EIGEN_MKL_TRMM_L(float, float, f, s)
EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
EIGEN_MKL_TRMM_L(scomplex, float, cf, c)
// implements col-major += alpha * op(general) * op(triangular)
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
@ -223,7 +220,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
if (cols != depth) { \
\
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS)*/; \
\
if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \
@ -248,10 +245,6 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
EIGTYPE *b; \
const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \
MKLTYPE alpha_; \
\
/* Set alpha_*/ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
\
/* Set m, n */ \
m = (MKL_INT)rows; \
@ -289,7 +282,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
} \
/*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
/* call ?trmm*/ \
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -298,9 +291,9 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
};
EIGEN_MKL_TRMM_R(double, double, d, d)
EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMM_R(dcomplex, double, cd, z)
EIGEN_MKL_TRMM_R(float, float, f, s)
EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
EIGEN_MKL_TRMM_R(scomplex, float, cf, c)
} // end namespace internal

View File

@ -107,9 +107,7 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \
EIGTYPE const *a; \
MKLTYPE alpha_, beta_; \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
EIGTYPE beta(1); \
\
/* Set m, n */ \
n = (MKL_INT)size; \
@ -123,10 +121,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@ -144,15 +142,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
m = size; \
n = cols-size; \
} \
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \
} \
} \
};
EIGEN_MKL_TRMV_CM(double, double, d, d)
EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMV_CM(float, float, f, s)
EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
EIGEN_MKL_TRMV_CM(double, double, d, d)
EIGEN_MKL_TRMV_CM(dcomplex, double, cd, z)
EIGEN_MKL_TRMV_CM(float, float, f, s)
EIGEN_MKL_TRMV_CM(scomplex, float, cf, c)
// implements row-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
@ -191,9 +189,7 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \
EIGTYPE const *a; \
MKLTYPE alpha_, beta_; \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
EIGTYPE beta(1); \
\
/* Set m, n */ \
n = (MKL_INT)size; \
@ -207,10 +203,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@ -228,15 +224,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
m = size; \
n = cols-size; \
} \
MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \
} \
} \
};
EIGEN_MKL_TRMV_RM(double, double, d, d)
EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMV_RM(float, float, f, s)
EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
EIGEN_MKL_TRMV_RM(double, double, d, d)
EIGEN_MKL_TRMV_RM(dcomplex, double, cd, z)
EIGEN_MKL_TRMV_RM(float, float, f, s)
EIGEN_MKL_TRMV_RM(scomplex, float, cf, c)
} // end namespase internal

View File

@ -56,9 +56,7 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
MKL_INT m = size, n = otherSize, lda, ldb; \
char side = 'L', uplo, diag='N', transa; \
/* Set alpha_ */ \
MKLTYPE alpha; \
EIGTYPE myone(1); \
assign_scalar_eig2mkl(alpha, myone); \
EIGTYPE alpha(1); \
ldb = otherStride;\
\
const EIGTYPE *a; \
@ -82,14 +80,14 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
} \
};
EIGEN_MKL_TRSM_L(double, double, d)
EIGEN_MKL_TRSM_L(dcomplex, MKL_Complex16, z)
EIGEN_MKL_TRSM_L(float, float, s)
EIGEN_MKL_TRSM_L(scomplex, MKL_Complex8, c)
EIGEN_MKL_TRSM_L(double, double, d)
EIGEN_MKL_TRSM_L(dcomplex, double, z)
EIGEN_MKL_TRSM_L(float, float, s)
EIGEN_MKL_TRSM_L(scomplex, float, c)
// implements RightSide general * op(triangular)^-1
@ -111,9 +109,7 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
MKL_INT m = otherSize, n = size, lda, ldb; \
char side = 'R', uplo, diag='N', transa; \
/* Set alpha_ */ \
MKLTYPE alpha; \
EIGTYPE myone(1); \
assign_scalar_eig2mkl(alpha, myone); \
EIGTYPE alpha(1); \
ldb = otherStride;\
\
const EIGTYPE *a; \
@ -137,15 +133,15 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
/*std::cout << "TRMS_L specialization!\n";*/ \
} \
};
EIGEN_MKL_TRSM_R(double, double, d)
EIGEN_MKL_TRSM_R(dcomplex, MKL_Complex16, z)
EIGEN_MKL_TRSM_R(float, float, s)
EIGEN_MKL_TRSM_R(scomplex, MKL_Complex8, c)
EIGEN_MKL_TRSM_R(double, double, d)
EIGEN_MKL_TRSM_R(dcomplex, double, z)
EIGEN_MKL_TRSM_R(float, float, s)
EIGEN_MKL_TRSM_R(scomplex, float, c)
} // end namespace internal

View File

@ -49,7 +49,7 @@
#define EIGEN_USE_LAPACKE
#endif
#if defined(EIGEN_USE_BLAS) || defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
#if defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
#define EIGEN_USE_MKL
#endif
@ -64,7 +64,6 @@
# ifndef EIGEN_USE_MKL
/*If the MKL version is too old, undef everything*/
# undef EIGEN_USE_MKL_ALL
# undef EIGEN_USE_BLAS
# undef EIGEN_USE_LAPACKE
# undef EIGEN_USE_MKL_VML
# undef EIGEN_USE_LAPACKE_STRICT
@ -107,12 +106,17 @@
#else
#define EIGEN_MKL_DOMAIN_PARDISO MKL_PARDISO
#endif
#endif
namespace Eigen {
typedef std::complex<double> dcomplex;
typedef std::complex<float> scomplex;
#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL)
typedef int MKL_INT;
#endif
namespace internal {
template<typename MKLType, typename EigenType>
@ -125,6 +129,7 @@ static inline void assign_conj_scalar_eig2mkl(MKLType& mklScalar, const EigenTyp
mklScalar=eigenScalar;
}
#ifdef EIGEN_USE_MKL
template <>
inline void assign_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
@ -148,11 +153,14 @@ inline void assign_conj_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklS
mklScalar.real=eigenScalar.real();
mklScalar.imag=-eigenScalar.imag();
}
#endif
} // end namespace internal
} // end namespace Eigen
#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL)
#include "../../misc/blas.h"
#endif
#endif // EIGEN_MKL_SUPPORT_H