Simplify and document Sylvester equation solver in MatrixFunction.

This commit is contained in:
Jitse Niesen 2009-12-27 18:09:50 +00:00
parent 72b6c05bf0
commit a25c9b1e46

View File

@ -168,7 +168,7 @@ class MatrixFunction<MatrixType, 1, 1>
void swapEntriesInSchur(int index, MatrixType& T, MatrixType& U);
void computeTriangular(const MatrixType& T, MatrixType& result, const IntVectorType& blockSize);
void computeBlockAtomic(const MatrixType& T, MatrixType& result, const IntVectorType& blockSize);
MatrixType solveSylvester(const MatrixType& A, const MatrixType& B, const MatrixType& C);
MatrixType solveTriangularSylvester(const MatrixType& A, const MatrixType& B, const MatrixType& C);
MatrixType computeAtomic(const MatrixType& T);
void divideInBlocks(const VectorType& v, listOfLists* result);
void constructPermutation(const VectorType& diag, const listOfLists& blocks,
@ -264,47 +264,76 @@ void MatrixFunction<MatrixType,1,1>::computeTriangular(const MatrixType& T, Matr
C += result.block(blockStart(blockIndex), blockStart(k), blockSize(blockIndex), blockSize(k)) * T.block(blockStart(k), blockStart(blockIndex+diagIndex), blockSize(k), blockSize(blockIndex+diagIndex));
C -= T.block(blockStart(blockIndex), blockStart(k), blockSize(blockIndex), blockSize(k)) * result.block(blockStart(k), blockStart(blockIndex+diagIndex), blockSize(k), blockSize(blockIndex+diagIndex));
}
result.block(blockStart(blockIndex), blockStart(blockIndex+diagIndex), blockSize(blockIndex), blockSize(blockIndex+diagIndex)) = solveSylvester(A, B, C);
result.block(blockStart(blockIndex), blockStart(blockIndex+diagIndex), blockSize(blockIndex), blockSize(blockIndex+diagIndex)) = solveTriangularSylvester(A, B, C);
}
}
}
// solve AX + XB = C <=> U* A' U X V V* + U* U X V B' V* = U* U C V V* <=> A' U X V + U X V B' = U C V
// Schur: A* = U A'* U* (so A = U* A' U), B = V B' V*, define: X' = U X V, C' = U C V, to get: A' X' + X' B' = C'
// A is m-by-m, B is n-by-n, X is m-by-n, C is m-by-n, U is m-by-m, V is n-by-n
/** \brief Solve a triangular Sylvester equation AX + XB = C
*
* \param[in] A The matrix A; should be square and upper triangular
* \param[in] B The matrix B; should be square and upper triangular
* \param[in] C The matrix C; should have correct size.
*
* \returns The solution X.
*
* If A is m-by-m and B is n-by-n, then both C and X are m-by-n.
* The (i,j)-th component of the Sylvester equation is
* \f[
* \sum_{k=i}^m A_{ik} X_{kj} + \sum_{k=1}^j X_{ik} B_{kj} = C_{ij}.
* \f]
* This can be re-arranged to yield:
* \f[
* X_{ij} = \frac{1}{A_{ii} + B_{jj}} \Bigl( C_{ij}
* - \sum_{k=i+1}^m A_{ik} X_{kj} - \sum_{k=1}^{j-1} X_{ik} B_{kj} \Bigr).
* \f]
* It is assumed that A and B are such that the numerator is never
* zero (otherwise the Sylvester equation does not have a unique
* solution). In that case, these equations can be evaluated in the
* order \f$ i=m,\ldots,1 \f$ and \f$ j=1,\ldots,n \f$.
*/
template <typename MatrixType>
MatrixType MatrixFunction<MatrixType,1,1>::solveSylvester(const MatrixType& A, const MatrixType& B, const MatrixType& C)
MatrixType MatrixFunction<MatrixType,1,1>::solveTriangularSylvester(
const MatrixType& A,
const MatrixType& B,
const MatrixType& C)
{
MatrixType U = MatrixType::Zero(A.rows(), A.rows());
for (int i = 0; i < A.rows(); i++) {
U(i, A.rows() - 1 - i) = static_cast<Scalar>(1);
}
MatrixType Aprime = U * A * U;
ei_assert(A.rows() == A.cols());
ei_assert(A.isUpperTriangular());
ei_assert(B.rows() == B.cols());
ei_assert(B.isUpperTriangular());
ei_assert(C.rows() == A.rows());
ei_assert(C.cols() == B.rows());
MatrixType Bprime = B;
MatrixType V = MatrixType::Identity(B.rows(), B.rows());
int m = A.rows();
int n = B.rows();
MatrixType X(m, n);
MatrixType Cprime = U * C * V;
MatrixType Xprime(A.rows(), B.rows());
for (int l = 0; l < B.rows(); l++) {
for (int k = 0; k < A.rows(); k++) {
Scalar tmp1, tmp2;
if (k == 0) {
tmp1 = 0;
for (int i = m - 1; i >= 0; --i) {
for (int j = 0; j < n; ++j) {
// Compute AX = \sum_{k=i+1}^m A_{ik} X_{kj}
Scalar AX;
if (i == m - 1) {
AX = 0;
} else {
Matrix<Scalar,1,1> tmp1matrix = Aprime.row(k).start(k) * Xprime.col(l).start(k);
tmp1 = tmp1matrix(0,0);
Matrix<Scalar,1,1> AXmatrix = A.row(i).end(m-1-i) * X.col(j).end(m-1-i);
AX = AXmatrix(0,0);
}
if (l == 0) {
tmp2 = 0;
// Compute XB = \sum_{k=1}^{j-1} X_{ik} B_{kj}
Scalar XB;
if (j == 0) {
XB = 0;
} else {
Matrix<Scalar,1,1> tmp2matrix = Xprime.row(k).start(l) * Bprime.col(l).start(l);
tmp2 = tmp2matrix(0,0);
Matrix<Scalar,1,1> XBmatrix = X.row(i).start(j) * B.col(j).start(j);
XB = XBmatrix(0,0);
}
Xprime(k,l) = (Cprime(k,l) - tmp1 - tmp2) / (Aprime(k,k) + Bprime(l,l));
X(i,j) = (C(i,j) - AX - XB) / (A(i,i) + B(j,j));
}
}
return U.adjoint() * Xprime * V.adjoint();
return X;
}