Fix elimination tree and SparseQR with rows<cols

This commit is contained in:
Gael Guennebaud 2013-09-12 22:16:35 +02:00
parent 4612a1cd87
commit 1b4623e713
2 changed files with 9 additions and 6 deletions

View File

@ -63,6 +63,7 @@ int coletree(const MatrixType& mat, IndexVector& parent, IndexVector& firstRowEl
typedef typename MatrixType::Index Index;
Index nc = mat.cols(); // Number of columns
Index m = mat.rows();
Index diagSize = (std::min)(nc,m);
IndexVector root(nc); // root of subtree of etree
root.setZero();
IndexVector pp(nc); // disjoint sets
@ -72,7 +73,7 @@ int coletree(const MatrixType& mat, IndexVector& parent, IndexVector& firstRowEl
Index row,col;
firstRowElt.resize(m);
firstRowElt.setConstant(nc);
firstRowElt.segment(0, nc).setLinSpaced(nc, 0, nc-1);
firstRowElt.segment(0, diagSize).setLinSpaced(diagSize, 0, diagSize-1);
bool found_diag;
for (col = 0; col < nc; col++)
{
@ -91,7 +92,7 @@ int coletree(const MatrixType& mat, IndexVector& parent, IndexVector& firstRowEl
Index rset, cset, rroot;
for (col = 0; col < nc; col++)
{
found_diag = false;
found_diag = col>=m;
pp(col) = col;
cset = col;
root(cset) = col;
@ -105,6 +106,7 @@ int coletree(const MatrixType& mat, IndexVector& parent, IndexVector& firstRowEl
Index i = col;
if(it) i = it.index();
if (i == col) found_diag = true;
row = firstRowElt(i);
if (row >= col) continue;
rset = internal::etree_find(row, pp); // Find the name of the set containing row

View File

@ -161,8 +161,9 @@ class SparseQR
b = y;
// Solve with the triangular matrix R
y.resize((std::max)(cols(),Index(y.rows())),y.cols());
y.topRows(rank) = this->matrixR().topLeftCorner(rank, rank).template triangularView<Upper>().solve(b.topRows(rank));
y.bottomRows(y.size()-rank).setZero();
y.bottomRows(y.rows()-rank).setZero();
// Apply the column permutation
if (m_perm_c.size()) dest.topRows(cols()) = colsPermutation() * y.topRows(cols());
@ -246,7 +247,7 @@ class SparseQR
Index m_nonzeropivots; // Number of non zero pivots found
IndexVector m_etree; // Column elimination tree
IndexVector m_firstRowElt; // First element in each row
bool m_isQSorted; // whether Q is sorted or not
bool m_isQSorted; // whether Q is sorted or not
template <typename, typename > friend struct SparseQR_QProduct;
template <typename > friend struct SparseQRMatrixQReturnType;
@ -338,7 +339,7 @@ void SparseQR<MatrixType,OrderingType>::factorize(const MatrixType& mat)
Index nonzeroCol = 0; // Record the number of valid pivots
// Left looking rank-revealing QR factorization: compute a column of R and Q at a time
for (Index col = 0; col < n; ++col)
for (Index col = 0; col < (std::min)(n,m); ++col)
{
mark.setConstant(-1);
m_R.startVec(col);
@ -346,7 +347,7 @@ void SparseQR<MatrixType,OrderingType>::factorize(const MatrixType& mat)
mark(nonzeroCol) = col;
Qidx(0) = nonzeroCol;
nzcolR = 0; nzcolQ = 1;
found_diag = false;
found_diag = col>=m;
tval.setZero();
// Symbolic factorization: find the nonzero locations of the column k of the factors R and Q, i.e.,