add a meta unroller for the triangular solver (only for vectors as rhs)

This commit is contained in:
Gael Guennebaud 2009-07-10 11:30:46 +02:00
parent 1a1b2e9f27
commit b47dea8b7a
2 changed files with 72 additions and 24 deletions

View File

@ -26,28 +26,25 @@
#define EIGEN_SOLVETRIANGULAR_H #define EIGEN_SOLVETRIANGULAR_H
template<typename Lhs, typename Rhs, template<typename Lhs, typename Rhs,
int Mode, // Upper/Lower | UnitDiag int Mode, // can be Upper/Lower | UnitDiag
int UpLo = (Mode & LowerTriangularBit) int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
? LowerTriangular ? CompleteUnrolling : NoUnrolling,
: (Mode & UpperTriangularBit)
? UpperTriangular
: -1,
int StorageOrder = int(Lhs::Flags) & RowMajorBit int StorageOrder = int(Lhs::Flags) & RowMajorBit
> >
struct ei_triangular_solver_selector; struct ei_triangular_solver_selector;
// forward substitution, row-major // forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode, int UpLo> template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor> struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
{ {
typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Scalar Scalar;
typedef ei_product_factor_traits<Lhs> LhsProductTraits; typedef ei_product_factor_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType; typedef typename LhsProductTraits::ActualXprType ActualLhsType;
enum { enum {
IsLowerTriangular = (UpLo==LowerTriangular) IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
}; };
static void run(const Lhs& lhs, Rhs& other) static void run(const Lhs& lhs, Rhs& other)
{//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n"; {
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
}; };
// Implements the following configurations: // Implements the following configurations:
// - inv(LowerTriangular, ColMajor) * Column vector // - inv(LowerTriangular, ColMajor) * Column vectors
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector // - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors
// - inv(UpperTriangular, ColMajor) * Column vector // - inv(UpperTriangular, ColMajor) * Column vectors
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector // - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors
template<typename Lhs, typename Rhs, int Mode, int UpLo> template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor> struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
{ {
typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet; typedef typename ei_packet_traits<Scalar>::type Packet;
@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
typedef typename LhsProductTraits::ActualXprType ActualLhsType; typedef typename LhsProductTraits::ActualXprType ActualLhsType;
enum { enum {
PacketSize = ei_packet_traits<Scalar>::size, PacketSize = ei_packet_traits<Scalar>::size,
IsLowerTriangular = (UpLo==LowerTriangular) IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
}; };
static void run(const Lhs& lhs, Rhs& other) static void run(const Lhs& lhs, Rhs& other)
{//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n"; {
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH; static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
} }
}; };
/***************************************************************************
* meta-unrolling implementation
***************************************************************************/
template<typename Lhs, typename Rhs, int Mode, int Index, int Size,
bool Stop = Index==Size>
struct ei_triangular_solver_unroller;
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
enum {
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit),
I = IsLowerTriangular ? Index : Size - Index - 1,
S = IsLowerTriangular ? 0 : I+1
};
static void run(const Lhs& lhs, Rhs& rhs)
{
if (Index>0)
rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose())
.cwise()*(rhs.template segment<Index>(S))).sum();
if(!(Mode & UnitDiagBit))
rhs.coeffRef(I) /= lhs.coeff(I,I);
ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
}
};
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
static void run(const Lhs& lhs, Rhs& rhs) {}
};
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
static void run(const Lhs& lhs, Rhs& rhs)
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
};
/***************************************************************************
* TriangularView methods
***************************************************************************/
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other /** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
* *
* \nonstableyet * \nonstableyet
@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here. * This function will const_cast it, so constness isn't honored here.
* *
* See MatrixBase:solveTriangular() for the details. * See TriangularView:solve() for the details.
*/ */
template<typename MatrixType, unsigned int Mode> template<typename MatrixType, unsigned int Mode>
template<typename RhsDerived> template<typename RhsDerived>
@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
* can be done by marked(), and that is automatically the case with expressions such as those returned * can be done by marked(), and that is automatically the case with expressions such as those returned
* by extract(). * by extract().
* *
* \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one)
*
* Example: \include MatrixBase_marked.cpp * Example: \include MatrixBase_marked.cpp
* Output: \verbinclude MatrixBase_marked.out * Output: \verbinclude MatrixBase_marked.out
* *
@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
* *
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.: * \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
* \code * \code
* M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose()); * M * T^1 <=> T.transpose().solveInPlace(M.transpose());
* \endcode * \endcode
* *
* \sa solveTriangularInPlace() * \sa TriangularView::solveInPlace()
*/ */
template<typename Derived, unsigned int Mode> template<typename Derived, unsigned int Mode>
template<typename RhsDerived> template<typename RhsDerived>

View File

@ -86,7 +86,17 @@ template<typename MatrixType> void triangular(const MatrixType& m)
while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>(); while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>();
Transpose<MatrixType> trm4(m4); Transpose<MatrixType> trm4(m4);
// test back and forward subsitution // test back and forward subsitution with a vector as the rhs
m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>();
VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(v2.isApprox(m3 * (m1.template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>();
VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
// test back and forward subsitution with a matrix as the rhs
m3 = m1.template triangularView<Eigen::UpperTriangular>(); m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps)); VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>(); m3 = m1.template triangularView<Eigen::LowerTriangular>();