From 00d7f8e567667941ffded734dcee800f66b43bae Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Sun, 25 Jan 2009 23:46:51 +0000 Subject: [PATCH] * solveTriangularInPlace(): take a const ref and const_cast it, to allow passing temporary xprs. * improvements, simplifications in LU::solve() * remove remnant of old norm2() --- Eigen/src/Core/MatrixBase.h | 3 +-- Eigen/src/Core/SolveTriangular.h | 6 +++++- Eigen/src/LU/LU.h | 32 +++++++++++--------------------- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index 5281e34fa..1fd02b0af 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -344,13 +344,12 @@ template class MatrixBase solveTriangular(const MatrixBase& other) const; template - void solveTriangularInPlace(MatrixBase& other) const; + void solveTriangularInPlace(const MatrixBase& other) const; template Scalar dot(const MatrixBase& other) const; RealScalar squaredNorm() const; - RealScalar norm2() const; RealScalar norm() const; const PlainMatrixType normalized() const; void normalize(); diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index 1c586b865..e39353275 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -221,13 +221,17 @@ struct ei_solve_triangular_selector }; /** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other + * + * The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. + * This function will const_cast it, so constness isn't honored here. * * See MatrixBase:solveTriangular() for the details. */ template template -void MatrixBase::solveTriangularInPlace(MatrixBase& other) const +void MatrixBase::solveTriangularInPlace(const MatrixBase& _other) const { + MatrixBase& other = _other.const_cast_derived(); ei_assert(derived().cols() == derived().rows()); ei_assert(derived().cols() == other.rows()); ei_assert(!(Flags & ZeroDiagBit)); diff --git a/Eigen/src/LU/LU.h b/Eigen/src/LU/LU.h index 00a6dcf64..a48ee8d1a 100644 --- a/Eigen/src/LU/LU.h +++ b/Eigen/src/LU/LU.h @@ -474,13 +474,13 @@ bool LU::solve( * So we proceed as follows: * Step 1: compute c = Pb. * Step 2: replace c by the solution x to Lx = c. Exists because L is invertible. - * Step 3: compute d such that Ud = c. Check if such d really exists. - * Step 4: result = Qd; + * Step 3: replace c by the solution x to Ux = c. Check if a solution really exists. + * Step 4: result = Qc; */ const int rows = m_lu.rows(), cols = m_lu.cols(); ei_assert(b.rows() == rows); - const int smalldim = std::min(rows, m_lu.cols()); + const int smalldim = std::min(rows, cols); typename OtherDerived::PlainMatrixType c(b.rows(), b.cols()); @@ -488,19 +488,13 @@ bool LU::solve( for(int i = 0; i < rows; ++i) c.row(m_p.coeff(i)) = b.row(i); // Step 2 - if(rows <= cols) - m_lu.corner(Eigen::TopLeft,rows,smalldim).template marked().solveTriangularInPlace(c); - else + m_lu.corner(Eigen::TopLeft,smalldim,smalldim).template marked() + .solveTriangularInPlace( + c.corner(Eigen::TopLeft, smalldim, c.cols())); + if(rows>cols) { - // construct the L matrix. We shouldn't do that everytime, it is a very large overhead in the case of vector solving. - // However the case rows>cols is rather unusual with LU so this is probably not a huge priority. - Matrix l(rows, rows); - l.setZero(); - l.corner(Eigen::TopLeft,rows,smalldim) = m_lu.corner(Eigen::TopLeft,rows,smalldim); - l.template marked().solveTriangularInPlace(c); + c.corner(Eigen::BottomLeft, rows-cols, c.cols()) + -= m_lu.corner(Eigen::BottomLeft, rows-cols, cols) * c.corner(Eigen::TopLeft, cols, c.cols()); } // Step 3 @@ -513,17 +507,13 @@ bool LU::solve( if(!ei_isMuchSmallerThan(c.coeff(row,col), biggest_in_c)) return false; } - Matrix - d(c.corner(TopLeft, m_rank, c.cols())); m_lu.corner(TopLeft, m_rank, m_rank) .template marked() - .solveTriangularInPlace(d); + .solveTriangularInPlace(c.corner(TopLeft, m_rank, c.cols())); // Step 4 result->resize(m_lu.cols(), b.cols()); - for(int i = 0; i < m_rank; ++i) result->row(m_q.coeff(i)) = d.row(i); + for(int i = 0; i < m_rank; ++i) result->row(m_q.coeff(i)) = c.row(i); for(int i = m_rank; i < m_lu.cols(); ++i) result->row(m_q.coeff(i)).setZero(); return true; }