bug #897: fix regression in BiCGSTAB(mat) ctor (an all other iterative solvers).

Add respective regression unit test.
This commit is contained in:
Gael Guennebaud 2015-02-16 17:05:10 +01:00
parent 470d26d580
commit 0f464d9d87
2 changed files with 23 additions and 4 deletions

View File

@ -54,9 +54,10 @@ public:
*/
template<typename SparseMatrixDerived>
explicit IterativeSolverBase(const SparseMatrixBase<SparseMatrixDerived>& A)
: mp_matrix(A)
{
init();
compute(A.derived());
compute(mp_matrix);
}
~IterativeSolverBase() {}
@ -69,7 +70,7 @@ public:
template<typename SparseMatrixDerived>
Derived& analyzePattern(const SparseMatrixBase<SparseMatrixDerived>& A)
{
grab(A);
grab(A.derived());
m_preconditioner.analyzePattern(mp_matrix);
m_isInitialized = true;
m_analysisIsOk = true;
@ -90,7 +91,7 @@ public:
Derived& factorize(const SparseMatrixBase<SparseMatrixDerived>& A)
{
eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
grab(A);
grab(A.derived());
m_preconditioner.factorize(mp_matrix);
m_factorizationIsOk = true;
m_info = Success;
@ -110,7 +111,7 @@ public:
template<typename SparseMatrixDerived>
Derived& compute(const SparseMatrixBase<SparseMatrixDerived>& A)
{
grab(A);
grab(A.derived());
m_preconditioner.compute(mp_matrix);
m_isInitialized = true;
m_analysisIsOk = true;
@ -229,6 +230,15 @@ protected:
::new (&mp_matrix) Ref<const MatrixType>(A);
}
void grab(const Ref<const MatrixType> &A)
{
if(&(A.derived()) != &mp_matrix)
{
mp_matrix.~Ref<const MatrixType>();
::new (&mp_matrix) Ref<const MatrixType>(A);
}
}
MatrixType m_dummy;
Ref<const MatrixType> mp_matrix;
Preconditioner m_preconditioner;

View File

@ -83,6 +83,15 @@ void check_sparse_solving(Solver& solver, const typename Solver::MatrixType& A,
VERIFY(xm.isApprox(refX,test_precision<Scalar>()));
}
// test initialization ctor
{
Rhs x(b.rows(), b.cols());
Solver solver2(A);
VERIFY(solver2.info() == Success);
x = solver2.solve(b);
VERIFY(x.isApprox(refX,test_precision<Scalar>()));
}
// test dense Block as the result and rhs:
{
DenseRhs x(db.rows(), db.cols());