add support for mixing type in trsv

This commit is contained in:
Gael Guennebaud 2010-07-13 16:03:49 +02:00
parent 36d9b51a44
commit 1dc9aaaf36
2 changed files with 73 additions and 11 deletions

View File

@ -53,7 +53,8 @@ struct ei_triangular_solver_selector;
template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor,1>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef ei_blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename Lhs::Index Index;
@ -81,12 +82,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor
Index startRow = IsLower ? pi : pi-actualPanelWidth;
Index startCol = IsLower ? 0 : pi;
ei_general_matrix_vector_product<Index,Scalar,RowMajor,LhsProductTraits::NeedToConjugate,Scalar,false>::run(
ei_general_matrix_vector_product<Index,LhsScalar,RowMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
actualPanelWidth, r,
&(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(),
&(other.coeffRef(startCol)), other.innerStride(),
&other.coeffRef(startRow), other.innerStride(),
Scalar(-1));
RhsScalar(-1));
}
for(Index k=0; k<actualPanelWidth; ++k)
@ -107,13 +108,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor
template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor,1>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
typedef ei_blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef typename Lhs::Index Index;
enum {
PacketSize = ei_packet_traits<Scalar>::size,
IsLower = ((Mode&Lower)==Lower)
};
@ -148,11 +148,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,ColMajor
// let's directly call the low level product function because:
// 1 - it is faster to compile
// 2 - it is slighlty faster at runtime
ei_general_matrix_vector_product<Index,Scalar,ColMajor,LhsProductTraits::NeedToConjugate,Scalar,false>::run(
ei_general_matrix_vector_product<Index,LhsScalar,ColMajor,LhsProductTraits::NeedToConjugate,RhsScalar,false>::run(
r, actualPanelWidth,
&(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(),
&other.coeff(startBlock), other.innerStride(),
&(other.coeffRef(endBlock, 0)), other.innerStride(), Scalar(-1));
&(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1));
}
}
}

View File

@ -170,6 +170,66 @@ template<typename MatrixType> void cholesky(const MatrixType& m)
}
template<typename MatrixType> void cholesky_cplx(const MatrixType& m)
{
// classic test
cholesky(m);
// test mixing real/scalar types
typedef typename MatrixType::Index Index;
Index rows = m.rows();
Index cols = m.cols();
typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> RealMatrixType;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType;
RealMatrixType a0 = RealMatrixType::Random(rows,cols);
VectorType vecB = VectorType::Random(rows), vecX(rows);
MatrixType matB = MatrixType::Random(rows,cols), matX(rows,cols);
RealMatrixType symm = a0 * a0.adjoint();
// let's make sure the matrix is not singular or near singular
for (int k=0; k<3; ++k)
{
RealMatrixType a1 = RealMatrixType::Random(rows,cols);
symm += a1 * a1.adjoint();
}
{
RealMatrixType symmLo = symm.template triangularView<Lower>();
LLT<RealMatrixType,Lower> chollo(symmLo);
VERIFY_IS_APPROX(symm, chollo.reconstructedMatrix());
vecX = chollo.solve(vecB);
VERIFY_IS_APPROX(symm * vecX, vecB);
// matX = chollo.solve(matB);
// VERIFY_IS_APPROX(symm * matX, matB);
}
// LDLT
{
int sign = ei_random<int>()%2 ? 1 : -1;
if(sign == -1)
{
symm = -symm; // test a negative matrix
}
RealMatrixType symmLo = symm.template triangularView<Lower>();
LDLT<RealMatrixType,Lower> ldltlo(symmLo);
VERIFY_IS_APPROX(symm, ldltlo.reconstructedMatrix());
vecX = ldltlo.solve(vecB);
VERIFY_IS_APPROX(symm * vecX, vecB);
// matX = ldltlo.solve(matB);
// VERIFY_IS_APPROX(symm * matX, matB);
}
}
template<typename MatrixType> void cholesky_verify_assert()
{
MatrixType tmp;
@ -192,14 +252,16 @@ template<typename MatrixType> void cholesky_verify_assert()
void test_cholesky()
{
int s;
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( cholesky(Matrix<double,1,1>()) );
CALL_SUBTEST_2( cholesky(MatrixXd(1,1)) );
CALL_SUBTEST_3( cholesky(Matrix2d()) );
CALL_SUBTEST_4( cholesky(Matrix3f()) );
CALL_SUBTEST_5( cholesky(Matrix4d()) );
CALL_SUBTEST_2( cholesky(MatrixXd(200,200)) );
CALL_SUBTEST_6( cholesky(MatrixXcd(100,100)) );
s = ei_random<int>(1,200);
CALL_SUBTEST_2( cholesky(MatrixXd(s,s)) );
s = ei_random<int>(1,100);
CALL_SUBTEST_6( cholesky_cplx(MatrixXcd(s,s)) );
}
CALL_SUBTEST_4( cholesky_verify_assert<Matrix3f>() );