SVD::solve() : port to new API and improvements

This commit is contained in:
Benoit Jacob 2009-10-30 08:51:33 -04:00
parent 6b48e932e9
commit f975b9bd3e
3 changed files with 89 additions and 41 deletions

View File

@ -200,7 +200,7 @@ template<typename MatrixType> class FullPivLU
return ei_fullpivlu_image_impl<MatrixType>(*this, originalMatrix.derived());
}
/** This method returns a solution x to the equation Ax=b, where A is the matrix of which
/** \return a solution x to the equation Ax=b, where A is the matrix of which
* *this is the LU decomposition.
*
* \param b the right-hand-side of the equation to solve. Can be a vector or a matrix,

View File

@ -25,6 +25,8 @@
#ifndef EIGEN_SVD_H
#define EIGEN_SVD_H
template<typename MatrixType, typename Rhs> struct ei_svd_solve_impl;
/** \ingroup SVD_Module
* \nonstableyet
*
@ -40,24 +42,24 @@
*/
template<typename MatrixType> class SVD
{
private:
public:
typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
enum {
RowsAtCompileTime = MatrixType::RowsAtCompileTime,
ColsAtCompileTime = MatrixType::ColsAtCompileTime,
PacketSize = ei_packet_traits<Scalar>::size,
AlignmentMask = int(PacketSize)-1,
MinSize = EIGEN_ENUM_MIN(MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime)
MinSize = EIGEN_ENUM_MIN(RowsAtCompileTime, ColsAtCompileTime)
};
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVector;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> RowVector;
typedef Matrix<Scalar, RowsAtCompileTime, 1> ColVector;
typedef Matrix<Scalar, ColsAtCompileTime, 1> RowVector;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> MatrixUType;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, MatrixType::ColsAtCompileTime> MatrixVType;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> SingularValuesType;
public:
typedef Matrix<Scalar, RowsAtCompileTime, RowsAtCompileTime> MatrixUType;
typedef Matrix<Scalar, ColsAtCompileTime, ColsAtCompileTime> MatrixVType;
typedef Matrix<Scalar, ColsAtCompileTime, 1> SingularValuesType;
/**
* \brief Default Constructor.
@ -76,8 +78,24 @@ template<typename MatrixType> class SVD
compute(matrix);
}
template<typename OtherDerived, typename ResultType>
bool solve(const MatrixBase<OtherDerived> &b, ResultType* result) const;
/** \returns a solution of \f$ A x = b \f$ using the current SVD decomposition of A.
*
* \param b the right-hand-side of the equation to solve.
*
* \note_about_checking_solutions
*
* \note_about_arbitrary_choice_of_solution
* \note_about_using_kernel_to_study_multiple_solutions
*
* \sa MatrixBase::svd(),
*/
template<typename Rhs>
inline const ei_svd_solve_impl<MatrixType, Rhs>
solve(const MatrixBase<Rhs>& b) const
{
ei_assert(m_isInitialized && "SVD is not initialized.");
return ei_svd_solve_impl<MatrixType, Rhs>(*this, b.derived());
}
const MatrixUType& matrixU() const
{
@ -108,6 +126,18 @@ template<typename MatrixType> class SVD
template<typename ScalingType, typename RotationType>
void computeScalingRotation(ScalingType *positive, RotationType *unitary) const;
inline int rows() const
{
ei_assert(m_isInitialized && "SVD is not initialized.");
return m_rows;
}
inline int cols() const
{
ei_assert(m_isInitialized && "SVD is not initialized.");
return m_cols;
}
protected:
// Computes (a^2 + b^2)^(1/2) without destructive underflow or overflow.
inline static Scalar pythag(Scalar a, Scalar b)
@ -133,6 +163,7 @@ template<typename MatrixType> class SVD
/** \internal */
SingularValuesType m_sigma;
bool m_isInitialized;
int m_rows, m_cols;
};
/** Computes / recomputes the SVD decomposition A = U S V^* of \a matrix
@ -144,8 +175,8 @@ template<typename MatrixType> class SVD
template<typename MatrixType>
SVD<MatrixType>& SVD<MatrixType>::compute(const MatrixType& matrix)
{
const int m = matrix.rows();
const int n = matrix.cols();
const int m = m_rows = matrix.rows();
const int n = m_cols = matrix.cols();
m_matU.resize(m, m);
m_matU.setZero();
@ -397,40 +428,57 @@ SVD<MatrixType>& SVD<MatrixType>::compute(const MatrixType& matrix)
return *this;
}
/** \returns the solution of \f$ A x = b \f$ using the current SVD decomposition of A.
* The parts of the solution corresponding to zero singular values are ignored.
*
* \sa MatrixBase::svd(), LU::solve(), LLT::solve()
*/
template<typename MatrixType>
template<typename OtherDerived, typename ResultType>
bool SVD<MatrixType>::solve(const MatrixBase<OtherDerived> &b, ResultType* result) const
template<typename MatrixType,typename Rhs>
struct ei_traits<ei_svd_solve_impl<MatrixType,Rhs> >
{
ei_assert(m_isInitialized && "SVD is not initialized.");
typedef Matrix<typename Rhs::Scalar,
MatrixType::ColsAtCompileTime,
Rhs::ColsAtCompileTime,
Rhs::PlainMatrixType::Options,
MatrixType::MaxColsAtCompileTime,
Rhs::MaxColsAtCompileTime> ReturnMatrixType;
};
const int rows = m_matU.rows();
ei_assert(b.rows() == rows);
template<typename MatrixType, typename Rhs>
struct ei_svd_solve_impl : public ReturnByValue<ei_svd_solve_impl<MatrixType, Rhs> >
{
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
typedef SVD<MatrixType> SVDType;
typedef typename MatrixType::RealScalar RealScalar;
typedef typename MatrixType::Scalar Scalar;
const SVDType& m_svd;
const typename Rhs::Nested m_rhs;
result->resize(m_matV.rows(), b.cols());
ei_svd_solve_impl(const SVDType& svd, const Rhs& rhs)
: m_svd(svd), m_rhs(rhs)
{}
Scalar maxVal = m_sigma.cwise().abs().maxCoeff();
for (int j=0; j<b.cols(); ++j)
inline int rows() const { return m_svd.cols(); }
inline int cols() const { return m_rhs.cols(); }
template<typename Dest> void evalTo(Dest& dst) const
{
Matrix<Scalar,MatrixUType::RowsAtCompileTime,1> aux = m_matU.transpose() * b.col(j);
ei_assert(m_rhs.rows() == m_svd.rows());
for (int i = 0; i <m_matU.cols(); ++i)
dst.resize(rows(), cols());
for (int j=0; j<cols(); ++j)
{
Scalar si = m_sigma.coeff(i);
if (ei_isMuchSmallerThan(ei_abs(si),maxVal))
aux.coeffRef(i) = 0;
else
aux.coeffRef(i) /= si;
}
Matrix<Scalar,SVDType::RowsAtCompileTime,1> aux = m_svd.matrixU().adjoint() * m_rhs.col(j);
result->col(j) = m_matV * aux;
for (int i = 0; i <m_svd.rows(); ++i)
{
Scalar si = m_svd.singularValues().coeff(i);
if(si == RealScalar(0))
aux.coeffRef(i) = Scalar(0);
else
aux.coeffRef(i) /= si;
}
dst.col(j) = m_svd.matrixV() * aux;
}
}
return true;
}
};
/** Computes the polar decomposition of the matrix, as a product unitary x positive.
*

View File

@ -59,7 +59,7 @@ template<typename MatrixType> void svd(const MatrixType& m)
a += a * a.adjoint() + a1 * a1.adjoint();
}
SVD<MatrixType> svd(a);
svd.solve(b, &x);
x = svd.solve(b);
VERIFY_IS_APPROX(a * x,b);
}
@ -87,7 +87,7 @@ template<typename MatrixType> void svd_verify_assert()
MatrixType tmp;
SVD<MatrixType> svd;
VERIFY_RAISES_ASSERT(svd.solve(tmp, &tmp))
VERIFY_RAISES_ASSERT(svd.solve(tmp))
VERIFY_RAISES_ASSERT(svd.matrixU())
VERIFY_RAISES_ASSERT(svd.singularValues())
VERIFY_RAISES_ASSERT(svd.matrixV())