mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
add a meta unroller for the triangular solver (only for vectors as rhs)
This commit is contained in:
parent
1a1b2e9f27
commit
b47dea8b7a
@ -26,28 +26,25 @@
|
|||||||
#define EIGEN_SOLVETRIANGULAR_H
|
#define EIGEN_SOLVETRIANGULAR_H
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs,
|
template<typename Lhs, typename Rhs,
|
||||||
int Mode, // Upper/Lower | UnitDiag
|
int Mode, // can be Upper/Lower | UnitDiag
|
||||||
int UpLo = (Mode & LowerTriangularBit)
|
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
|
||||||
? LowerTriangular
|
? CompleteUnrolling : NoUnrolling,
|
||||||
: (Mode & UpperTriangularBit)
|
|
||||||
? UpperTriangular
|
|
||||||
: -1,
|
|
||||||
int StorageOrder = int(Lhs::Flags) & RowMajorBit
|
int StorageOrder = int(Lhs::Flags) & RowMajorBit
|
||||||
>
|
>
|
||||||
struct ei_triangular_solver_selector;
|
struct ei_triangular_solver_selector;
|
||||||
|
|
||||||
// forward substitution, row-major
|
// forward substitution, row-major
|
||||||
template<typename Lhs, typename Rhs, int Mode, int UpLo>
|
template<typename Lhs, typename Rhs, int Mode>
|
||||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
|
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
|
||||||
{
|
{
|
||||||
typedef typename Rhs::Scalar Scalar;
|
typedef typename Rhs::Scalar Scalar;
|
||||||
typedef ei_product_factor_traits<Lhs> LhsProductTraits;
|
typedef ei_product_factor_traits<Lhs> LhsProductTraits;
|
||||||
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
||||||
enum {
|
enum {
|
||||||
IsLowerTriangular = (UpLo==LowerTriangular)
|
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
|
||||||
};
|
};
|
||||||
static void run(const Lhs& lhs, Rhs& other)
|
static void run(const Lhs& lhs, Rhs& other)
|
||||||
{//std::cerr << "row maj " << LhsProductTraits::NeedToConjugate << "\n";
|
{
|
||||||
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
|
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
|
||||||
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
||||||
|
|
||||||
@ -90,12 +87,12 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,RowMajor>
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Implements the following configurations:
|
// Implements the following configurations:
|
||||||
// - inv(LowerTriangular, ColMajor) * Column vector
|
// - inv(LowerTriangular, ColMajor) * Column vectors
|
||||||
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
|
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vectors
|
||||||
// - inv(UpperTriangular, ColMajor) * Column vector
|
// - inv(UpperTriangular, ColMajor) * Column vectors
|
||||||
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
|
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vectors
|
||||||
template<typename Lhs, typename Rhs, int Mode, int UpLo>
|
template<typename Lhs, typename Rhs, int Mode>
|
||||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
|
||||||
{
|
{
|
||||||
typedef typename Rhs::Scalar Scalar;
|
typedef typename Rhs::Scalar Scalar;
|
||||||
typedef typename ei_packet_traits<Scalar>::type Packet;
|
typedef typename ei_packet_traits<Scalar>::type Packet;
|
||||||
@ -103,11 +100,11 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
|||||||
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
||||||
enum {
|
enum {
|
||||||
PacketSize = ei_packet_traits<Scalar>::size,
|
PacketSize = ei_packet_traits<Scalar>::size,
|
||||||
IsLowerTriangular = (UpLo==LowerTriangular)
|
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit)
|
||||||
};
|
};
|
||||||
|
|
||||||
static void run(const Lhs& lhs, Rhs& other)
|
static void run(const Lhs& lhs, Rhs& other)
|
||||||
{//std::cerr << "col maj " << LhsProductTraits::NeedToConjugate << "\n";
|
{
|
||||||
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
|
static const int PanelWidth = EIGEN_TUNE_TRSV_PANEL_WIDTH;
|
||||||
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
||||||
|
|
||||||
@ -154,6 +151,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/***************************************************************************
|
||||||
|
* meta-unrolling implementation
|
||||||
|
***************************************************************************/
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Mode, int Index, int Size,
|
||||||
|
bool Stop = Index==Size>
|
||||||
|
struct ei_triangular_solver_unroller;
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
|
||||||
|
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
|
||||||
|
enum {
|
||||||
|
IsLowerTriangular = ((Mode&LowerTriangularBit)==LowerTriangularBit),
|
||||||
|
I = IsLowerTriangular ? Index : Size - Index - 1,
|
||||||
|
S = IsLowerTriangular ? 0 : I+1
|
||||||
|
};
|
||||||
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
|
{
|
||||||
|
if (Index>0)
|
||||||
|
rhs.coeffRef(I) -= ((lhs.row(I).template segment<Index>(S).transpose())
|
||||||
|
.cwise()*(rhs.template segment<Index>(S))).sum();
|
||||||
|
|
||||||
|
if(!(Mode & UnitDiagBit))
|
||||||
|
rhs.coeffRef(I) /= lhs.coeff(I,I);
|
||||||
|
|
||||||
|
ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
|
||||||
|
struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
|
||||||
|
static void run(const Lhs& lhs, Rhs& rhs) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||||
|
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
|
||||||
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
|
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
||||||
|
};
|
||||||
|
|
||||||
|
/***************************************************************************
|
||||||
|
* TriangularView methods
|
||||||
|
***************************************************************************/
|
||||||
|
|
||||||
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
|
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
|
||||||
*
|
*
|
||||||
* \nonstableyet
|
* \nonstableyet
|
||||||
@ -161,7 +201,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
|
|||||||
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
|
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
|
||||||
* This function will const_cast it, so constness isn't honored here.
|
* This function will const_cast it, so constness isn't honored here.
|
||||||
*
|
*
|
||||||
* See MatrixBase:solveTriangular() for the details.
|
* See TriangularView:solve() for the details.
|
||||||
*/
|
*/
|
||||||
template<typename MatrixType, unsigned int Mode>
|
template<typename MatrixType, unsigned int Mode>
|
||||||
template<typename RhsDerived>
|
template<typename RhsDerived>
|
||||||
@ -198,8 +238,6 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
|||||||
* can be done by marked(), and that is automatically the case with expressions such as those returned
|
* can be done by marked(), and that is automatically the case with expressions such as those returned
|
||||||
* by extract().
|
* by extract().
|
||||||
*
|
*
|
||||||
* \addexample SolveTriangular \label How to solve a triangular system (aka. how to multiply the inverse of a triangular matrix by another one)
|
|
||||||
*
|
|
||||||
* Example: \include MatrixBase_marked.cpp
|
* Example: \include MatrixBase_marked.cpp
|
||||||
* Output: \verbinclude MatrixBase_marked.out
|
* Output: \verbinclude MatrixBase_marked.out
|
||||||
*
|
*
|
||||||
@ -213,10 +251,10 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
|||||||
*
|
*
|
||||||
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
|
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
|
||||||
* \code
|
* \code
|
||||||
* M * T^1 <=> T.transpose().solveTriangularInPlace(M.transpose());
|
* M * T^1 <=> T.transpose().solveInPlace(M.transpose());
|
||||||
* \endcode
|
* \endcode
|
||||||
*
|
*
|
||||||
* \sa solveTriangularInPlace()
|
* \sa TriangularView::solveInPlace()
|
||||||
*/
|
*/
|
||||||
template<typename Derived, unsigned int Mode>
|
template<typename Derived, unsigned int Mode>
|
||||||
template<typename RhsDerived>
|
template<typename RhsDerived>
|
||||||
|
@ -86,7 +86,17 @@ template<typename MatrixType> void triangular(const MatrixType& m)
|
|||||||
while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>();
|
while (ei_abs2(m1(i,i))<1e-3) m1(i,i) = ei_random<Scalar>();
|
||||||
|
|
||||||
Transpose<MatrixType> trm4(m4);
|
Transpose<MatrixType> trm4(m4);
|
||||||
// test back and forward subsitution
|
// test back and forward subsitution with a vector as the rhs
|
||||||
|
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
||||||
|
VERIFY(v2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
|
||||||
|
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
||||||
|
VERIFY(v2.isApprox(m3.transpose() * (m1.transpose().template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
|
||||||
|
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
||||||
|
VERIFY(v2.isApprox(m3 * (m1.template triangularView<Eigen::UpperTriangular>().solve(v2)), largerEps));
|
||||||
|
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
||||||
|
VERIFY(v2.isApprox(m3.conjugate() * (m1.conjugate().template triangularView<Eigen::LowerTriangular>().solve(v2)), largerEps));
|
||||||
|
|
||||||
|
// test back and forward subsitution with a matrix as the rhs
|
||||||
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
m3 = m1.template triangularView<Eigen::UpperTriangular>();
|
||||||
VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps));
|
VERIFY(m2.isApprox(m3.adjoint() * (m1.adjoint().template triangularView<Eigen::LowerTriangular>().solve(m2)), largerEps));
|
||||||
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
m3 = m1.template triangularView<Eigen::LowerTriangular>();
|
||||||
|
Loading…
Reference in New Issue
Block a user