Use ReturnByValue to return result of ei_matrix_function(), ...

This commit is contained in:
Jitse Niesen 2010-02-16 16:43:11 +00:00
parent 25019f0836
commit 319bf3130b
6 changed files with 216 additions and 144 deletions

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009 Jitse Niesen <jitse@maths.leeds.ac.uk>
// Copyright (C) 2009, 2010 Jitse Niesen <jitse@maths.leeds.ac.uk>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public

View File

@ -28,66 +28,11 @@
#include "StemFunction.h"
#include "MatrixFunctionAtomic.h"
/** \ingroup MatrixFunctions_Module
*
* \brief Compute a matrix function.
*
* \param[in] M argument of matrix function, should be a square matrix.
* \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x.
* \param[out] result pointer to the matrix in which to store the result, \f$ f(M) \f$.
*
* This function computes \f$ f(A) \f$ and stores the result in the
* matrix pointed to by \p result.
*
* Suppose that \p M is a matrix whose entries have type \c Scalar.
* Then, the second argument, \p f, should be a function with prototype
* \code
* ComplexScalar f(ComplexScalar, int)
* \endcode
* where \c ComplexScalar = \c std::complex<Scalar> if \c Scalar is
* real (e.g., \c float or \c double) and \c ComplexScalar =
* \c Scalar if \c Scalar is complex. The return value of \c f(x,n)
* should be \f$ f^{(n)}(x) \f$, the n-th derivative of f at x.
*
* This routine uses the algorithm described in:
* Philip Davies and Nicholas J. Higham,
* "A Schur-Parlett algorithm for computing matrix functions",
* <em>SIAM J. %Matrix Anal. Applic.</em>, <b>25</b>:464&ndash;485, 2003.
*
* The actual work is done by the MatrixFunction class.
*
* Example: The following program checks that
* \f[ \exp \left[ \begin{array}{ccc}
* 0 & \frac14\pi & 0 \\
* -\frac14\pi & 0 & 0 \\
* 0 & 0 & 0
* \end{array} \right] = \left[ \begin{array}{ccc}
* \frac12\sqrt2 & -\frac12\sqrt2 & 0 \\
* \frac12\sqrt2 & \frac12\sqrt2 & 0 \\
* 0 & 0 & 1
* \end{array} \right]. \f]
* This corresponds to a rotation of \f$ \frac14\pi \f$ radians around
* the z-axis. This is the same example as used in the documentation
* of ei_matrix_exponential().
*
* \include MatrixFunction.cpp
* Output: \verbinclude MatrixFunction.out
*
* Note that the function \c expfn is defined for complex numbers
* \c x, even though the matrix \c A is over the reals. Instead of
* \c expfn, we could also have used StdStemFunctions::exp:
* \code
* ei_matrix_function(A, StdStemFunctions<std::complex<double> >::exp, &B);
* \endcode
*/
template <typename Derived>
EIGEN_STRONG_INLINE void ei_matrix_function(const MatrixBase<Derived>& M,
typename ei_stem_function<typename ei_traits<Derived>::Scalar>::type f,
typename MatrixBase<Derived>::PlainMatrixType* result);
/** \ingroup MatrixFunctions_Module
* \brief Helper class for computing matrix functions.
* \brief Class for computing matrix exponentials.
* \tparam MatrixType type of the argument of the matrix function,
* expected to be an instantiation of the Matrix class template.
*/
template <typename MatrixType, int IsComplex = NumTraits<typename ei_traits<MatrixType>::Scalar>::IsComplex>
class MatrixFunction
@ -99,18 +44,26 @@ class MatrixFunction
public:
/** \brief Constructor. Computes matrix function.
/** \brief Constructor.
*
* \param[in] A argument of matrix function, should be a square matrix.
* \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x.
* \param[out] result pointer to the matrix in which to store the result, \f$ f(A) \f$.
*
* This function computes \f$ f(A) \f$ and stores the result in
* the matrix pointed to by \p result.
*
* See ei_matrix_function() for details.
* The class stores a reference to \p A, so it should not be
* changed (or destroyed) before compute() is called.
*/
MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result);
MatrixFunction(const MatrixType& A, StemFunction f);
/** \brief Compute the matrix function.
*
* \param[out] result the function \p f applied to \p A, as
* specified in the constructor.
*
* See ei_matrix_function() for details on how this computation
* is implemented.
*/
template <typename ResultType>
void compute(ResultType &result);
};
@ -136,23 +89,36 @@ class MatrixFunction<MatrixType, 0>
public:
/** \brief Constructor. Computes matrix function.
/** \brief Constructor.
*
* \param[in] A argument of matrix function, should be a square matrix.
* \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x.
* \param[out] result pointer to the matrix in which to store the result, \f$ f(A) \f$.
*/
MatrixFunction(const MatrixType& A, StemFunction f) : m_A(A), m_f(f) { }
/** \brief Compute the matrix function.
*
* \param[out] result the function \p f applied to \p A, as
* specified in the constructor.
*
* This function converts the real matrix \c A to a complex matrix,
* uses MatrixFunction<MatrixType,1> and then converts the result back to
* a real matrix.
*/
MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result)
template <typename ResultType>
void compute(ResultType& result)
{
ComplexMatrix CA = A.template cast<ComplexScalar>();
ComplexMatrix CA = m_A.template cast<ComplexScalar>();
ComplexMatrix Cresult;
MatrixFunction<ComplexMatrix>(CA, f, &Cresult);
*result = Cresult.real();
MatrixFunction<ComplexMatrix> mf(CA, m_f);
mf.compute(Cresult);
result = Cresult.real();
}
private:
const MatrixType& m_A; /**< \brief Reference to argument of matrix function. */
StemFunction *m_f; /**< \brief Stem function for matrix function under consideration */
};
@ -179,17 +145,12 @@ class MatrixFunction<MatrixType, 1>
public:
/** \brief Constructor. Computes matrix function.
*
* \param[in] A argument of matrix function, should be a square matrix.
* \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x.
* \param[out] result pointer to the matrix in which to store the result, \f$ f(A) \f$.
*/
MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result);
MatrixFunction(const MatrixType& A, StemFunction f);
template <typename ResultType> void compute(ResultType& result);
private:
void computeSchurDecomposition(const MatrixType& A);
void computeSchurDecomposition();
void partitionEigenvalues();
typename ListOfClusters::iterator findCluster(Scalar key);
void computeClusterSize();
@ -202,6 +163,7 @@ class MatrixFunction<MatrixType, 1>
void computeOffDiagonal();
DynMatrixType solveTriangularSylvester(const DynMatrixType& A, const DynMatrixType& B, const DynMatrixType& C);
const MatrixType& m_A; /**< \brief Reference to argument of matrix function. */
StemFunction *m_f; /**< \brief Stem function for matrix function under consideration */
MatrixType m_T; /**< \brief Triangular part of Schur decomposition */
MatrixType m_U; /**< \brief Unitary part of Schur decomposition */
@ -221,11 +183,28 @@ class MatrixFunction<MatrixType, 1>
static const RealScalar separation() { return static_cast<RealScalar>(0.01); }
};
/** \brief Constructor.
*
* \param[in] A argument of matrix function, should be a square matrix.
* \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x.
*/
template <typename MatrixType>
MatrixFunction<MatrixType,1>::MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result) :
m_f(f)
MatrixFunction<MatrixType,1>::MatrixFunction(const MatrixType& A, StemFunction f) :
m_A(A), m_f(f)
{
computeSchurDecomposition(A);
/* empty body */
}
/** \brief Compute the matrix function.
*
* \param[out] result the function \p f applied to \p A, as
* specified in the constructor.
*/
template <typename MatrixType>
template <typename ResultType>
void MatrixFunction<MatrixType,1>::compute(ResultType& result)
{
computeSchurDecomposition();
partitionEigenvalues();
computeClusterSize();
computeBlockStart();
@ -233,14 +212,14 @@ MatrixFunction<MatrixType,1>::MatrixFunction(const MatrixType& A, StemFunction f
permuteSchur();
computeBlockAtomic();
computeOffDiagonal();
*result = m_U * m_fT * m_U.adjoint();
result = m_U * m_fT * m_U.adjoint();
}
/** \brief Store the Schur decomposition of \p A in #m_T and #m_U */
/** \brief Store the Schur decomposition of #m_A in #m_T and #m_U */
template <typename MatrixType>
void MatrixFunction<MatrixType,1>::computeSchurDecomposition(const MatrixType& A)
void MatrixFunction<MatrixType,1>::computeSchurDecomposition()
{
const ComplexSchur<MatrixType> schurOfA(A);
const ComplexSchur<MatrixType> schurOfA(m_A);
m_T = schurOfA.matrixT();
m_U = schurOfA.matrixU();
}
@ -498,23 +477,129 @@ typename MatrixFunction<MatrixType,1>::DynMatrixType MatrixFunction<MatrixType,1
return X;
}
/** \ingroup MatrixFunctions_Module
*
* \brief Proxy for the matrix function of some matrix (expression).
*
* \tparam Derived Type of the argument to the matrix function.
*
* This class holds the argument to the matrix function until it is
* assigned or evaluated for some other reason (so the argument
* should not be changed in the meantime). It is the return type of
* ei_matrix_function() and related functions and most of the time
* this is the only way it is used.
*/
template<typename Derived> class MatrixFunctionReturnValue
: public ReturnByValue<MatrixFunctionReturnValue<Derived> >
{
private:
typedef typename ei_traits<Derived>::Scalar Scalar;
typedef typename ei_stem_function<Scalar>::type StemFunction;
public:
/** \brief Constructor.
*
* \param[in] A %Matrix (expression) forming the argument of the
* matrix function.
* \param[in] f Stem function for matrix function under consideration.
*/
MatrixFunctionReturnValue(const Derived& A, StemFunction f) : m_A(A), m_f(f) { }
/** \brief Compute the matrix function.
*
* \param[out] result \p f applied to \p A, where \p f and \p A
* are as in the constructor.
*/
template <typename ResultType>
inline void evalTo(ResultType& result) const
{
const typename ei_eval<Derived>::type Aevaluated = m_A.eval();
MatrixFunction<typename Derived::PlainMatrixType> mf(Aevaluated, m_f);
mf.compute(result);
}
int rows() const { return m_A.rows(); }
int cols() const { return m_A.cols(); }
private:
const Derived& m_A;
StemFunction *m_f;
};
template<typename Derived>
struct ei_traits<MatrixFunctionReturnValue<Derived> >
{
typedef typename Derived::PlainMatrixType ReturnMatrixType;
};
/** \ingroup MatrixFunctions_Module
*
* \brief Compute a matrix function.
*
* \param[in] M argument of matrix function, should be a square matrix.
* \param[in] f an entire function; \c f(x,n) should compute the n-th
* derivative of f at x.
* \returns expression representing \p f applied to \p M.
*
* Suppose that \p M is a matrix whose entries have type \c Scalar.
* Then, the second argument, \p f, should be a function with prototype
* \code
* ComplexScalar f(ComplexScalar, int)
* \endcode
* where \c ComplexScalar = \c std::complex<Scalar> if \c Scalar is
* real (e.g., \c float or \c double) and \c ComplexScalar =
* \c Scalar if \c Scalar is complex. The return value of \c f(x,n)
* should be \f$ f^{(n)}(x) \f$, the n-th derivative of f at x.
*
* This routine uses the algorithm described in:
* Philip Davies and Nicholas J. Higham,
* "A Schur-Parlett algorithm for computing matrix functions",
* <em>SIAM J. %Matrix Anal. Applic.</em>, <b>25</b>:464&ndash;485, 2003.
*
* The actual work is done by the MatrixFunction class.
*
* Example: The following program checks that
* \f[ \exp \left[ \begin{array}{ccc}
* 0 & \frac14\pi & 0 \\
* -\frac14\pi & 0 & 0 \\
* 0 & 0 & 0
* \end{array} \right] = \left[ \begin{array}{ccc}
* \frac12\sqrt2 & -\frac12\sqrt2 & 0 \\
* \frac12\sqrt2 & \frac12\sqrt2 & 0 \\
* 0 & 0 & 1
* \end{array} \right]. \f]
* This corresponds to a rotation of \f$ \frac14\pi \f$ radians around
* the z-axis. This is the same example as used in the documentation
* of ei_matrix_exponential().
*
* \include MatrixFunction.cpp
* Output: \verbinclude MatrixFunction.out
*
* Note that the function \c expfn is defined for complex numbers
* \c x, even though the matrix \c A is over the reals. Instead of
* \c expfn, we could also have used StdStemFunctions::exp:
* \code
* ei_matrix_function(A, StdStemFunctions<std::complex<double> >::exp, &B);
* \endcode
*/
template <typename Derived>
EIGEN_STRONG_INLINE void ei_matrix_function(const MatrixBase<Derived>& M,
typename ei_stem_function<typename ei_traits<Derived>::Scalar>::type f,
typename MatrixBase<Derived>::PlainMatrixType* result)
MatrixFunctionReturnValue<Derived>
ei_matrix_function(const MatrixBase<Derived> &M,
typename ei_stem_function<typename ei_traits<Derived>::Scalar>::type f)
{
ei_assert(M.rows() == M.cols());
typedef typename MatrixBase<Derived>::PlainMatrixType PlainMatrixType;
MatrixFunction<PlainMatrixType>(M, f, result);
return MatrixFunctionReturnValue<Derived>(M.derived(), f);
}
/** \ingroup MatrixFunctions_Module
*
* \brief Compute the matrix sine.
*
* \param[in] M a square matrix.
* \param[out] result pointer to matrix in which to store the result, \f$ \sin(M) \f$
* \param[in] M a square matrix.
* \returns expression representing \f$ \sin(M) \f$.
*
* This function calls ei_matrix_function() with StdStemFunctions::sin().
*
@ -522,44 +607,42 @@ EIGEN_STRONG_INLINE void ei_matrix_function(const MatrixBase<Derived>& M,
* Output: \verbinclude MatrixSine.out
*/
template <typename Derived>
EIGEN_STRONG_INLINE void ei_matrix_sin(const MatrixBase<Derived>& M,
typename MatrixBase<Derived>::PlainMatrixType* result)
MatrixFunctionReturnValue<Derived>
ei_matrix_sin(const MatrixBase<Derived>& M)
{
ei_assert(M.rows() == M.cols());
typedef typename MatrixBase<Derived>::PlainMatrixType PlainMatrixType;
typedef typename ei_traits<PlainMatrixType>::Scalar Scalar;
typedef typename ei_traits<Derived>::Scalar Scalar;
typedef typename ei_stem_function<Scalar>::ComplexScalar ComplexScalar;
MatrixFunction<PlainMatrixType>(M, StdStemFunctions<ComplexScalar>::sin, result);
return MatrixFunctionReturnValue<Derived>(M.derived(), StdStemFunctions<ComplexScalar>::sin);
}
/** \ingroup MatrixFunctions_Module
*
* \brief Compute the matrix cosine.
*
* \param[in] M a square matrix.
* \param[out] result pointer to matrix in which to store the result, \f$ \cos(M) \f$
* \param[in] M a square matrix.
* \returns expression representing \f$ \cos(M) \f$.
*
* This function calls ei_matrix_function() with StdStemFunctions::cos().
*
* \sa ei_matrix_sin() for an example.
*/
template <typename Derived>
EIGEN_STRONG_INLINE void ei_matrix_cos(const MatrixBase<Derived>& M,
typename MatrixBase<Derived>::PlainMatrixType* result)
MatrixFunctionReturnValue<Derived>
ei_matrix_cos(const MatrixBase<Derived>& M)
{
ei_assert(M.rows() == M.cols());
typedef typename MatrixBase<Derived>::PlainMatrixType PlainMatrixType;
typedef typename ei_traits<PlainMatrixType>::Scalar Scalar;
typedef typename ei_traits<Derived>::Scalar Scalar;
typedef typename ei_stem_function<Scalar>::ComplexScalar ComplexScalar;
MatrixFunction<PlainMatrixType>(M, StdStemFunctions<ComplexScalar>::cos, result);
return MatrixFunctionReturnValue<Derived>(M.derived(), StdStemFunctions<ComplexScalar>::cos);
}
/** \ingroup MatrixFunctions_Module
*
* \brief Compute the matrix hyperbolic sine.
*
* \param[in] M a square matrix.
* \param[out] result pointer to matrix in which to store the result, \f$ \sinh(M) \f$
* \param[in] M a square matrix.
* \returns expression representing \f$ \sinh(M) \f$
*
* This function calls ei_matrix_function() with StdStemFunctions::sinh().
*
@ -567,36 +650,34 @@ EIGEN_STRONG_INLINE void ei_matrix_cos(const MatrixBase<Derived>& M,
* Output: \verbinclude MatrixSinh.out
*/
template <typename Derived>
EIGEN_STRONG_INLINE void ei_matrix_sinh(const MatrixBase<Derived>& M,
typename MatrixBase<Derived>::PlainMatrixType* result)
MatrixFunctionReturnValue<Derived>
ei_matrix_sinh(const MatrixBase<Derived>& M)
{
ei_assert(M.rows() == M.cols());
typedef typename MatrixBase<Derived>::PlainMatrixType PlainMatrixType;
typedef typename ei_traits<PlainMatrixType>::Scalar Scalar;
typedef typename ei_traits<Derived>::Scalar Scalar;
typedef typename ei_stem_function<Scalar>::ComplexScalar ComplexScalar;
MatrixFunction<PlainMatrixType>(M, StdStemFunctions<ComplexScalar>::sinh, result);
return MatrixFunctionReturnValue<Derived>(M.derived(), StdStemFunctions<ComplexScalar>::sinh);
}
/** \ingroup MatrixFunctions_Module
*
* \brief Compute the matrix hyberpolic cosine.
* \brief Compute the matrix hyberbolic cosine.
*
* \param[in] M a square matrix.
* \param[out] result pointer to matrix in which to store the result, \f$ \cosh(M) \f$
* \param[in] M a square matrix.
* \returns expression representing \f$ \cosh(M) \f$
*
* This function calls ei_matrix_function() with StdStemFunctions::cosh().
*
* \sa ei_matrix_sinh() for an example.
*/
template <typename Derived>
EIGEN_STRONG_INLINE void ei_matrix_cosh(const MatrixBase<Derived>& M,
typename MatrixBase<Derived>::PlainMatrixType* result)
MatrixFunctionReturnValue<Derived>
ei_matrix_cosh(const MatrixBase<Derived>& M)
{
ei_assert(M.rows() == M.cols());
typedef typename MatrixBase<Derived>::PlainMatrixType PlainMatrixType;
typedef typename ei_traits<PlainMatrixType>::Scalar Scalar;
typedef typename ei_traits<Derived>::Scalar Scalar;
typedef typename ei_stem_function<Scalar>::ComplexScalar ComplexScalar;
MatrixFunction<PlainMatrixType>(M, StdStemFunctions<ComplexScalar>::cosh, result);
return MatrixFunctionReturnValue<Derived>(M.derived(), StdStemFunctions<ComplexScalar>::cosh);
}
#endif // EIGEN_MATRIX_FUNCTION

