Generalize Cholmod support to hanlde any sparse type as the rhs and result of the solve method

This commit is contained in:
Gael Guennebaud 2016-11-06 20:29:23 +01:00
parent afc55b1885
commit 436a111792

View File

@ -55,7 +55,7 @@ template<> struct cholmod_configure_matrix<std::complex<double> > {
* Note that the data are shared. * Note that the data are shared.
*/ */
template<typename _Scalar, int _Options, typename _StorageIndex> template<typename _Scalar, int _Options, typename _StorageIndex>
cholmod_sparse viewAsCholmod(SparseMatrix<_Scalar,_Options,_StorageIndex>& mat) cholmod_sparse viewAsCholmod(Ref<SparseMatrix<_Scalar,_Options,_StorageIndex> > mat)
{ {
cholmod_sparse res; cholmod_sparse res;
res.nzmax = mat.nonZeros(); res.nzmax = mat.nonZeros();
@ -104,7 +104,14 @@ cholmod_sparse viewAsCholmod(SparseMatrix<_Scalar,_Options,_StorageIndex>& mat)
template<typename _Scalar, int _Options, typename _Index> template<typename _Scalar, int _Options, typename _Index>
const cholmod_sparse viewAsCholmod(const SparseMatrix<_Scalar,_Options,_Index>& mat) const cholmod_sparse viewAsCholmod(const SparseMatrix<_Scalar,_Options,_Index>& mat)
{ {
cholmod_sparse res = viewAsCholmod(mat.const_cast_derived()); cholmod_sparse res = viewAsCholmod(Ref<SparseMatrix<_Scalar,_Options,_Index> >(mat.const_cast_derived()));
return res;
}
template<typename _Scalar, int _Options, typename _Index>
const cholmod_sparse viewAsCholmod(const SparseVector<_Scalar,_Options,_Index>& mat)
{
cholmod_sparse res = viewAsCholmod(Ref<SparseMatrix<_Scalar,_Options,_Index> >(mat.const_cast_derived()));
return res; return res;
} }
@ -113,7 +120,7 @@ const cholmod_sparse viewAsCholmod(const SparseMatrix<_Scalar,_Options,_Index>&
template<typename _Scalar, int _Options, typename _Index, unsigned int UpLo> template<typename _Scalar, int _Options, typename _Index, unsigned int UpLo>
cholmod_sparse viewAsCholmod(const SparseSelfAdjointView<const SparseMatrix<_Scalar,_Options,_Index>, UpLo>& mat) cholmod_sparse viewAsCholmod(const SparseSelfAdjointView<const SparseMatrix<_Scalar,_Options,_Index>, UpLo>& mat)
{ {
cholmod_sparse res = viewAsCholmod(mat.matrix().const_cast_derived()); cholmod_sparse res = viewAsCholmod(Ref<SparseMatrix<_Scalar,_Options,_Index> >(mat.matrix().const_cast_derived()));
if(UpLo==Upper) res.stype = 1; if(UpLo==Upper) res.stype = 1;
if(UpLo==Lower) res.stype = -1; if(UpLo==Lower) res.stype = -1;
@ -298,8 +305,8 @@ class CholmodBase : public SparseSolverBase<Derived>
} }
/** \internal */ /** \internal */
template<typename RhsScalar, int RhsOptions, typename RhsIndex, typename DestScalar, int DestOptions, typename DestIndex> template<typename RhsDerived, typename DestDerived>
void _solve_impl(const SparseMatrix<RhsScalar,RhsOptions,RhsIndex> &b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const void _solve_impl(const SparseMatrixBase<RhsDerived> &b, SparseMatrixBase<DestDerived> &dest) const
{ {
eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()"); eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()");
const Index size = m_cholmodFactor->n; const Index size = m_cholmodFactor->n;
@ -307,7 +314,8 @@ class CholmodBase : public SparseSolverBase<Derived>
eigen_assert(size==b.rows()); eigen_assert(size==b.rows());
// note: cs stands for Cholmod Sparse // note: cs stands for Cholmod Sparse
cholmod_sparse b_cs = viewAsCholmod(b); Ref<SparseMatrix<typename RhsDerived::Scalar,ColMajor,typename RhsDerived::StorageIndex> > b_ref(b.const_cast_derived());
cholmod_sparse b_cs = viewAsCholmod(b_ref);
cholmod_sparse* x_cs = cholmod_spsolve(CHOLMOD_A, m_cholmodFactor, &b_cs, &m_cholmod); cholmod_sparse* x_cs = cholmod_spsolve(CHOLMOD_A, m_cholmodFactor, &b_cs, &m_cholmod);
if(!x_cs) if(!x_cs)
{ {
@ -315,7 +323,7 @@ class CholmodBase : public SparseSolverBase<Derived>
return; return;
} }
// TODO optimize this copy by swapping when possible (be careful with alignment, etc.) // TODO optimize this copy by swapping when possible (be careful with alignment, etc.)
dest = viewAsEigen<DestScalar,DestOptions,DestIndex>(*x_cs); dest.derived() = viewAsEigen<typename DestDerived::Scalar,ColMajor,typename DestDerived::StorageIndex>(*x_cs);
cholmod_free_sparse(&x_cs, &m_cholmod); cholmod_free_sparse(&x_cs, &m_cholmod);
} }
#endif // EIGEN_PARSED_BY_DOXYGEN #endif // EIGEN_PARSED_BY_DOXYGEN