commit woking versions of triangular solvers naturally

handling conjuagted expression. still have to bench whether it
is faster (runtime and compile time) to directly call the
cache friendly functions, whence all the commented piece of code...
This commit is contained in:
Gael Guennebaud 2009-07-09 23:59:18 +02:00
parent fa60c72398
commit 8885d56928
3 changed files with 64 additions and 110 deletions

View File

@ -129,6 +129,19 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
{ return - Base::extractScalarFactor(x._expression()); }
};
// pop opposite
template<typename NestedXpr> struct ei_product_factor_traits<NestByValue<NestedXpr> >
: ei_product_factor_traits<NestedXpr>
{
typedef typename NestedXpr::Scalar Scalar;
typedef ei_product_factor_traits<NestedXpr> Base;
typedef NestByValue<NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast<const NestedXpr&>(x)); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return Base::extractScalarFactor(static_cast<const NestedXpr&>(x)); }
};
/* Helper class to determine the type of the product, can be either:
* - NormalProduct
* - CacheFriendlyProduct

View File

@ -43,123 +43,49 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/
{
typedef typename Rhs::Scalar Scalar;
static void run(const Lhs& lhs, Rhs& other)
{std::cerr << "here\n";
#if NOTDEF
const bool IsLowerTriangular = (UpLo==LowerTriangular);
{//std::cerr << "row maj " << ConjugateLhs << " , " << ConjugateRhs
// << " " << typeid(Lhs).name() << "\n";
static const int PanelWidth = 40; // TODO make this a user definable constant
static const bool IsLowerTriangular = (UpLo==LowerTriangular);
const int size = lhs.cols();
for(int c=0 ; c<other.cols() ; ++c)
{
const int PanelWidth = 4;
for(int pi=IsLowerTriangular ? 0 : size;
IsLowerTriangular ? pi<size : pi>0;
IsLowerTriangular ? pi+=PanelWidth : pi-=PanelWidth)
{
int actualPanelWidth = std::min(IsLowerTriangular ? size - pi : pi, PanelWidth);
int startBlock = IsLowerTriangular ? pi : pi-actualPanelWidth;
int endBlock = IsLowerTriangular ? pi + actualPanelWidth : 0;
if (pi > 0)
int r = IsLowerTriangular ? pi : size - pi; // remaining size
if (r > 0)
{
int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size
ei_cache_friendly_product_colmajor_times_vector<false,false>(
r,
&(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(),
other.col(c).segment(startBlock, actualPanelWidth),
&(other.coeffRef(endBlock, c)),
Scalar(-1));
int startRow = IsLowerTriangular ? pi : pi-actualPanelWidth;
int startCol = IsLowerTriangular ? 0 : pi;
// Block<Rhs,Dynamic,1> target(other,startRow,c,actualPanelWidth,1);
// ei_cache_friendly_product_rowmajor_times_vector<ConjugateLhs,ConjugateRhs>(
// &(lhs.const_cast_derived().coeffRef(startRow,startCol)), lhs.stride(),
// &(other.coeffRef(startCol, c)), r,
// target, Scalar(-1));
other.col(c).segment(startRow,actualPanelWidth) -=
lhs.block(startRow,startCol,actualPanelWidth,r)
* other.col(c).segment(startCol,r);
}
for(int k=0; k<actualPanelWidth; ++k)
{
int i = IsLowerTriangular ? pi+k : pi-k-1;
int s = IsLowerTriangular ? pi : i+1;
if (k>0)
other.coeffRef(i,c) -= ((lhs.row(i).segment(s,k).transpose())
.cwise()*(other.col(c).segment(s,k))).sum();
if(!(Mode & UnitDiagBit))
other.coeffRef(i,c) /= lhs.coeff(i,i);
int r = actualPanelWidth - k - 1; // remaining size
if (r>0)
{
other.col(c).segment((IsLowerTriangular ? i+1 : i-r), r) -=
other.coeffRef(i,c)
* Block<Lhs,Dynamic,1>(lhs, (IsLowerTriangular ? i+1 : i-r), i, r, 1);
}
}
}
}
#else
const bool IsLowerTriangular = (UpLo==LowerTriangular);
const int size = lhs.cols();
/* We perform the inverse product per block of 4 rows such that we perfectly match
* our optimized matrix * vector product. blockyStart represents the number of rows
* we have process first using the non-block version.
*/
int blockyStart = (std::max(size-5,0)/4)*4;
if (IsLowerTriangular)
blockyStart = size - blockyStart;
else
blockyStart -= 1;
for(int c=0 ; c<other.cols() ; ++c)
{
// process first rows using the non block version
if(!(Mode & UnitDiagBit))
{
if (IsLowerTriangular)
other.coeffRef(0,c) = other.coeff(0,c)/lhs.coeff(0, 0);
else
other.coeffRef(size-1,c) = other.coeff(size-1, c)/lhs.coeff(size-1, size-1);
}
for(int i=(IsLowerTriangular ? 1 : size-2); IsLowerTriangular ? i<blockyStart : i>blockyStart; i += (IsLowerTriangular ? 1 : -1) )
{
Scalar tmp = other.coeff(i,c)
- (IsLowerTriangular ? ((lhs.row(i).start(i)) * other.col(c).start(i)).coeff(0,0)
: ((lhs.row(i).end(size-i-1)) * other.col(c).end(size-i-1)).coeff(0,0));
if (Mode & UnitDiagBit)
other.coeffRef(i,c) = tmp;
else
other.coeffRef(i,c) = tmp/lhs.coeff(i,i);
}
// now let's process the remaining rows 4 at once
for(int i=blockyStart; IsLowerTriangular ? i<size : i>0; )
{
int startBlock = i;
int endBlock = startBlock + (IsLowerTriangular ? 4 : -4);
/* Process the i cols times 4 rows block, and keep the result in a temporary vector */
// FIXME use fixed size block but take care to small fixed size matrices...
Matrix<Scalar,Dynamic,1> btmp(4);
if (IsLowerTriangular)
btmp = lhs.block(startBlock,0,4,i) * other.col(c).start(i);
else
btmp = lhs.block(i-3,i+1,4,size-1-i) * other.col(c).end(size-1-i);
/* Let's process the 4x4 sub-matrix as usual.
* btmp stores the diagonal coefficients used to update the remaining part of the result.
*/
{
Scalar tmp = other.coeff(startBlock,c)-btmp.coeff(IsLowerTriangular?0:3);
if (Mode & UnitDiagBit)
other.coeffRef(i,c) = tmp;
else
other.coeffRef(i,c) = tmp/lhs.coeff(i,i);
}
i += IsLowerTriangular ? 1 : -1;
for (;IsLowerTriangular ? i<endBlock : i>endBlock; i += IsLowerTriangular ? 1 : -1)
{
int remainingSize = IsLowerTriangular ? i-startBlock : startBlock-i;
Scalar tmp = other.coeff(i,c)
- btmp.coeff(IsLowerTriangular ? remainingSize : 3-remainingSize)
- ( lhs.row(i).segment(IsLowerTriangular ? startBlock : i+1, remainingSize)
* other.col(c).segment(IsLowerTriangular ? startBlock : i+1, remainingSize)).coeff(0,0);
if (Mode & UnitDiagBit)
other.coeffRef(i,c) = tmp;
else
other.coeffRef(i,c) = tmp/lhs.coeff(i,i);
}
}
}
#endif
}
};
@ -168,15 +94,15 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/
// - inv(LowerTriangular,UnitDiag,ColMajor) * Column vector
// - inv(UpperTriangular, ColMajor) * Column vector
// - inv(UpperTriangular,UnitDiag,ColMajor) * Column vector
template<typename Lhs, typename Rhs, int Mode, int UpLo>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
template<typename Lhs, typename Rhs, int Mode, /*bool ConjugateLhs, bool ConjugateRhs,*/ int UpLo>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,/*ConjugateLhs,ConjugateRhs,*/UpLo,ColMajor>
{
typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet;
enum { PacketSize = ei_packet_traits<Scalar>::size };
static void run(const Lhs& lhs, Rhs& other)
{
{//std::cerr << "col maj " << ConjugateLhs << " , " << ConjugateRhs << "\n";
static const int PanelWidth = 4; // TODO make this a user definable constant
static const bool IsLowerTriangular = (UpLo==LowerTriangular);
const int size = lhs.cols();
@ -207,12 +133,16 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size
if (r > 0)
{
ei_cache_friendly_product_colmajor_times_vector<false,false>(
r,
&(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(),
other.col(c).segment(startBlock, actualPanelWidth),
&(other.coeffRef(endBlock, c)),
Scalar(-1));
// ei_cache_friendly_product_colmajor_times_vector<ConjugateLhs,ConjugateRhs>(
// r,
// &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(),
// other.col(c).segment(startBlock, actualPanelWidth),
// &(other.coeffRef(endBlock, c)),
// Scalar(-1));
other.col(c).segment(endBlock,r) -=
lhs.block(endBlock,startBlock,r,actualPanelWidth)
* other.col(c).segment(startBlock,actualPanelWidth);
}
}
}
@ -238,13 +168,21 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
ei_assert(!(Mode & ZeroDiagBit));
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
// typedef ei_product_factor_traits<MatrixType> LhsProductTraits;
// typedef ei_product_factor_traits<RhsDerived> RhsProductTraits;
// typedef typename LhsProductTraits::ActualXprType ActualLhsType;
// typedef typename RhsProductTraits::ActualXprType ActualRhsType;
// const ActualLhsType& actualLhs = LhsProductTraits::extract(_expression());
// ActualRhsType& actualRhs = const_cast<ActualRhsType&>(RhsProductTraits::extract(rhs));
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
// std::cerr << typeid(MatrixType).name() << "\n";
typedef typename ei_meta_if<copy,
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
RhsCopy rhsCopy(rhs);
ei_triangular_solver_selector<MatrixType, typename ei_unref<RhsCopy>::type, Mode>::run(_expression(), rhsCopy);
ei_triangular_solver_selector<MatrixType, typename ei_unref<RhsCopy>::type,
Mode/*, LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate*/>::run(_expression(), rhsCopy);
if (copy)
rhs = rhsCopy;

View File

@ -307,8 +307,11 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
skipRows = std::min(skipRows,res.size());
// note that the skiped columns are processed later.
}
ei_internal_assert((alignmentPattern==NoneAligned) || PacketSize==1
|| (size_t(lhs+alignedStart+lhsStride*skipRows)%sizeof(Packet))==0);
ei_internal_assert( alignmentPattern==NoneAligned
|| PacketSize==1
|| (skipRows + rowsAtOnce >= res.size())
|| PacketSize > rhsSize
|| (size_t(lhs+alignedStart+lhsStride*skipRows)%sizeof(Packet))==0);
}
int offset1 = (FirstAligned && alignmentStep==1?3:1);