View File

@ -15,9 +15,8 @@ int main()
A << 0, -pi/4, 0,
pi/4, 0, 0,
0, 0, 0;
std::cout << "The matrix A is:\n" << A << "\n\n";
MatrixXd B;
ei_matrix_function(A, expfn, &B);
std::cout << "The matrix exponential of A is:\n" << B << "\n\n";
std::cout << "The matrix A is:\n" << A << "\n\n";
std::cout << "The matrix exponential of A is:\n"
<< ei_matrix_function(A, expfn) << "\n\n";
}

View File

@ -7,12 +7,10 @@ int main()
MatrixXd A = MatrixXd::Random(3,3);
std::cout << "A = \n" << A << "\n\n";
MatrixXd sinA;
ei_matrix_sin(A, &sinA);
MatrixXd sinA = ei_matrix_sin(A);
std::cout << "sin(A) = \n" << sinA << "\n\n";
MatrixXd cosA;
ei_matrix_cos(A, &cosA);
MatrixXd cosA = ei_matrix_cos(A);
std::cout << "cos(A) = \n" << cosA << "\n\n";
// The matrix functions satisfy sin^2(A) + cos^2(A) = I,

View File

@ -7,12 +7,10 @@ int main()
MatrixXf A = MatrixXf::Random(3,3);
std::cout << "A = \n" << A << "\n\n";
MatrixXf sinhA;
ei_matrix_sinh(A, &sinhA);
MatrixXf sinhA = ei_matrix_sinh(A);
std::cout << "sinh(A) = \n" << sinhA << "\n\n";
MatrixXf coshA;
ei_matrix_cosh(A, &coshA);
MatrixXf coshA = ei_matrix_cosh(A);
std::cout << "cosh(A) = \n" << coshA << "\n\n";
// The matrix functions satisfy cosh^2(A) - sinh^2(A) = I,

