Fixed sparse conservativeResize() when both num cols and rows decreased.

The previous implementation caused a buffer overflow trying to calculate non-
zero counts for columns that no longer exist.
This commit is contained in:
Adam Shapiro 2021-02-23 21:32:39 +00:00 committed by Antonio Sánchez
parent 10c77b0ff4
commit 2ac0b78739
2 changed files with 19 additions and 9 deletions

View File

@ -579,10 +579,12 @@ class SparseMatrix
else if (innerChange < 0)
{
// Inner size decreased: allocate a new m_innerNonZeros
m_innerNonZeros = static_cast<StorageIndex*>(std::malloc((m_outerSize+outerChange+1) * sizeof(StorageIndex)));
m_innerNonZeros = static_cast<StorageIndex*>(std::malloc((m_outerSize + outerChange) * sizeof(StorageIndex)));
if (!m_innerNonZeros) internal::throw_std_bad_alloc();
for(Index i = 0; i < m_outerSize; i++)
for(Index i = 0; i < m_outerSize + (std::min)(outerChange, Index(0)); i++)
m_innerNonZeros[i] = m_outerIndex[i+1] - m_outerIndex[i];
for(Index i = m_outerSize; i < m_outerSize + outerChange; i++)
m_innerNonZeros[i] = 0;
}
// Change the m_innerNonZeros in case of a decrease of inner size

View File

@ -587,30 +587,38 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
inc.push_back(std::pair<StorageIndex,StorageIndex>(3,2));
inc.push_back(std::pair<StorageIndex,StorageIndex>(3,0));
inc.push_back(std::pair<StorageIndex,StorageIndex>(0,3));
inc.push_back(std::pair<StorageIndex,StorageIndex>(0,-1));
inc.push_back(std::pair<StorageIndex,StorageIndex>(-1,0));
inc.push_back(std::pair<StorageIndex,StorageIndex>(-1,-1));
for(size_t i = 0; i< inc.size(); i++) {
StorageIndex incRows = inc[i].first;
StorageIndex incCols = inc[i].second;
SparseMatrixType m1(rows, cols);
DenseMatrix refMat1 = DenseMatrix::Zero(rows, cols);
initSparse<Scalar>(density, refMat1, m1);
SparseMatrixType m2 = m1;
m2.makeCompressed();
m1.conservativeResize(rows+incRows, cols+incCols);
m2.conservativeResize(rows+incRows, cols+incCols);
refMat1.conservativeResize(rows+incRows, cols+incCols);
if (incRows > 0) refMat1.bottomRows(incRows).setZero();
if (incCols > 0) refMat1.rightCols(incCols).setZero();
VERIFY_IS_APPROX(m1, refMat1);
VERIFY_IS_APPROX(m2, refMat1);
// Insert new values
if (incRows > 0)
m1.insert(m1.rows()-1, 0) = refMat1(refMat1.rows()-1, 0) = 1;
if (incCols > 0)
m1.insert(0, m1.cols()-1) = refMat1(0, refMat1.cols()-1) = 1;
VERIFY_IS_APPROX(m1, refMat1);
}
}