Improve performance of row-major-dense-matrix * vector products for recent CPUs.

This revised version does not bother about aligned loads/stores,
and rather processes 8 rows at ones for better instruction pipelining.
This commit is contained in:
Gael Guennebaud 2016-12-05 13:02:01 +01:00
parent 3abc827354
commit e3f613cbd4

View File

@ -242,253 +242,161 @@ EIGEN_DONT_INLINE static void run(
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run( EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
Index rows, Index cols, Index rows, Index cols,
const LhsMapper& lhs, const LhsMapper& alhs,
const RhsMapper& rhs, const RhsMapper& rhs,
ResScalar* res, Index resIncr, ResScalar* res, Index resIncr,
ResScalar alpha) ResScalar alpha)
{ {
// The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate propoer code.
LhsMapper lhs(alhs);
eigen_internal_assert(rhs.stride()==1); eigen_internal_assert(rhs.stride()==1);
#ifdef _EIGEN_ACCUMULATE_PACKETS
#error _EIGEN_ACCUMULATE_PACKETS has already been defined
#endif
#define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\
RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \
ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \
ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \
ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); }
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
typedef typename LhsMapper::VectorMapper LhsScalars; // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
// processing 8 rows at once might be counter productive wrt cache.
const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7;
const Index n4 = rows-3;
const Index n2 = rows-1;
enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 }; // TODO: for padded aligned inputs, we could enable aligned reads
const Index rowsAtOnce = 4; enum { LhsAlignment = Unaligned };
const Index peels = 2;
const Index RhsPacketAlignedMask = RhsPacketSize-1;
const Index LhsPacketAlignedMask = LhsPacketSize-1;
const Index depth = cols;
const Index lhsStride = lhs.stride();
// How many coeffs of the result do we have to skip to be aligned. Index i=0;
// Here we assume data are at least aligned on the base scalar type for(; i<n8; i+=8)
// if that's not the case then vectorization is discarded, see below.
Index alignedStart = rhs.firstAligned(depth);
Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
Index alignmentPattern = alignmentStep==0 ? AllAligned
: alignmentStep==(LhsPacketSize/2) ? EvenAligned
: FirstAligned;
// we cannot assume the first element is aligned because of sub-matrices
const Index lhsAlignmentOffset = lhs.firstAligned(depth);
const Index rhsAlignmentOffset = rhs.firstAligned(rows);
// find how many rows do we have to skip to be aligned with rhs (if possible)
Index skipRows = 0;
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) ||
(lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
(rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
{ {
alignedSize = 0; ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
alignedStart = 0; c1 = pset1<ResPacket>(ResScalar(0)),
alignmentPattern = NoneAligned; c2 = pset1<ResPacket>(ResScalar(0)),
} c3 = pset1<ResPacket>(ResScalar(0)),
else if(LhsPacketSize > 4) c4 = pset1<ResPacket>(ResScalar(0)),
{ c5 = pset1<ResPacket>(ResScalar(0)),
// TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4. c6 = pset1<ResPacket>(ResScalar(0)),
alignmentPattern = NoneAligned; c7 = pset1<ResPacket>(ResScalar(0));
}
else if (LhsPacketSize>1)
{
// eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || depth<LhsPacketSize);
while (skipRows<LhsPacketSize && Index j=0;
alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize)) for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
++skipRows;
if (skipRows==LhsPacketSize)
{ {
// nothing can be aligned, no need to skip any column RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
alignmentPattern = NoneAligned;
skipRows = 0; c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4);
c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5);
c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6);
c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7);
} }
else ResScalar cc0 = predux(c0);
ResScalar cc1 = predux(c1);
ResScalar cc2 = predux(c2);
ResScalar cc3 = predux(c3);
ResScalar cc4 = predux(c4);
ResScalar cc5 = predux(c5);
ResScalar cc6 = predux(c6);
ResScalar cc7 = predux(c7);
for(; j<cols; ++j)
{ {
skipRows = (std::min)(skipRows,Index(rows)); RhsScalar b0 = rhs(j,0);
// note that the skiped columns are processed later.
cc0 += cj.pmul(lhs(i+0,j), b0);
cc1 += cj.pmul(lhs(i+1,j), b0);
cc2 += cj.pmul(lhs(i+2,j), b0);
cc3 += cj.pmul(lhs(i+3,j), b0);
cc4 += cj.pmul(lhs(i+4,j), b0);
cc5 += cj.pmul(lhs(i+5,j), b0);
cc6 += cj.pmul(lhs(i+6,j), b0);
cc7 += cj.pmul(lhs(i+7,j), b0);
} }
/* eigen_internal_assert( alignmentPattern==NoneAligned res[(i+0)*resIncr] += alpha*cc0;
|| LhsPacketSize==1 res[(i+1)*resIncr] += alpha*cc1;
|| (skipRows + rowsAtOnce >= rows) res[(i+2)*resIncr] += alpha*cc2;
|| LhsPacketSize > depth res[(i+3)*resIncr] += alpha*cc3;
|| (size_t(firstLhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0);*/ res[(i+4)*resIncr] += alpha*cc4;
res[(i+5)*resIncr] += alpha*cc5;
res[(i+6)*resIncr] += alpha*cc6;
res[(i+7)*resIncr] += alpha*cc7;
} }
else if(Vectorizable) for(; i<n4; i+=4)
{ {
alignedStart = 0; ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
alignedSize = depth; c1 = pset1<ResPacket>(ResScalar(0)),
alignmentPattern = AllAligned; c2 = pset1<ResPacket>(ResScalar(0)),
c3 = pset1<ResPacket>(ResScalar(0));
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
}
ResScalar cc0 = predux(c0);
ResScalar cc1 = predux(c1);
ResScalar cc2 = predux(c2);
ResScalar cc3 = predux(c3);
for(; j<cols; ++j)
{
RhsScalar b0 = rhs(j,0);
cc0 += cj.pmul(lhs(i+0,j), b0);
cc1 += cj.pmul(lhs(i+1,j), b0);
cc2 += cj.pmul(lhs(i+2,j), b0);
cc3 += cj.pmul(lhs(i+3,j), b0);
}
res[(i+0)*resIncr] += alpha*cc0;
res[(i+1)*resIncr] += alpha*cc1;
res[(i+2)*resIncr] += alpha*cc2;
res[(i+3)*resIncr] += alpha*cc3;
} }
for(; i<n2; i+=2)
const Index offset1 = (FirstAligned && alignmentStep==1)?3:1;
const Index offset3 = (FirstAligned && alignmentStep==1)?1:3;
Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
{ {
// FIXME: what is the purpose of this EIGEN_ALIGN_DEFAULT ?? ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0); c1 = pset1<ResPacket>(ResScalar(0));
ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
// this helps the compiler generating good binary code Index j=0;
const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0), for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0);
if (Vectorizable)
{ {
/* explicit vectorization */ RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
// process initial unaligned coeffs c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
// FIXME this loop get vectorized by the compiler ! c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
for (Index j=0; j<alignedStart; ++j)
{
RhsScalar b = rhs(j, 0);
tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
}
if (alignedSize>alignedStart)
{
switch(alignmentPattern)
{
case AllAligned:
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned);
break;
case EvenAligned:
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned);
break;
case FirstAligned:
{
Index j = alignedStart;
if (peels>1)
{
/* Here we proccess 4 rows with with two peeled iterations to hide
* the overhead of unaligned loads. Moreover unaligned loads are handled
* using special shift/move operations between the two aligned packets
* overlaping the desired unaligned packet. This is *much* more efficient
* than basic unaligned loads.
*/
LhsPacket A01, A02, A03, A11, A12, A13;
A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
for (; j<peeledSize; j+=peels*RhsPacketSize)
{
RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0);
A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0);
ptmp1 = pcj.pmadd(A01, b, ptmp1);
A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
ptmp2 = pcj.pmadd(A02, b, ptmp2);
A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
ptmp3 = pcj.pmadd(A03, b, ptmp3);
A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0);
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0);
ptmp1 = pcj.pmadd(A11, b, ptmp1);
ptmp2 = pcj.pmadd(A12, b, ptmp2);
ptmp3 = pcj.pmadd(A13, b, ptmp3);
}
}
for (; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned);
break;
}
default:
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned);
break;
}
tmp0 += predux(ptmp0);
tmp1 += predux(ptmp1);
tmp2 += predux(ptmp2);
tmp3 += predux(ptmp3);
}
} // end explicit vectorization
// process remaining coeffs (or all if no explicit vectorization)
// FIXME this loop get vectorized by the compiler !
for (Index j=alignedSize; j<depth; ++j)
{
RhsScalar b = rhs(j, 0);
tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
} }
res[i*resIncr] += alpha*tmp0; ResScalar cc0 = predux(c0);
res[(i+offset1)*resIncr] += alpha*tmp1; ResScalar cc1 = predux(c1);
res[(i+2)*resIncr] += alpha*tmp2; for(; j<cols; ++j)
res[(i+offset3)*resIncr] += alpha*tmp3; {
RhsScalar b0 = rhs(j,0);
cc0 += cj.pmul(lhs(i+0,j), b0);
cc1 += cj.pmul(lhs(i+1,j), b0);
}
res[(i+0)*resIncr] += alpha*cc0;
res[(i+1)*resIncr] += alpha*cc1;
} }
for(; i<rows; ++i)
// process remaining first and last rows (at most columnsAtOnce-1)
Index end = rows;
Index start = rowBound;
do
{ {
for (Index i=start; i<end; ++i) ResPacket c0 = pset1<ResPacket>(ResScalar(0));
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{ {
EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0); RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0);
ResPacket ptmp0 = pset1<ResPacket>(tmp0); c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
const LhsScalars lhs0 = lhs.getVectorMapper(i, 0);
// process first unaligned result's coeffs
// FIXME this loop get vectorized by the compiler !
for (Index j=0; j<alignedStart; ++j)
tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
if (alignedSize>alignedStart)
{
// process aligned rhs coeffs
if (lhs0.template aligned<LhsPacket>(alignedStart))
for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
else
for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
tmp0 += predux(ptmp0);
}
// process remaining scalars
// FIXME this loop get vectorized by the compiler !
for (Index j=alignedSize; j<depth; ++j)
tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
res[i*resIncr] += alpha*tmp0;
} }
if (skipRows) ResScalar cc0 = predux(c0);
for(; j<cols; ++j)
{ {
start = 0; cc0 += cj.pmul(lhs(i,j), rhs(j,0));
end = skipRows;
skipRows = 0;
} }
else res[i*resIncr] += alpha*cc0;
break; }
} while(Vectorizable); #endif
#undef _EIGEN_ACCUMULATE_PACKETS
} }
} // end namespace internal } // end namespace internal