View File

@ -100,9 +100,8 @@ void testMatrixExponential(const MatrixType& A)
typedef std::complex<RealScalar> ComplexScalar;
for (int i = 0; i < g_repeat; i++) {
MatrixType expA;
ei_matrix_function(A, StdStemFunctions<ComplexScalar>::exp, &expA);
VERIFY_IS_APPROX(ei_matrix_exponential(A), expA);
VERIFY_IS_APPROX(ei_matrix_exponential(A),
ei_matrix_function(A, StdStemFunctions<ComplexScalar>::exp));
}
}
@ -110,9 +109,8 @@ template<typename MatrixType>
void testHyperbolicFunctions(const MatrixType& A)
{
for (int i = 0; i < g_repeat; i++) {
MatrixType sinhA, coshA;
ei_matrix_sinh(A, &sinhA);
ei_matrix_cosh(A, &coshA);
MatrixType sinhA = ei_matrix_sinh(A);
MatrixType coshA = ei_matrix_cosh(A);
MatrixType expA = ei_matrix_exponential(A);
VERIFY_IS_APPROX(sinhA, (expA - expA.inverse())/2);
VERIFY_IS_APPROX(coshA, (expA + expA.inverse())/2);
@ -137,13 +135,11 @@ void testGonioFunctions(const MatrixType& A)
ComplexMatrix exp_iA = ei_matrix_exponential(imagUnit * Ac);
MatrixType sinA;
ei_matrix_sin(A, &sinA);
MatrixType sinA = ei_matrix_sin(A);
ComplexMatrix sinAc = sinA.template cast<ComplexScalar>();
VERIFY_IS_APPROX(sinAc, (exp_iA - exp_iA.inverse()) / (two*imagUnit));
MatrixType cosA;
ei_matrix_cos(A, &cosA);
MatrixType cosA = ei_matrix_cos(A);
ComplexMatrix cosAc = cosA.template cast<ComplexScalar>();
VERIFY_IS_APPROX(cosAc, (exp_iA + exp_iA.inverse()) / 2);
}