From 6b4e215710dd5c12ad1fe8e820875674bdd849c8 Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Sat, 7 May 2011 22:57:46 +0100 Subject: [PATCH] Implement matrix square root for complex matrices. I hope to implement the real case soon, but it's a bit more complicated due to the 2-by-2 blocks in the real Schur decomposition. --- unsupported/Eigen/MatrixFunctions | 1 + .../src/MatrixFunctions/MatrixFunction.h | 8 +- .../src/MatrixFunctions/MatrixSquareRoot.h | 135 ++++++++++++++++++ unsupported/test/CMakeLists.txt | 1 + unsupported/test/matrix_square_root.cpp | 46 ++++++ 5 files changed, 187 insertions(+), 4 deletions(-) create mode 100644 unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h create mode 100644 unsupported/test/matrix_square_root.cpp diff --git a/unsupported/Eigen/MatrixFunctions b/unsupported/Eigen/MatrixFunctions index d39c49e53..79340e206 100644 --- a/unsupported/Eigen/MatrixFunctions +++ b/unsupported/Eigen/MatrixFunctions @@ -69,6 +69,7 @@ namespace Eigen { #include "src/MatrixFunctions/MatrixExponential.h" #include "src/MatrixFunctions/MatrixFunction.h" +#include "src/MatrixFunctions/MatrixSquareRoot.h" diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h index 4b9d8a102..b343f38bc 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h @@ -68,8 +68,8 @@ class MatrixFunction }; -/** \ingroup MatrixFunctions_Module - * \brief Partial specialization of MatrixFunction for real matrices \internal +/** \internal \ingroup MatrixFunctions_Module + * \brief Partial specialization of MatrixFunction for real matrices */ template class MatrixFunction @@ -124,8 +124,8 @@ class MatrixFunction }; -/** \ingroup MatrixFunctions_Module - * \brief Partial specialization of MatrixFunction for complex matrices \internal +/** \internal \ingroup MatrixFunctions_Module + * \brief Partial specialization of MatrixFunction for complex matrices */ template class MatrixFunction diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h new file mode 100644 index 000000000..5eeda11ec --- /dev/null +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -0,0 +1,135 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2011 Jitse Niesen +// +// Eigen is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 3 of the License, or (at your option) any later version. +// +// Alternatively, you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation; either version 2 of +// the License, or (at your option) any later version. +// +// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License and a copy of the GNU General Public License along with +// Eigen. If not, see . + +#ifndef EIGEN_MATRIX_SQUARE_ROOT +#define EIGEN_MATRIX_SQUARE_ROOT + +/** \ingroup MatrixFunctions_Module + * \brief Class for computing matrix square roots. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + */ +template ::Scalar>::IsComplex> +class MatrixSquareRoot +{ + public: + + /** \brief Constructor. + * + * \param[in] A matrix whose square root is to be computed. + * + * The class stores a reference to \p A, so it should not be + * changed (or destroyed) before compute() is called. + */ + MatrixSquareRoot(const MatrixType& A); + + /** \brief Compute the matrix square root + * + * \param[out] result square root of \p A, as specified in the constructor. + * + * See MatrixBase::sqrt() for details on how this computation + * is implemented. + */ + template + void compute(ResultType &result); +}; + + +// ********** Partial specialization for real matrices ********** + +template +class MatrixSquareRoot +{ + public: + MatrixSquareRoot(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + + template void compute(ResultType &result); + + private: + const MatrixType& m_A; +}; + +template +template +void MatrixSquareRoot::compute(ResultType &result) +{ + eigen_assert("Square root of real matrices is not implemented!"); +} + + +// ********** Partial specialization for complex matrices ********** + +template +class MatrixSquareRoot +{ + public: + MatrixSquareRoot(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + + template void compute(ResultType &result); + + private: + const MatrixType& m_A; +}; + +template +template +void MatrixSquareRoot::compute(ResultType &result) +{ + // Compute Schur decomposition of m_A + const ComplexSchur schurOfA(m_A); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); + + // Compute square root of T and store it in upper triangular part of result + // This uses that the square root of triangular matrices can be computed directly. + result.resize(m_A.rows(), m_A.cols()); + typedef typename MatrixType::Index Index; + for (Index i = 0; i < m_A.rows(); i++) { + result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i)); + } + for (Index j = 1; j < m_A.cols(); j++) { + for (Index i = j-1; i >= 0; i--) { + typedef typename MatrixType::Scalar Scalar; + // if i = j-1, then segment has length 0 so tmp = 0 + Scalar tmp = result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1); + // denominator may be zero if original matrix is singular + result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); + } + } + + // Compute square root of m_A as U * result * U.adjoint() + MatrixType tmp; + tmp.noalias() = U * result.template triangularView(); + result.noalias() = tmp * U.adjoint(); +} + +#endif // EIGEN_MATRIX_FUNCTION diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index e2f5b1c60..5452dd819 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -75,6 +75,7 @@ endif() ei_add_test(matrix_exponential) ei_add_test(matrix_function) +ei_add_test(matrix_square_root) ei_add_test(alignedvector3) ei_add_test(FFT) diff --git a/unsupported/test/matrix_square_root.cpp b/unsupported/test/matrix_square_root.cpp new file mode 100644 index 000000000..cd2c6cfc4 --- /dev/null +++ b/unsupported/test/matrix_square_root.cpp @@ -0,0 +1,46 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2011 Jitse Niesen +// +// Eigen is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 3 of the License, or (at your option) any later version. +// +// Alternatively, you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation; either version 2 of +// the License, or (at your option) any later version. +// +// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License and a copy of the GNU General Public License along with +// Eigen. If not, see . + +#include "main.h" +#include + +template +void testMatrixSqrt(const MatrixType& m) +{ + typedef typename MatrixType::Index Index; + const Index size = m.rows(); + MatrixType A = MatrixType::Random(size, size); + MatrixSquareRoot msr(A); + MatrixType S; + msr.compute(S); + VERIFY_IS_APPROX(S*S, A); +} + +void test_matrix_square_root() +{ + for (int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1(testMatrixSqrt(Matrix3cf())); + CALL_SUBTEST_2(testMatrixSqrt(MatrixXcd(12,12))); + } +}