mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
add a meta unroller for the triangular solver (only for vectors as rhs)
This commit is contained in:
parent
1a1b2e9f27
commit
b47dea8b7a
@ -26,28 +26,25 @@
|
||||
#define EIGEN_SOLVETRIANGULAR_H
|
||||
|
||||
template<typename Lhs, typename Rhs,
|
||||
int Mode, // Upper/Lower | UnitDiag
|
||||
int UpLo = (Mode & LowerTriangularBit)
|
||||
? LowerTriangular
|
||||
: (Mode & UpperTriangularBit)
|
||||
? UpperTriangular
|
||||
: -1,
|
||||
int Mode, // can be Upper/Lower | UnitDiag
|
||||
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
|
||||
? CompleteUnrolling : NoUnrolling,
|
||||
int StorageOrder = int(Lhs::Flags) & RowMajorBit
|
||||
>
|
||||
struct ei_triangular_solver_selector;
|
||||
|
||||
// forward substitution, row-major
|
||||
template<typename Lhs, typename Rhs, int Mode, int UpLo>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef ei_product_factor_traits<Lhs> LhsProductTraits;
|
||||
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
||||
enum {
|
||||
IsLowerTriangular = (UpLo==LowerTriangular)
|
||||
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
|
||||
};
|
||||
static void run(const Lhs& lhs, Rhs& other)
|
||||
{//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n";
|
||||
{
|
||||
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
|
||||
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
||||
|
||||
@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
|
||||
};
|
||||
|
||||
// Implements the following configurations:
|
||||
// - inv(LowerTriangular, ColMajor) * Column vector
|
||||
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
|
||||
// - inv(UpperTriangular, ColMajor) * Column vector
|
||||
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
|
||||
template<typename Lhs, typename Rhs, int Mode, int UpLo>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
||||
// - inv(LowerTriangular, ColMajor) * Column vectors
|
||||
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors
|
||||
// - inv(UpperTriangular, ColMajor) * Column vectors
|
||||
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef typename ei_packet_traits<Scalar>::type Packet;
|
||||
@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
||||
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
||||
enum {
|
||||
PacketSize = ei_packet_traits<Scalar>::size,
|
||||
IsLowerTriangular = (UpLo==LowerTriangular)
|
||||
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
|
||||
};
|
||||
|
||||
static void run(const Lhs& lhs, Rhs& other)
|
||||
{//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n";
|
||||
{
|
||||
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
|
||||
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
||||
|
||||
@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
||||
}
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* meta-unrolling implementation
|
||||
***************************************************************************/
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int Index, int Size,
|
||||
bool Stop = Index==Size>
|
||||
struct ei_triangular_solver_unroller;
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
|
||||
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
|
||||
enum {
|
||||
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit),
|
||||
I = IsLowerTriangular ? Index : Size - Index - 1,
|
||||
S = IsLowerTriangular ? 0 : I+1
|
||||
};
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
if (Index>0)
|
||||
rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose())
|
||||
.cwise()*(rhs.template segment<Index>(S))).sum();
|
||||
|
||||
if(!(Mode & UnitDiagBit))
|
||||
rhs.coeffRef(I) /= lhs.coeff(I,I);
|
||||
|
||||
ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
|
||||
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
|
||||
static void run(const Lhs& lhs, Rhs& rhs) {}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* TriangularView methods
|
||||
***************************************************************************/
|
||||
|
||||
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
|
||||
*
|
||||
* \nonstableyet
|
||||
@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
||||
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
|
||||
* This function will const_cast it, so constness isn't honored here.
|
||||
*
|
||||
* See MatrixBase:solveTriangular() for the details.
|
||||
* See TriangularView:solve() for the details.
|
||||
*/
|
||||
template<typename MatrixType, unsigned int Mode>
|
||||
template<typename RhsDerived>
|
||||
@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
||||
* can be done by marked(), and that is automatically the case with expressions such as those returned
|
||||
* by extract().
|
||||
*
|
||||
* \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one)
|
||||
*
|
||||
* Example: \include MatrixBase_marked.cpp
|
||||
* Output: \verbinclude MatrixBase_marked.out
|
||||
*
|
||||
@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
||||
*
|
||||
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
|
||||
* \code
|
||||
* M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose());
|
||||
* M * T^1 <=> T.transpose().solveInPlace(M.transpose());
|
||||
* \endcode
|
||||
*
|
||||
* \sa solveTriangularInPlace()
|
||||
* \sa TriangularView::solveInPlace()
|
||||
*/
|
||||
template<typename Derived, unsigned int Mode>
|
||||
template<typename RhsDerived>
|
||||
|
@ -86,7 +86,17 @@ template<typename MatrixType> void triangular(const MatrixType& m)
|
||||
while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>();
|
||||
|
||||
Transpose<MatrixType> trm4(m4);
|
||||
// test back and forward subsitution
|
||||
// test back and forward subsitution with a vector as the rhs
|
||||
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
||||
VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
|
||||
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
||||
VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
|
||||
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
||||
VERIFY(v2.isApprox(m3 * (m1.template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
|
||||
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
||||
VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
|
||||
|
||||
// test back and forward subsitution with a matrix as the rhs
|
||||
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
||||
VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps));
|
||||
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
||||
|
Loading…
Reference in New Issue
Block a user