* solveTriangularInPlace(): take a const ref and const_cast it, to allow passing temporary xprs.

* improvements, simplifications in LU::solve()
* remove remnant of old norm2()
This commit is contained in:
Benoit Jacob 2009-01-25 23:46:51 +00:00
parent 414ee1db4b
commit 00d7f8e567
3 changed files with 17 additions and 24 deletions

View File

@ -344,13 +344,12 @@ template<typename Derived> class MatrixBase
solveTriangular(const MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived>
void solveTriangularInPlace(MatrixBase<OtherDerived>& other) const;
void solveTriangularInPlace(const MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived>
Scalar dot(const MatrixBase<OtherDerived>& other) const;
RealScalar squaredNorm() const;
RealScalar norm2() const;
RealScalar norm() const;
const PlainMatrixType normalized() const;
void normalize();

View File

@ -221,13 +221,17 @@ struct ei_solve_triangular_selector<Lhs,Rhs,UpLo,ColMajor|IsDense>
};
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
*
* The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
*
* See MatrixBase:solveTriangular() for the details.
*/
template<typename Derived>
template<typename OtherDerived>
void MatrixBase<Derived>::solveTriangularInPlace(MatrixBase<OtherDerived>& other) const
void MatrixBase<Derived>::solveTriangularInPlace(const MatrixBase<OtherDerived>& _other) const
{
MatrixBase<OtherDerived>& other = _other.const_cast_derived();
ei_assert(derived().cols() == derived().rows());
ei_assert(derived().cols() == other.rows());
ei_assert(!(Flags & ZeroDiagBit));

View File

@ -474,13 +474,13 @@ bool LU<MatrixType>::solve(
* So we proceed as follows:
* Step 1: compute c = Pb.
* Step 2: replace c by the solution x to Lx = c. Exists because L is invertible.
* Step 3: compute d such that Ud = c. Check if such d really exists.
* Step 4: result = Qd;
* Step 3: replace c by the solution x to Ux = c. Check if a solution really exists.
* Step 4: result = Qc;
*/
const int rows = m_lu.rows(), cols = m_lu.cols();
ei_assert(b.rows() == rows);
const int smalldim = std::min(rows, m_lu.cols());
const int smalldim = std::min(rows, cols);
typename OtherDerived::PlainMatrixType c(b.rows(), b.cols());
@ -488,19 +488,13 @@ bool LU<MatrixType>::solve(
for(int i = 0; i < rows; ++i) c.row(m_p.coeff(i)) = b.row(i);
// Step 2
if(rows <= cols)
m_lu.corner(Eigen::TopLeft,rows,smalldim).template marked<UnitLowerTriangular>().solveTriangularInPlace(c);
else
m_lu.corner(Eigen::TopLeft,smalldim,smalldim).template marked<UnitLowerTriangular>()
.solveTriangularInPlace(
c.corner(Eigen::TopLeft, smalldim, c.cols()));
if(rows>cols)
{
// construct the L matrix. We shouldn't do that everytime, it is a very large overhead in the case of vector solving.
// However the case rows>cols is rather unusual with LU so this is probably not a huge priority.
Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime,
MatrixType::Options,
MatrixType::MaxRowsAtCompileTime,
MatrixType::MaxRowsAtCompileTime> l(rows, rows);
l.setZero();
l.corner(Eigen::TopLeft,rows,smalldim) = m_lu.corner(Eigen::TopLeft,rows,smalldim);
l.template marked<UnitLowerTriangular>().solveTriangularInPlace(c);
c.corner(Eigen::BottomLeft, rows-cols, c.cols())
-= m_lu.corner(Eigen::BottomLeft, rows-cols, cols) * c.corner(Eigen::TopLeft, cols, c.cols());
}
// Step 3
@ -513,17 +507,13 @@ bool LU<MatrixType>::solve(
if(!ei_isMuchSmallerThan(c.coeff(row,col), biggest_in_c))
return false;
}
Matrix<Scalar, Dynamic, OtherDerived::ColsAtCompileTime,
MatrixType::Options,
MatrixType::MaxRowsAtCompileTime, OtherDerived::MaxColsAtCompileTime>
d(c.corner(TopLeft, m_rank, c.cols()));
m_lu.corner(TopLeft, m_rank, m_rank)
.template marked<UpperTriangular>()
.solveTriangularInPlace(d);
.solveTriangularInPlace(c.corner(TopLeft, m_rank, c.cols()));
// Step 4
result->resize(m_lu.cols(), b.cols());
for(int i = 0; i < m_rank; ++i) result->row(m_q.coeff(i)) = d.row(i);
for(int i = 0; i < m_rank; ++i) result->row(m_q.coeff(i)) = c.row(i);
for(int i = m_rank; i < m_lu.cols(); ++i) result->row(m_q.coeff(i)).setZero();
return true;
}