add a meta unroller for the triangular solver (only for vectors as rhs)

This commit is contained in:
Gael Guennebaud 2009-07-10 11:30:46 +02:00
parent 1a1b2e9f27
commit b47dea8b7a
2 changed files with 72 additions and 24 deletions

View File

@ -26,28 +26,25 @@
#define EIGEN_SOLVETRIANGULAR_H
template<typename Lhs, typename Rhs,
int Mode, // Upper/Lower | UnitDiag
int UpLo = (Mode & LowerTriangularBit)
? LowerTriangular
: (Mode & UpperTriangularBit)
? UpperTriangular
: -1,
int Mode, // can be Upper/Lower | UnitDiag
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
? CompleteUnrolling : NoUnrolling,
int StorageOrder = int(Lhs::Flags) & RowMajorBit
>
struct ei_triangular_solver_selector;
// forward substitution, row-major
template<typename Lhs, typename Rhs, int Mode, int UpLo>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef ei_product_factor_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
enum {
IsLowerTriangular = (UpLo==LowerTriangular)
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
};
static void run(const Lhs& lhs, Rhs& other)
{//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n";
{
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
};
// Implements the following configurations:
// - inv(LowerTriangular, ColMajor) * Column vector
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
// - inv(UpperTriangular, ColMajor) * Column vector
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
template<typename Lhs, typename Rhs, int Mode, int UpLo>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
// - inv(LowerTriangular, ColMajor) * Column vectors
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors
// - inv(UpperTriangular, ColMajor) * Column vectors
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors
template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
enum {
PacketSize = ei_packet_traits<Scalar>::size,
IsLowerTriangular = (UpLo==LowerTriangular)
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
};
static void run(const Lhs& lhs, Rhs& other)
{//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n";
{
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
}
};
/***************************************************************************
* meta-unrolling implementation
***************************************************************************/
template<typename Lhs, typename Rhs, int Mode, int Index, int Size,
bool Stop = Index==Size>
struct ei_triangular_solver_unroller;
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
enum {
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit),
I = IsLowerTriangular ? Index : Size - Index - 1,
S = IsLowerTriangular ? 0 : I+1
};
static void run(const Lhs& lhs, Rhs& rhs)
{
if (Index>0)
rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose())
.cwise()*(rhs.template segment<Index>(S))).sum();
if(!(Mode & UnitDiagBit))
rhs.coeffRef(I) /= lhs.coeff(I,I);
ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
}
};
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
static void run(const Lhs& lhs, Rhs& rhs) {}
};
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
static void run(const Lhs& lhs, Rhs& rhs)
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
};
/***************************************************************************
* TriangularView methods
***************************************************************************/
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
*
* \nonstableyet
@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
*
* See MatrixBase:solveTriangular() for the details.
* See TriangularView:solve() for the details.
*/
template<typename MatrixType, unsigned int Mode>
template<typename RhsDerived>
@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
* can be done by marked(), and that is automatically the case with expressions such as those returned
* by extract().
*
* \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one)
*
* Example: \include MatrixBase_marked.cpp
* Output: \verbinclude MatrixBase_marked.out
*
@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
*
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
* \code
* M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose());
* M * T^1 <=> T.transpose().solveInPlace(M.transpose());
* \endcode
*
* \sa solveTriangularInPlace()
* \sa TriangularView::solveInPlace()
*/
template<typename Derived, unsigned int Mode>
template<typename RhsDerived>

View File

@ -86,7 +86,17 @@ template<typename MatrixType> void triangular(const MatrixType& m)
while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>();
Transpose<MatrixType> trm4(m4);
// test back and forward subsitution
// test back and forward subsitution with a vector as the rhs
m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>();
VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(v2.isApprox(m3 * (m1.template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>();
VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
// test back and forward subsitution with a matrix as the rhs
m3 = m1.template triangularView<Eigen::UpperTriangular>();
VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps));
m3 = m1.template triangularView<Eigen::LowerTriangular>();