add a TridiagonalizationMatrixTReturnType class to make Tridiagonalization::matrixT() more efficient and future proof.

This commit is contained in:
Gael Guennebaud 2010-11-26 15:31:47 +01:00
parent 421b2b5ff7
commit 0d63212257
2 changed files with 62 additions and 26 deletions

View File

@ -26,6 +26,16 @@
#ifndef EIGEN_TRIDIAGONALIZATION_H
#define EIGEN_TRIDIAGONALIZATION_H
template<typename MatrixType> struct TridiagonalizationMatrixTReturnType;
namespace internal {
template<typename MatrixType>
struct traits<TridiagonalizationMatrixTReturnType<MatrixType> >
{
typedef typename MatrixType::PlainObject ReturnType;
};
}
namespace internal {
template<typename MatrixType, typename CoeffVectorType>
void tridiagonalization_inplace(MatrixType& matA, CoeffVectorType& hCoeffs);
@ -85,6 +95,7 @@ template<typename _MatrixType> class Tridiagonalization
typedef Matrix<Scalar, SizeMinusOne, 1, Options & ~RowMajor, MaxSizeMinusOne, 1> CoeffVectorType;
typedef typename internal::plain_col_type<MatrixType, RealScalar>::type DiagonalType;
typedef Matrix<RealScalar, SizeMinusOne, 1, Options & ~RowMajor, MaxSizeMinusOne, 1> SubDiagonalType;
typedef typename internal::remove_all<typename MatrixType::RealReturnType>::type MatrixTypeRealView;
typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
typename Diagonal<MatrixType,0>::RealReturnType,
@ -244,24 +255,28 @@ template<typename _MatrixType> class Tridiagonalization
return HouseholderSequenceType(m_matrix, m_hCoeffs.conjugate(), false, m_matrix.rows() - 1, 1);
}
/** \brief Constructs the tridiagonal matrix T in the decomposition
/** \brief Returns an expression of the tridiagonal matrix T in the decomposition
*
* \returns the matrix T
* \returns expression object representing the matrix T
*
* \pre Either the constructor Tridiagonalization(const MatrixType&) or
* the member function compute(const MatrixType&) has been called before
* to compute the tridiagonal decomposition of a matrix.
*
* This function copies the matrix T from internal data. The diagonal and
* subdiagonal of the packed matrix as returned by packedMatrix()
* represents the matrix T. It may sometimes be sufficient to directly use
* the packed matrix or the vector expressions returned by diagonal()
* and subDiagonal() instead of creating a new matrix with this function.
* Currently, this function can be used to extract the matrix T from internal
* data and copy it to a dense matrix object. In most cases, it may be
* sufficient to directly use the packed matrix or the vector expressions
* returned by diagonal() and subDiagonal() instead of creating a new
* dense copy matrix with this function.
*
* \sa Tridiagonalization(const MatrixType&) for an example,
* matrixQ(), packedMatrix(), diagonal(), subDiagonal()
*/
MatrixType matrixT() const;
TridiagonalizationMatrixTReturnType<MatrixTypeRealView> matrixT() const
{
eigen_assert(m_isInitialized && "Tridiagonalization is not initialized.");
return TridiagonalizationMatrixTReturnType<MatrixTypeRealView>(m_matrix.real());
}
/** \brief Returns the diagonal of the tridiagonal matrix T in the decomposition.
*
@ -314,24 +329,6 @@ Tridiagonalization<MatrixType>::subDiagonal() const
return Block<MatrixType,SizeMinusOne,SizeMinusOne>(m_matrix, 1, 0, n-1,n-1).diagonal();
}
template<typename MatrixType>
typename Tridiagonalization<MatrixType>::MatrixType
Tridiagonalization<MatrixType>::matrixT() const
{
// FIXME should this function (and other similar ones) rather take a matrix as argument
// and fill it ? (to avoid temporaries)
eigen_assert(m_isInitialized && "Tridiagonalization is not initialized.");
Index n = m_matrix.rows();
MatrixType matT = m_matrix;
matT.topRightCorner(n-1, n-1).diagonal() = subDiagonal().template cast<Scalar>().conjugate();
if (n>2)
{
matT.topRightCorner(n-2, n-2).template triangularView<Upper>().setZero();
matT.bottomLeftCorner(n-2, n-2).template triangularView<Lower>().setZero();
}
return matT;
}
namespace internal {
/** \internal
@ -530,4 +527,38 @@ struct tridiagonalization_inplace_selector<MatrixType,1,IsComplex>
} // end namespace internal
/** \eigenvalues_module \ingroup Eigenvalues_Module
*
*
* \brief Expression type for return value of Tridiagonalization::matrixT()
*
* \tparam MatrixType type of underlying dense matrix
*/
template<typename MatrixType> struct TridiagonalizationMatrixTReturnType
: public ReturnByValue<TridiagonalizationMatrixTReturnType<MatrixType> >
{
typedef typename MatrixType::Index Index;
public:
/** \brief Constructor.
*
* \param[in] mat The underlying dense matrix
*/
TridiagonalizationMatrixTReturnType(const MatrixType& mat) : m_matrix(mat) { }
template <typename ResultType>
inline void evalTo(ResultType& result) const
{
result.setZero();
result.template diagonal<1>() = m_matrix.template diagonal<-1>().conjugate();
result.template diagonal() = m_matrix.template diagonal();
result.template diagonal<-1>() = m_matrix.template diagonal<-1>();
}
Index rows() const { return m_matrix.rows(); }
Index cols() const { return m_matrix.cols(); }
protected:
const typename MatrixType::Nested m_matrix;
};
#endif // EIGEN_TRIDIAGONALIZATION_H

View File

@ -155,6 +155,11 @@ template<typename MatrixType> void selfadjointeigensolver(const MatrixType& m)
VERIFY_RAISES_ASSERT(eiSymmUninitialized.operatorSqrt());
VERIFY_RAISES_ASSERT(eiSymmUninitialized.operatorInverseSqrt());
// test Tridiagonalization's methods
Tridiagonalization<MatrixType> tridiag(symmA);
// FIXME tridiag.matrixQ().adjoint() does not work
VERIFY_IS_APPROX(MatrixType(symmA.template selfadjointView<Lower>()), tridiag.matrixQ() * tridiag.matrixT().eval() * MatrixType(tridiag.matrixQ()).adjoint());
if (rows > 1)
{
// Test matrix with NaN