KLU: truely disable unimplemented code, add proper static assertions in solve

This commit is contained in:
Gael Guennebaud 2017-11-10 14:09:01 +01:00
parent 6365f937d6
commit b82cd93c01

View File

@ -27,24 +27,24 @@ namespace Eigen {
*
* \implsparsesolverconcept
*
* \sa \ref TutorialSparseSolverConcept, class SparseLU
* \sa \ref TutorialSparseSolverConcept, class UmfPackLU, class SparseLU
*/
inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, int ldim, int nrhs, double B [ ], klu_common *Common, double) {
return klu_solve(Symbolic, Numeric, ldim, nrhs, B, Common);
inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B [ ], klu_common *Common, double) {
return klu_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B, Common);
}
inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, int ldim, int nrhs, std::complex<double>B[], klu_common *Common, std::complex<double>) {
return klu_z_solve(Symbolic, Numeric, ldim, nrhs, &numext::real_ref(B[0]), Common);
inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double>B[], klu_common *Common, std::complex<double>) {
return klu_z_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), &numext::real_ref(B[0]), Common);
}
inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, int ldim, int nrhs, double B[], klu_common *Common, double) {
return klu_tsolve(Symbolic, Numeric, ldim, nrhs, B, Common);
inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double) {
return klu_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B, Common);
}
inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, int ldim, int nrhs, std::complex<double>B[], klu_common *Common, std::complex<double>) {
return klu_z_tsolve(Symbolic, Numeric, ldim, nrhs, &numext::real_ref(B[0]), 0, Common);
inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double>B[], klu_common *Common, std::complex<double>) {
return klu_z_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), &numext::real_ref(B[0]), 0, Common);
}
inline klu_numeric* klu_factor(int Ap [ ], int Ai [ ], double Ax [ ], klu_symbolic *Symbolic, klu_common *Common, double) {
@ -114,7 +114,7 @@ class KLU : public SparseSolverBase<KLU<_MatrixType> >
eigen_assert(m_isInitialized && "Decomposition is not initialized.");
return m_info;
}
#if 0 // not implemented yet
inline const LUMatrixType& matrixL() const
{
if (m_extractedDataAreDirty) extractData();
@ -138,7 +138,7 @@ class KLU : public SparseSolverBase<KLU<_MatrixType> >
if (m_extractedDataAreDirty) extractData();
return m_q;
}
#endif
/** Computes the sparse Cholesky decomposition of \a matrix
* Note that the matrix should be column-major, and in compressed format for best performance.
* \sa SparseMatrix::makeCompressed().
@ -213,9 +213,11 @@ class KLU : public SparseSolverBase<KLU<_MatrixType> >
template<typename BDerived,typename XDerived>
bool _solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const;
#if 0 // not implemented yet
Scalar determinant() const;
void extractData() const;
#endif
protected:
@ -275,11 +277,12 @@ class KLU : public SparseSolverBase<KLU<_MatrixType> >
}
// cached data to reduce reallocation, etc.
#if 0 // not implemented yet
mutable LUMatrixType m_l;
mutable LUMatrixType m_u;
mutable IntColVectorType m_p;
mutable IntRowVectorType m_q;
#endif
KLUMatrixType m_dummy;
KLUMatrixRef mp_matrix;
@ -296,7 +299,7 @@ class KLU : public SparseSolverBase<KLU<_MatrixType> >
KLU(const KLU& ) { }
};
#if 0 // not implemented yet
template<typename MatrixType>
void KLU<MatrixType>::extractData() const
{
@ -304,26 +307,26 @@ void KLU<MatrixType>::extractData() const
{
eigen_assert(false && "KLU: extractData Not Yet Implemented");
// // get size of the data
// int lnz, unz, rows, cols, nz_udiag;
// umfpack_get_lunz(&lnz, &unz, &rows, &cols, &nz_udiag, m_numeric, Scalar());
//
// // allocate data
// m_l.resize(rows,(std::min)(rows,cols));
// m_l.resizeNonZeros(lnz);
//
// m_u.resize((std::min)(rows,cols),cols);
// m_u.resizeNonZeros(unz);
//
// m_p.resize(rows);
// m_q.resize(cols);
//
// // extract
// umfpack_get_numeric(m_l.outerIndexPtr(), m_l.innerIndexPtr(), m_l.valuePtr(),
// m_u.outerIndexPtr(), m_u.innerIndexPtr(), m_u.valuePtr(),
// m_p.data(), m_q.data(), 0, 0, 0, m_numeric);
//
// m_extractedDataAreDirty = false;
// get size of the data
int lnz, unz, rows, cols, nz_udiag;
umfpack_get_lunz(&lnz, &unz, &rows, &cols, &nz_udiag, m_numeric, Scalar());
// allocate data
m_l.resize(rows,(std::min)(rows,cols));
m_l.resizeNonZeros(lnz);
m_u.resize((std::min)(rows,cols),cols);
m_u.resizeNonZeros(unz);
m_p.resize(rows);
m_q.resize(cols);
// extract
umfpack_get_numeric(m_l.outerIndexPtr(), m_l.innerIndexPtr(), m_l.valuePtr(),
m_u.outerIndexPtr(), m_u.innerIndexPtr(), m_u.valuePtr(),
m_p.data(), m_q.data(), 0, 0, 0, m_numeric);
m_extractedDataAreDirty = false;
}
}
@ -333,27 +336,18 @@ typename KLU<MatrixType>::Scalar KLU<MatrixType>::determinant() const
eigen_assert(false && "KLU: extractData Not Yet Implemented");
return Scalar();
}
#endif
template<typename MatrixType>
template<typename BDerived,typename XDerived>
bool KLU<MatrixType>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBase<XDerived> &x) const
{
Index rhsCols = b.cols();
eigen_assert((BDerived::Flags&RowMajorBit)==0 && "KLU backend does not support non col-major rhs yet");
eigen_assert((XDerived::Flags&RowMajorBit)==0 && "KLU backend does not support non col-major result yet");
eigen_assert(b.derived().data() != x.derived().data() && " KLU does not support inplace solve");
EIGEN_STATIC_ASSERT((XDerived::Flags&RowMajorBit)==0, THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or analyzePattern()/factorize()");
x = b;
int info = 0;
if (true/*(MatrixType::Flags&RowMajorBit) == 0*/)
{
info = klu_solve(m_symbolic, m_numeric, b.rows(), rhsCols, x.const_cast_derived().data(), const_cast<klu_common*>(&m_common), Scalar());
}
else
{
info = klu_tsolve(m_symbolic, m_numeric, b.rows(), rhsCols, x.const_cast_derived().data(), const_cast<klu_common*>(&m_common), Scalar());
}
int info = klu_solve(m_symbolic, m_numeric, b.rows(), rhsCols, x.const_cast_derived().data(), const_cast<klu_common*>(&m_common), Scalar());
m_info = info!=0 ? Success : NumericalIssue;
return true;