From b47dea8b7aeab10cf584f2d3275192d90d8df2ed Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 10 Jul 2009 11:30:46 +0200 Subject: [PATCH] add a meta unroller for the triangular solver (only for vectors as rhs) --- Eigen/src/Core/SolveTriangular.h | 84 +++++++++++++++++++++++--------- test/triangular.cpp | 12 ++++- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index 452d40a4c..3a65a8b27 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -26,28 +26,25 @@ #define EIGEN_SOLVETRIANGULAR_H template struct ei_triangular_solver_selector; // forward substitution, row-major -template -struct ei_triangular_solver_selector +template +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef ei_product_factor_traits LhsProductTraits; typedef typename LhsProductTraits::ActualXprType ActualLhsType; enum { - IsLowerTriangular = (UpLo==LowerTriangular) + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit) }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n"; + { static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); @@ -90,12 +87,12 @@ struct ei_triangular_solver_selector }; // Implements the following configurations: -// - inv(LowerTriangular, ColMajor) * Column vector -// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector -// - inv(UpperTriangular, ColMajor) * Column vector -// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector -template -struct ei_triangular_solver_selector +// - inv(LowerTriangular, ColMajor) * Column vectors +// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors +// - inv(UpperTriangular, ColMajor) * Column vectors +// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors +template +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits::type Packet; @@ -103,11 +100,11 @@ struct ei_triangular_solver_selector typedef typename LhsProductTraits::ActualXprType ActualLhsType; enum { PacketSize = ei_packet_traits::size, - IsLowerTriangular = (UpLo==LowerTriangular) + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit) }; static void run(const Lhs& lhs, Rhs& other) - {//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n"; + { static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); @@ -154,6 +151,49 @@ struct ei_triangular_solver_selector } }; +/*************************************************************************** +* meta-unrolling implementation +***************************************************************************/ + +template +struct ei_triangular_solver_unroller; + +template +struct ei_triangular_solver_unroller { + enum { + IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit), + I = IsLowerTriangular ? Index : Size - Index - 1, + S = IsLowerTriangular ? 0 : I+1 + }; + static void run(const Lhs& lhs, Rhs& rhs) + { + if (Index>0) + rhs.coeffRef(I) -= ((lhs.row(I).template segment(S).transpose()) + .cwise()*(rhs.template segment(S))).sum(); + + if(!(Mode & UnitDiagBit)) + rhs.coeffRef(I) /= lhs.coeff(I,I); + + ei_triangular_solver_unroller::run(lhs,rhs); + } +}; + +template +struct ei_triangular_solver_unroller { + static void run(const Lhs& lhs, Rhs& rhs) {} +}; + +template +struct ei_triangular_solver_selector { + static void run(const Lhs& lhs, Rhs& rhs) + { ei_triangular_solver_unroller::run(lhs,rhs); } +}; + +/*************************************************************************** +* TriangularView methods +***************************************************************************/ + /** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other * * \nonstableyet @@ -161,7 +201,7 @@ struct ei_triangular_solver_selector * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. * This function will const_cast it, so constness isn't honored here. * - * See MatrixBase:solveTriangular() for the details. + * See TriangularView:solve() for the details. */ template template @@ -198,8 +238,6 @@ void TriangularView::solveInPlace(const MatrixBase& * can be done by marked(), and that is automatically the case with expressions such as those returned * by extract(). * - * \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one) - * * Example: \include MatrixBase_marked.cpp * Output: \verbinclude MatrixBase_marked.out * @@ -213,10 +251,10 @@ void TriangularView::solveInPlace(const MatrixBase& * * \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.: * \code - * M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose()); + * M * T^1 <=> T.transpose().solveInPlace(M.transpose()); * \endcode * - * \sa solveTriangularInPlace() + * \sa TriangularView::solveInPlace() */ template template diff --git a/test/triangular.cpp b/test/triangular.cpp index 0c03e987e..7c680a8ed 100644 --- a/test/triangular.cpp +++ b/test/triangular.cpp @@ -86,7 +86,17 @@ template void triangular(const MatrixType& m) while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random(); Transpose trm4(m4); - // test back and forward subsitution + // test back and forward subsitution with a vector as the rhs + m3 = m1.template triangularView(); + VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView().solve(v2)), largerEps)); + m3 = m1.template triangularView(); + VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView().solve(v2)), largerEps)); + m3 = m1.template triangularView(); + VERIFY(v2.isApprox(m3 * (m1.template triangularView().solve(v2)), largerEps)); + m3 = m1.template triangularView(); + VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView().solve(v2)), largerEps)); + + // test back and forward subsitution with a matrix as the rhs m3 = m1.template triangularView(); VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView().solve(m2)), largerEps)); m3 = m1.template triangularView();