diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h b/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h index 147e21bc1..39c23cdc5 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2009 Jitse Niesen +// Copyright (C) 2009, 2010 Jitse Niesen // // Eigen is free software; you can redistribute it and/or // modify it under the terms of the GNU Lesser General Public diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h index fbc55507f..d63bcbce9 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h @@ -28,66 +28,11 @@ #include "StemFunction.h" #include "MatrixFunctionAtomic.h" + /** \ingroup MatrixFunctions_Module - * - * \brief Compute a matrix function. - * - * \param[in] M argument of matrix function, should be a square matrix. - * \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x. - * \param[out] result pointer to the matrix in which to store the result, \f$ f(M) \f$. - * - * This function computes \f$ f(A) \f$ and stores the result in the - * matrix pointed to by \p result. - * - * Suppose that \p M is a matrix whose entries have type \c Scalar. - * Then, the second argument, \p f, should be a function with prototype - * \code - * ComplexScalar f(ComplexScalar, int) - * \endcode - * where \c ComplexScalar = \c std::complex if \c Scalar is - * real (e.g., \c float or \c double) and \c ComplexScalar = - * \c Scalar if \c Scalar is complex. The return value of \c f(x,n) - * should be \f$ f^{(n)}(x) \f$, the n-th derivative of f at x. - * - * This routine uses the algorithm described in: - * Philip Davies and Nicholas J. Higham, - * "A Schur-Parlett algorithm for computing matrix functions", - * SIAM J. %Matrix Anal. Applic., 25:464–485, 2003. - * - * The actual work is done by the MatrixFunction class. - * - * Example: The following program checks that - * \f[ \exp \left[ \begin{array}{ccc} - * 0 & \frac14\pi & 0 \\ - * -\frac14\pi & 0 & 0 \\ - * 0 & 0 & 0 - * \end{array} \right] = \left[ \begin{array}{ccc} - * \frac12\sqrt2 & -\frac12\sqrt2 & 0 \\ - * \frac12\sqrt2 & \frac12\sqrt2 & 0 \\ - * 0 & 0 & 1 - * \end{array} \right]. \f] - * This corresponds to a rotation of \f$ \frac14\pi \f$ radians around - * the z-axis. This is the same example as used in the documentation - * of ei_matrix_exponential(). - * - * \include MatrixFunction.cpp - * Output: \verbinclude MatrixFunction.out - * - * Note that the function \c expfn is defined for complex numbers - * \c x, even though the matrix \c A is over the reals. Instead of - * \c expfn, we could also have used StdStemFunctions::exp: - * \code - * ei_matrix_function(A, StdStemFunctions >::exp, &B); - * \endcode - */ -template -EIGEN_STRONG_INLINE void ei_matrix_function(const MatrixBase& M, - typename ei_stem_function::Scalar>::type f, - typename MatrixBase::PlainMatrixType* result); - - -/** \ingroup MatrixFunctions_Module - * \brief Helper class for computing matrix functions. + * \brief Class for computing matrix exponentials. + * \tparam MatrixType type of the argument of the matrix function, + * expected to be an instantiation of the Matrix class template. */ template ::Scalar>::IsComplex> class MatrixFunction @@ -99,18 +44,26 @@ class MatrixFunction public: - /** \brief Constructor. Computes matrix function. + /** \brief Constructor. * * \param[in] A argument of matrix function, should be a square matrix. * \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x. - * \param[out] result pointer to the matrix in which to store the result, \f$ f(A) \f$. * - * This function computes \f$ f(A) \f$ and stores the result in - * the matrix pointed to by \p result. - * - * See ei_matrix_function() for details. + * The class stores a reference to \p A, so it should not be + * changed (or destroyed) before compute() is called. */ - MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result); + MatrixFunction(const MatrixType& A, StemFunction f); + + /** \brief Compute the matrix function. + * + * \param[out] result the function \p f applied to \p A, as + * specified in the constructor. + * + * See ei_matrix_function() for details on how this computation + * is implemented. + */ + template + void compute(ResultType &result); }; @@ -136,23 +89,36 @@ class MatrixFunction public: - /** \brief Constructor. Computes matrix function. + /** \brief Constructor. * * \param[in] A argument of matrix function, should be a square matrix. * \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x. - * \param[out] result pointer to the matrix in which to store the result, \f$ f(A) \f$. + */ + MatrixFunction(const MatrixType& A, StemFunction f) : m_A(A), m_f(f) { } + + /** \brief Compute the matrix function. + * + * \param[out] result the function \p f applied to \p A, as + * specified in the constructor. * * This function converts the real matrix \c A to a complex matrix, * uses MatrixFunction and then converts the result back to * a real matrix. */ - MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result) + template + void compute(ResultType& result) { - ComplexMatrix CA = A.template cast(); + ComplexMatrix CA = m_A.template cast(); ComplexMatrix Cresult; - MatrixFunction(CA, f, &Cresult); - *result = Cresult.real(); + MatrixFunction mf(CA, m_f); + mf.compute(Cresult); + result = Cresult.real(); } + + private: + + const MatrixType& m_A; /**< \brief Reference to argument of matrix function. */ + StemFunction *m_f; /**< \brief Stem function for matrix function under consideration */ }; @@ -179,17 +145,12 @@ class MatrixFunction public: - /** \brief Constructor. Computes matrix function. - * - * \param[in] A argument of matrix function, should be a square matrix. - * \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x. - * \param[out] result pointer to the matrix in which to store the result, \f$ f(A) \f$. - */ - MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result); + MatrixFunction(const MatrixType& A, StemFunction f); + template void compute(ResultType& result); private: - void computeSchurDecomposition(const MatrixType& A); + void computeSchurDecomposition(); void partitionEigenvalues(); typename ListOfClusters::iterator findCluster(Scalar key); void computeClusterSize(); @@ -202,6 +163,7 @@ class MatrixFunction void computeOffDiagonal(); DynMatrixType solveTriangularSylvester(const DynMatrixType& A, const DynMatrixType& B, const DynMatrixType& C); + const MatrixType& m_A; /**< \brief Reference to argument of matrix function. */ StemFunction *m_f; /**< \brief Stem function for matrix function under consideration */ MatrixType m_T; /**< \brief Triangular part of Schur decomposition */ MatrixType m_U; /**< \brief Unitary part of Schur decomposition */ @@ -221,11 +183,28 @@ class MatrixFunction static const RealScalar separation() { return static_cast(0.01); } }; +/** \brief Constructor. + * + * \param[in] A argument of matrix function, should be a square matrix. + * \param[in] f an entire function; \c f(x,n) should compute the n-th derivative of f at x. + */ template -MatrixFunction::MatrixFunction(const MatrixType& A, StemFunction f, MatrixType* result) : - m_f(f) +MatrixFunction::MatrixFunction(const MatrixType& A, StemFunction f) : + m_A(A), m_f(f) { - computeSchurDecomposition(A); + /* empty body */ +} + +/** \brief Compute the matrix function. + * + * \param[out] result the function \p f applied to \p A, as + * specified in the constructor. + */ +template +template +void MatrixFunction::compute(ResultType& result) +{ + computeSchurDecomposition(); partitionEigenvalues(); computeClusterSize(); computeBlockStart(); @@ -233,14 +212,14 @@ MatrixFunction::MatrixFunction(const MatrixType& A, StemFunction f permuteSchur(); computeBlockAtomic(); computeOffDiagonal(); - *result = m_U * m_fT * m_U.adjoint(); + result = m_U * m_fT * m_U.adjoint(); } -/** \brief Store the Schur decomposition of \p A in #m_T and #m_U */ +/** \brief Store the Schur decomposition of #m_A in #m_T and #m_U */ template -void MatrixFunction::computeSchurDecomposition(const MatrixType& A) +void MatrixFunction::computeSchurDecomposition() { - const ComplexSchur schurOfA(A); + const ComplexSchur schurOfA(m_A); m_T = schurOfA.matrixT(); m_U = schurOfA.matrixU(); } @@ -498,23 +477,129 @@ typename MatrixFunction::DynMatrixType MatrixFunction class MatrixFunctionReturnValue +: public ReturnByValue > +{ + private: + typedef typename ei_traits::Scalar Scalar; + typedef typename ei_stem_function::type StemFunction; + + public: + + /** \brief Constructor. + * + * \param[in] A %Matrix (expression) forming the argument of the + * matrix function. + * \param[in] f Stem function for matrix function under consideration. + */ + MatrixFunctionReturnValue(const Derived& A, StemFunction f) : m_A(A), m_f(f) { } + + /** \brief Compute the matrix function. + * + * \param[out] result \p f applied to \p A, where \p f and \p A + * are as in the constructor. + */ + template + inline void evalTo(ResultType& result) const + { + const typename ei_eval::type Aevaluated = m_A.eval(); + MatrixFunction mf(Aevaluated, m_f); + mf.compute(result); + } + + int rows() const { return m_A.rows(); } + int cols() const { return m_A.cols(); } + + private: + const Derived& m_A; + StemFunction *m_f; +}; + +template +struct ei_traits > +{ + typedef typename Derived::PlainMatrixType ReturnMatrixType; +}; + + +/** \ingroup MatrixFunctions_Module + * + * \brief Compute a matrix function. + * + * \param[in] M argument of matrix function, should be a square matrix. + * \param[in] f an entire function; \c f(x,n) should compute the n-th + * derivative of f at x. + * \returns expression representing \p f applied to \p M. + * + * Suppose that \p M is a matrix whose entries have type \c Scalar. + * Then, the second argument, \p f, should be a function with prototype + * \code + * ComplexScalar f(ComplexScalar, int) + * \endcode + * where \c ComplexScalar = \c std::complex if \c Scalar is + * real (e.g., \c float or \c double) and \c ComplexScalar = + * \c Scalar if \c Scalar is complex. The return value of \c f(x,n) + * should be \f$ f^{(n)}(x) \f$, the n-th derivative of f at x. + * + * This routine uses the algorithm described in: + * Philip Davies and Nicholas J. Higham, + * "A Schur-Parlett algorithm for computing matrix functions", + * SIAM J. %Matrix Anal. Applic., 25:464–485, 2003. + * + * The actual work is done by the MatrixFunction class. + * + * Example: The following program checks that + * \f[ \exp \left[ \begin{array}{ccc} + * 0 & \frac14\pi & 0 \\ + * -\frac14\pi & 0 & 0 \\ + * 0 & 0 & 0 + * \end{array} \right] = \left[ \begin{array}{ccc} + * \frac12\sqrt2 & -\frac12\sqrt2 & 0 \\ + * \frac12\sqrt2 & \frac12\sqrt2 & 0 \\ + * 0 & 0 & 1 + * \end{array} \right]. \f] + * This corresponds to a rotation of \f$ \frac14\pi \f$ radians around + * the z-axis. This is the same example as used in the documentation + * of ei_matrix_exponential(). + * + * \include MatrixFunction.cpp + * Output: \verbinclude MatrixFunction.out + * + * Note that the function \c expfn is defined for complex numbers + * \c x, even though the matrix \c A is over the reals. Instead of + * \c expfn, we could also have used StdStemFunctions::exp: + * \code + * ei_matrix_function(A, StdStemFunctions >::exp, &B); + * \endcode + */ template -EIGEN_STRONG_INLINE void ei_matrix_function(const MatrixBase& M, - typename ei_stem_function::Scalar>::type f, - typename MatrixBase::PlainMatrixType* result) +MatrixFunctionReturnValue +ei_matrix_function(const MatrixBase &M, + typename ei_stem_function::Scalar>::type f) { ei_assert(M.rows() == M.cols()); - typedef typename MatrixBase::PlainMatrixType PlainMatrixType; - MatrixFunction(M, f, result); + return MatrixFunctionReturnValue(M.derived(), f); } /** \ingroup MatrixFunctions_Module * * \brief Compute the matrix sine. * - * \param[in] M a square matrix. - * \param[out] result pointer to matrix in which to store the result, \f$ \sin(M) \f$ + * \param[in] M a square matrix. + * \returns expression representing \f$ \sin(M) \f$. * * This function calls ei_matrix_function() with StdStemFunctions::sin(). * @@ -522,44 +607,42 @@ EIGEN_STRONG_INLINE void ei_matrix_function(const MatrixBase& M, * Output: \verbinclude MatrixSine.out */ template -EIGEN_STRONG_INLINE void ei_matrix_sin(const MatrixBase& M, - typename MatrixBase::PlainMatrixType* result) +MatrixFunctionReturnValue +ei_matrix_sin(const MatrixBase& M) { ei_assert(M.rows() == M.cols()); - typedef typename MatrixBase::PlainMatrixType PlainMatrixType; - typedef typename ei_traits::Scalar Scalar; + typedef typename ei_traits::Scalar Scalar; typedef typename ei_stem_function::ComplexScalar ComplexScalar; - MatrixFunction(M, StdStemFunctions::sin, result); + return MatrixFunctionReturnValue(M.derived(), StdStemFunctions::sin); } /** \ingroup MatrixFunctions_Module * * \brief Compute the matrix cosine. * - * \param[in] M a square matrix. - * \param[out] result pointer to matrix in which to store the result, \f$ \cos(M) \f$ + * \param[in] M a square matrix. + * \returns expression representing \f$ \cos(M) \f$. * * This function calls ei_matrix_function() with StdStemFunctions::cos(). * * \sa ei_matrix_sin() for an example. */ template -EIGEN_STRONG_INLINE void ei_matrix_cos(const MatrixBase& M, - typename MatrixBase::PlainMatrixType* result) +MatrixFunctionReturnValue +ei_matrix_cos(const MatrixBase& M) { ei_assert(M.rows() == M.cols()); - typedef typename MatrixBase::PlainMatrixType PlainMatrixType; - typedef typename ei_traits::Scalar Scalar; + typedef typename ei_traits::Scalar Scalar; typedef typename ei_stem_function::ComplexScalar ComplexScalar; - MatrixFunction(M, StdStemFunctions::cos, result); + return MatrixFunctionReturnValue(M.derived(), StdStemFunctions::cos); } /** \ingroup MatrixFunctions_Module * * \brief Compute the matrix hyperbolic sine. * - * \param[in] M a square matrix. - * \param[out] result pointer to matrix in which to store the result, \f$ \sinh(M) \f$ + * \param[in] M a square matrix. + * \returns expression representing \f$ \sinh(M) \f$ * * This function calls ei_matrix_function() with StdStemFunctions::sinh(). * @@ -567,36 +650,34 @@ EIGEN_STRONG_INLINE void ei_matrix_cos(const MatrixBase& M, * Output: \verbinclude MatrixSinh.out */ template -EIGEN_STRONG_INLINE void ei_matrix_sinh(const MatrixBase& M, - typename MatrixBase::PlainMatrixType* result) +MatrixFunctionReturnValue +ei_matrix_sinh(const MatrixBase& M) { ei_assert(M.rows() == M.cols()); - typedef typename MatrixBase::PlainMatrixType PlainMatrixType; - typedef typename ei_traits::Scalar Scalar; + typedef typename ei_traits::Scalar Scalar; typedef typename ei_stem_function::ComplexScalar ComplexScalar; - MatrixFunction(M, StdStemFunctions::sinh, result); + return MatrixFunctionReturnValue(M.derived(), StdStemFunctions::sinh); } /** \ingroup MatrixFunctions_Module * - * \brief Compute the matrix hyberpolic cosine. + * \brief Compute the matrix hyberbolic cosine. * - * \param[in] M a square matrix. - * \param[out] result pointer to matrix in which to store the result, \f$ \cosh(M) \f$ + * \param[in] M a square matrix. + * \returns expression representing \f$ \cosh(M) \f$ * * This function calls ei_matrix_function() with StdStemFunctions::cosh(). * * \sa ei_matrix_sinh() for an example. */ template -EIGEN_STRONG_INLINE void ei_matrix_cosh(const MatrixBase& M, - typename MatrixBase::PlainMatrixType* result) +MatrixFunctionReturnValue +ei_matrix_cosh(const MatrixBase& M) { ei_assert(M.rows() == M.cols()); - typedef typename MatrixBase::PlainMatrixType PlainMatrixType; - typedef typename ei_traits::Scalar Scalar; + typedef typename ei_traits::Scalar Scalar; typedef typename ei_stem_function::ComplexScalar ComplexScalar; - MatrixFunction(M, StdStemFunctions::cosh, result); + return MatrixFunctionReturnValue(M.derived(), StdStemFunctions::cosh); } #endif // EIGEN_MATRIX_FUNCTION diff --git a/unsupported/doc/examples/MatrixFunction.cpp b/unsupported/doc/examples/MatrixFunction.cpp index c11cb821b..075fe7361 100644 --- a/unsupported/doc/examples/MatrixFunction.cpp +++ b/unsupported/doc/examples/MatrixFunction.cpp @@ -15,9 +15,8 @@ int main() A << 0, -pi/4, 0, pi/4, 0, 0, 0, 0, 0; - std::cout << "The matrix A is:\n" << A << "\n\n"; - MatrixXd B; - ei_matrix_function(A, expfn, &B); - std::cout << "The matrix exponential of A is:\n" << B << "\n\n"; + std::cout << "The matrix A is:\n" << A << "\n\n"; + std::cout << "The matrix exponential of A is:\n" + << ei_matrix_function(A, expfn) << "\n\n"; } diff --git a/unsupported/doc/examples/MatrixSine.cpp b/unsupported/doc/examples/MatrixSine.cpp index f8780ac92..2bbf99bbb 100644 --- a/unsupported/doc/examples/MatrixSine.cpp +++ b/unsupported/doc/examples/MatrixSine.cpp @@ -7,12 +7,10 @@ int main() MatrixXd A = MatrixXd::Random(3,3); std::cout << "A = \n" << A << "\n\n"; - MatrixXd sinA; - ei_matrix_sin(A, &sinA); + MatrixXd sinA = ei_matrix_sin(A); std::cout << "sin(A) = \n" << sinA << "\n\n"; - MatrixXd cosA; - ei_matrix_cos(A, &cosA); + MatrixXd cosA = ei_matrix_cos(A); std::cout << "cos(A) = \n" << cosA << "\n\n"; // The matrix functions satisfy sin^2(A) + cos^2(A) = I, diff --git a/unsupported/doc/examples/MatrixSinh.cpp b/unsupported/doc/examples/MatrixSinh.cpp index 488d95652..036534dea 100644 --- a/unsupported/doc/examples/MatrixSinh.cpp +++ b/unsupported/doc/examples/MatrixSinh.cpp @@ -7,12 +7,10 @@ int main() MatrixXf A = MatrixXf::Random(3,3); std::cout << "A = \n" << A << "\n\n"; - MatrixXf sinhA; - ei_matrix_sinh(A, &sinhA); + MatrixXf sinhA = ei_matrix_sinh(A); std::cout << "sinh(A) = \n" << sinhA << "\n\n"; - MatrixXf coshA; - ei_matrix_cosh(A, &coshA); + MatrixXf coshA = ei_matrix_cosh(A); std::cout << "cosh(A) = \n" << coshA << "\n\n"; // The matrix functions satisfy cosh^2(A) - sinh^2(A) = I, diff --git a/unsupported/test/matrix_function.cpp b/unsupported/test/matrix_function.cpp index 25134f21d..4ff6d7f1e 100644 --- a/unsupported/test/matrix_function.cpp +++ b/unsupported/test/matrix_function.cpp @@ -100,9 +100,8 @@ void testMatrixExponential(const MatrixType& A) typedef std::complex ComplexScalar; for (int i = 0; i < g_repeat; i++) { - MatrixType expA; - ei_matrix_function(A, StdStemFunctions::exp, &expA); - VERIFY_IS_APPROX(ei_matrix_exponential(A), expA); + VERIFY_IS_APPROX(ei_matrix_exponential(A), + ei_matrix_function(A, StdStemFunctions::exp)); } } @@ -110,9 +109,8 @@ template void testHyperbolicFunctions(const MatrixType& A) { for (int i = 0; i < g_repeat; i++) { - MatrixType sinhA, coshA; - ei_matrix_sinh(A, &sinhA); - ei_matrix_cosh(A, &coshA); + MatrixType sinhA = ei_matrix_sinh(A); + MatrixType coshA = ei_matrix_cosh(A); MatrixType expA = ei_matrix_exponential(A); VERIFY_IS_APPROX(sinhA, (expA - expA.inverse())/2); VERIFY_IS_APPROX(coshA, (expA + expA.inverse())/2); @@ -137,13 +135,11 @@ void testGonioFunctions(const MatrixType& A) ComplexMatrix exp_iA = ei_matrix_exponential(imagUnit * Ac); - MatrixType sinA; - ei_matrix_sin(A, &sinA); + MatrixType sinA = ei_matrix_sin(A); ComplexMatrix sinAc = sinA.template cast(); VERIFY_IS_APPROX(sinAc, (exp_iA - exp_iA.inverse()) / (two*imagUnit)); - MatrixType cosA; - ei_matrix_cos(A, &cosA); + MatrixType cosA = ei_matrix_cos(A); ComplexMatrix cosAc = cosA.template cast(); VERIFY_IS_APPROX(cosAc, (exp_iA + exp_iA.inverse()) / 2); }