Fix compilation of MKL Pardiso support

This commit is contained in:
Gael Guennebaud 2015-06-24 14:53:43 +02:00
parent 2a33075aeb
commit 95e19be381

49
Eigen/src/PardisoSupport/PardisoSupport.h Normal file → Executable file
View File

@ -54,7 +54,7 @@ namespace internal
template<>
struct pardiso_run_selector<long long int>
{
typedef long long int IndexTypeType;
typedef long long int IndexType;
static IndexType run( _MKL_DSS_HANDLE_t pt, IndexType maxfct, IndexType mnum, IndexType type, IndexType phase, IndexType n, void *a,
IndexType *ia, IndexType *ja, IndexType *perm, IndexType nrhs, IndexType *iparm, IndexType msglvl, void *b, void *x)
{
@ -93,19 +93,19 @@ namespace internal
typedef typename _MatrixType::StorageIndex StorageIndex;
};
}
} // end namespace internal
template<class Derived>
class PardisoImpl : public SparseSolveBase<PardisoImpl<Derived>
class PardisoImpl : public SparseSolverBase<Derived>
{
protected:
typedef SparseSolveBase<PardisoImpl<Derived> Base;
typedef SparseSolverBase<Derived> Base;
using Base::derived;
using Base::m_isInitialized;
typedef internal::pardiso_traits<Derived> Traits;
public:
using base::_solve_impl;
using Base::_solve_impl;
typedef typename Traits::MatrixType MatrixType;
typedef typename Traits::Scalar Scalar;
@ -173,16 +173,17 @@ class PardisoImpl : public SparseSolveBase<PardisoImpl<Derived>
Derived& compute(const MatrixType& matrix);
template<typename BDerived, typename XDerived>
bool _solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived>& x) const;
template<typename Rhs,typename Dest>
void _solve_impl(const MatrixBase<Rhs> &b, MatrixBase<Dest> &dest) const;
protected:
void pardisoRelease()
{
if(m_isInitialized) // Factorization ran at least once
{
internal::pardiso_run_selector<StorageIndex>::run(m_pt, 1, 1, m_type, -1, m_size, 0, 0, 0, m_perm.data(), 0,
m_iparm.data(), m_msglvl, 0, 0);
internal::pardiso_run_selector<StorageIndex>::run(m_pt, 1, 1, m_type, -1, m_size,0, 0, 0, m_perm.data(), 0,
m_iparm.data(), m_msglvl, NULL, NULL);
m_isInitialized = false;
}
}
@ -217,12 +218,14 @@ class PardisoImpl : public SparseSolveBase<PardisoImpl<Derived>
m_iparm[27] = (sizeof(RealScalar) == 4) ? 1 : 0;
m_iparm[34] = 1; // C indexing
m_iparm[59] = 1; // Automatic switch between In-Core and Out-of-Core modes
memset(m_pt, 0, sizeof(m_pt));
}
protected:
// cached data to reduce reallocation, etc.
void manageErrorCode(Index error)
void manageErrorCode(Index error) const
{
switch(error)
{
@ -239,7 +242,7 @@ class PardisoImpl : public SparseSolveBase<PardisoImpl<Derived>
}
mutable SparseMatrixType m_matrix;
ComputationInfo m_info;
mutable ComputationInfo m_info;
bool m_analysisIsOk, m_factorizationIsOk;
Index m_type, m_msglvl;
mutable void *m_pt[64];
@ -256,7 +259,6 @@ Derived& PardisoImpl<Derived>::compute(const MatrixType& a)
eigen_assert(a.rows() == a.cols());
pardisoRelease();
memset(m_pt, 0, sizeof(m_pt));
m_perm.setZero(m_size);
derived().getMatrix(a);
@ -279,7 +281,6 @@ Derived& PardisoImpl<Derived>::analyzePattern(const MatrixType& a)
eigen_assert(m_size == a.cols());
pardisoRelease();
memset(m_pt, 0, sizeof(m_pt));
m_perm.setZero(m_size);
derived().getMatrix(a);
@ -313,12 +314,15 @@ Derived& PardisoImpl<Derived>::factorize(const MatrixType& a)
return derived();
}
template<class Base>
template<class Derived>
template<typename BDerived,typename XDerived>
bool PardisoImpl<Base>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived>& x) const
void PardisoImpl<Derived>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived>& x) const
{
if(m_iparm[0] == 0) // Factorization was not computed
return false;
{
m_info = InvalidInput;
return;
}
//Index n = m_matrix.rows();
Index nrhs = Index(b.cols());
@ -353,7 +357,7 @@ bool PardisoImpl<Base>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XD
m_perm.data(), nrhs, m_iparm.data(), m_msglvl,
rhs_ptr, x.derived().data());
return error==0;
manageErrorCode(error);
}
@ -373,7 +377,7 @@ template<typename MatrixType>
class PardisoLU : public PardisoImpl< PardisoLU<MatrixType> >
{
protected:
typedef PardisoImpl< PardisoLU<MatrixType> > Base;
typedef PardisoImpl<PardisoLU> Base;
typedef typename Base::Scalar Scalar;
typedef typename Base::RealScalar RealScalar;
using Base::pardisoInit;
@ -401,6 +405,7 @@ class PardisoLU : public PardisoImpl< PardisoLU<MatrixType> >
void getMatrix(const MatrixType& matrix)
{
m_matrix = matrix;
m_matrix.makeCompressed();
}
};
@ -424,7 +429,6 @@ class PardisoLLT : public PardisoImpl< PardisoLLT<MatrixType,_UpLo> >
protected:
typedef PardisoImpl< PardisoLLT<MatrixType,_UpLo> > Base;
typedef typename Base::Scalar Scalar;
typedef typename Base::StorageIndex StorageIndex;
typedef typename Base::RealScalar RealScalar;
using Base::pardisoInit;
using Base::m_matrix;
@ -432,9 +436,9 @@ class PardisoLLT : public PardisoImpl< PardisoLLT<MatrixType,_UpLo> >
public:
typedef typename Base::StorageIndex StorageIndex;
enum { UpLo = _UpLo };
using Base::compute;
using Base::solve;
PardisoLLT()
: Base()
@ -457,6 +461,7 @@ class PardisoLLT : public PardisoImpl< PardisoLLT<MatrixType,_UpLo> >
PermutationMatrix<Dynamic,Dynamic,StorageIndex> p_null;
m_matrix.resize(matrix.rows(), matrix.cols());
m_matrix.template selfadjointView<Upper>() = matrix.template selfadjointView<UpLo>().twistedBy(p_null);
m_matrix.makeCompressed();
}
};
@ -482,7 +487,6 @@ class PardisoLDLT : public PardisoImpl< PardisoLDLT<MatrixType,Options> >
protected:
typedef PardisoImpl< PardisoLDLT<MatrixType,Options> > Base;
typedef typename Base::Scalar Scalar;
typedef typename Base::StorageIndex StorageIndex;
typedef typename Base::RealScalar RealScalar;
using Base::pardisoInit;
using Base::m_matrix;
@ -490,8 +494,8 @@ class PardisoLDLT : public PardisoImpl< PardisoLDLT<MatrixType,Options> >
public:
typedef typename Base::StorageIndex StorageIndex;
using Base::compute;
using Base::solve;
enum { UpLo = Options&(Upper|Lower) };
PardisoLDLT()
@ -513,6 +517,7 @@ class PardisoLDLT : public PardisoImpl< PardisoLDLT<MatrixType,Options> >
PermutationMatrix<Dynamic,Dynamic,StorageIndex> p_null;
m_matrix.resize(matrix.rows(), matrix.cols());
m_matrix.template selfadjointView<Upper>() = matrix.template selfadjointView<UpLo>().twistedBy(p_null);
m_matrix.makeCompressed();
}
};