mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-18 14:34:17 +08:00
SVD::solve() : port to new API and improvements
This commit is contained in:
parent
6b48e932e9
commit
f975b9bd3e
@ -200,7 +200,7 @@ template<typename MatrixType> class FullPivLU
|
||||
return ei_fullpivlu_image_impl<MatrixType>(*this, originalMatrix.derived());
|
||||
}
|
||||
|
||||
/** This method returns a solution x to the equation Ax=b, where A is the matrix of which
|
||||
/** \return a solution x to the equation Ax=b, where A is the matrix of which
|
||||
* *this is the LU decomposition.
|
||||
*
|
||||
* \param b the right-hand-side of the equation to solve. Can be a vector or a matrix,
|
||||
|
@ -25,6 +25,8 @@
|
||||
#ifndef EIGEN_SVD_H
|
||||
#define EIGEN_SVD_H
|
||||
|
||||
template<typename MatrixType, typename Rhs> struct ei_svd_solve_impl;
|
||||
|
||||
/** \ingroup SVD_Module
|
||||
* \nonstableyet
|
||||
*
|
||||
@ -40,24 +42,24 @@
|
||||
*/
|
||||
template<typename MatrixType> class SVD
|
||||
{
|
||||
private:
|
||||
public:
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
|
||||
|
||||
enum {
|
||||
RowsAtCompileTime = MatrixType::RowsAtCompileTime,
|
||||
ColsAtCompileTime = MatrixType::ColsAtCompileTime,
|
||||
PacketSize = ei_packet_traits<Scalar>::size,
|
||||
AlignmentMask = int(PacketSize)-1,
|
||||
MinSize = EIGEN_ENUM_MIN(MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime)
|
||||
MinSize = EIGEN_ENUM_MIN(RowsAtCompileTime, ColsAtCompileTime)
|
||||
};
|
||||
|
||||
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVector;
|
||||
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> RowVector;
|
||||
typedef Matrix<Scalar, RowsAtCompileTime, 1> ColVector;
|
||||
typedef Matrix<Scalar, ColsAtCompileTime, 1> RowVector;
|
||||
|
||||
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> MatrixUType;
|
||||
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, MatrixType::ColsAtCompileTime> MatrixVType;
|
||||
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> SingularValuesType;
|
||||
|
||||
public:
|
||||
typedef Matrix<Scalar, RowsAtCompileTime, RowsAtCompileTime> MatrixUType;
|
||||
typedef Matrix<Scalar, ColsAtCompileTime, ColsAtCompileTime> MatrixVType;
|
||||
typedef Matrix<Scalar, ColsAtCompileTime, 1> SingularValuesType;
|
||||
|
||||
/**
|
||||
* \brief Default Constructor.
|
||||
@ -76,8 +78,24 @@ template<typename MatrixType> class SVD
|
||||
compute(matrix);
|
||||
}
|
||||
|
||||
template<typename OtherDerived, typename ResultType>
|
||||
bool solve(const MatrixBase<OtherDerived> &b, ResultType* result) const;
|
||||
/** \returns a solution of \f$ A x = b \f$ using the current SVD decomposition of A.
|
||||
*
|
||||
* \param b the right-hand-side of the equation to solve.
|
||||
*
|
||||
* \note_about_checking_solutions
|
||||
*
|
||||
* \note_about_arbitrary_choice_of_solution
|
||||
* \note_about_using_kernel_to_study_multiple_solutions
|
||||
*
|
||||
* \sa MatrixBase::svd(),
|
||||
*/
|
||||
template<typename Rhs>
|
||||
inline const ei_svd_solve_impl<MatrixType, Rhs>
|
||||
solve(const MatrixBase<Rhs>& b) const
|
||||
{
|
||||
ei_assert(m_isInitialized && "SVD is not initialized.");
|
||||
return ei_svd_solve_impl<MatrixType, Rhs>(*this, b.derived());
|
||||
}
|
||||
|
||||
const MatrixUType& matrixU() const
|
||||
{
|
||||
@ -108,6 +126,18 @@ template<typename MatrixType> class SVD
|
||||
template<typename ScalingType, typename RotationType>
|
||||
void computeScalingRotation(ScalingType *positive, RotationType *unitary) const;
|
||||
|
||||
inline int rows() const
|
||||
{
|
||||
ei_assert(m_isInitialized && "SVD is not initialized.");
|
||||
return m_rows;
|
||||
}
|
||||
|
||||
inline int cols() const
|
||||
{
|
||||
ei_assert(m_isInitialized && "SVD is not initialized.");
|
||||
return m_cols;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Computes (a^2 + b^2)^(1/2) without destructive underflow or overflow.
|
||||
inline static Scalar pythag(Scalar a, Scalar b)
|
||||
@ -133,6 +163,7 @@ template<typename MatrixType> class SVD
|
||||
/** \internal */
|
||||
SingularValuesType m_sigma;
|
||||
bool m_isInitialized;
|
||||
int m_rows, m_cols;
|
||||
};
|
||||
|
||||
/** Computes / recomputes the SVD decomposition A = U S V^* of \a matrix
|
||||
@ -144,8 +175,8 @@ template<typename MatrixType> class SVD
|
||||
template<typename MatrixType>
|
||||
SVD<MatrixType>& SVD<MatrixType>::compute(const MatrixType& matrix)
|
||||
{
|
||||
const int m = matrix.rows();
|
||||
const int n = matrix.cols();
|
||||
const int m = m_rows = matrix.rows();
|
||||
const int n = m_cols = matrix.cols();
|
||||
|
||||
m_matU.resize(m, m);
|
||||
m_matU.setZero();
|
||||
@ -397,40 +428,57 @@ SVD<MatrixType>& SVD<MatrixType>::compute(const MatrixType& matrix)
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** \returns the solution of \f$ A x = b \f$ using the current SVD decomposition of A.
|
||||
* The parts of the solution corresponding to zero singular values are ignored.
|
||||
*
|
||||
* \sa MatrixBase::svd(), LU::solve(), LLT::solve()
|
||||
*/
|
||||
template<typename MatrixType>
|
||||
template<typename OtherDerived, typename ResultType>
|
||||
bool SVD<MatrixType>::solve(const MatrixBase<OtherDerived> &b, ResultType* result) const
|
||||
template<typename MatrixType,typename Rhs>
|
||||
struct ei_traits<ei_svd_solve_impl<MatrixType,Rhs> >
|
||||
{
|
||||
ei_assert(m_isInitialized && "SVD is not initialized.");
|
||||
typedef Matrix<typename Rhs::Scalar,
|
||||
MatrixType::ColsAtCompileTime,
|
||||
Rhs::ColsAtCompileTime,
|
||||
Rhs::PlainMatrixType::Options,
|
||||
MatrixType::MaxColsAtCompileTime,
|
||||
Rhs::MaxColsAtCompileTime> ReturnMatrixType;
|
||||
};
|
||||
|
||||
const int rows = m_matU.rows();
|
||||
ei_assert(b.rows() == rows);
|
||||
template<typename MatrixType, typename Rhs>
|
||||
struct ei_svd_solve_impl : public ReturnByValue<ei_svd_solve_impl<MatrixType, Rhs> >
|
||||
{
|
||||
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
|
||||
typedef SVD<MatrixType> SVDType;
|
||||
typedef typename MatrixType::RealScalar RealScalar;
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
const SVDType& m_svd;
|
||||
const typename Rhs::Nested m_rhs;
|
||||
|
||||
result->resize(m_matV.rows(), b.cols());
|
||||
ei_svd_solve_impl(const SVDType& svd, const Rhs& rhs)
|
||||
: m_svd(svd), m_rhs(rhs)
|
||||
{}
|
||||
|
||||
Scalar maxVal = m_sigma.cwise().abs().maxCoeff();
|
||||
for (int j=0; j<b.cols(); ++j)
|
||||
inline int rows() const { return m_svd.cols(); }
|
||||
inline int cols() const { return m_rhs.cols(); }
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst) const
|
||||
{
|
||||
Matrix<Scalar,MatrixUType::RowsAtCompileTime,1> aux = m_matU.transpose() * b.col(j);
|
||||
ei_assert(m_rhs.rows() == m_svd.rows());
|
||||
|
||||
for (int i = 0; i <m_matU.cols(); ++i)
|
||||
dst.resize(rows(), cols());
|
||||
|
||||
for (int j=0; j<cols(); ++j)
|
||||
{
|
||||
Scalar si = m_sigma.coeff(i);
|
||||
if (ei_isMuchSmallerThan(ei_abs(si),maxVal))
|
||||
aux.coeffRef(i) = 0;
|
||||
else
|
||||
aux.coeffRef(i) /= si;
|
||||
}
|
||||
Matrix<Scalar,SVDType::RowsAtCompileTime,1> aux = m_svd.matrixU().adjoint() * m_rhs.col(j);
|
||||
|
||||
result->col(j) = m_matV * aux;
|
||||
for (int i = 0; i <m_svd.rows(); ++i)
|
||||
{
|
||||
Scalar si = m_svd.singularValues().coeff(i);
|
||||
if(si == RealScalar(0))
|
||||
aux.coeffRef(i) = Scalar(0);
|
||||
else
|
||||
aux.coeffRef(i) /= si;
|
||||
}
|
||||
|
||||
dst.col(j) = m_svd.matrixV() * aux;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/** Computes the polar decomposition of the matrix, as a product unitary x positive.
|
||||
*
|
||||
|
@ -59,7 +59,7 @@ template<typename MatrixType> void svd(const MatrixType& m)
|
||||
a += a * a.adjoint() + a1 * a1.adjoint();
|
||||
}
|
||||
SVD<MatrixType> svd(a);
|
||||
svd.solve(b, &x);
|
||||
x = svd.solve(b);
|
||||
VERIFY_IS_APPROX(a * x,b);
|
||||
}
|
||||
|
||||
@ -87,7 +87,7 @@ template<typename MatrixType> void svd_verify_assert()
|
||||
MatrixType tmp;
|
||||
|
||||
SVD<MatrixType> svd;
|
||||
VERIFY_RAISES_ASSERT(svd.solve(tmp, &tmp))
|
||||
VERIFY_RAISES_ASSERT(svd.solve(tmp))
|
||||
VERIFY_RAISES_ASSERT(svd.matrixU())
|
||||
VERIFY_RAISES_ASSERT(svd.singularValues())
|
||||
VERIFY_RAISES_ASSERT(svd.matrixV())
|
||||
|
Loading…
Reference in New Issue
Block a user