Improve performance of row-major-dense-matrix * vector products for recent CPUs.

This revised version does not bother about aligned loads/stores,
and rather processes 8 rows at ones for better instruction pipelining.
This commit is contained in:
Gael Guennebaud 2016-12-05 13:02:01 +01:00
parent 3abc827354
commit e3f613cbd4

View File

@ -242,253 +242,161 @@ EIGEN_DONT_INLINE static void run(
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
Index rows, Index cols,
const LhsMapper& lhs,
const LhsMapper& alhs,
const RhsMapper& rhs,
ResScalar* res, Index resIncr,
ResScalar alpha)
{
// The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate propoer code.
LhsMapper lhs(alhs);
eigen_internal_assert(rhs.stride()==1);
#ifdef _EIGEN_ACCUMULATE_PACKETS
#error _EIGEN_ACCUMULATE_PACKETS has already been defined
#endif
#define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\
RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \
ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \
ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \
ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); }
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
typedef typename LhsMapper::VectorMapper LhsScalars;
// TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
// processing 8 rows at once might be counter productive wrt cache.
const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7;
const Index n4 = rows-3;
const Index n2 = rows-1;
enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
const Index rowsAtOnce = 4;
const Index peels = 2;
const Index RhsPacketAlignedMask = RhsPacketSize-1;
const Index LhsPacketAlignedMask = LhsPacketSize-1;
const Index depth = cols;
const Index lhsStride = lhs.stride();
// TODO: for padded aligned inputs, we could enable aligned reads
enum { LhsAlignment = Unaligned };
// How many coeffs of the result do we have to skip to be aligned.
// Here we assume data are at least aligned on the base scalar type
// if that's not the case then vectorization is discarded, see below.
Index alignedStart = rhs.firstAligned(depth);
Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
Index alignmentPattern = alignmentStep==0 ? AllAligned
: alignmentStep==(LhsPacketSize/2) ? EvenAligned
: FirstAligned;
// we cannot assume the first element is aligned because of sub-matrices
const Index lhsAlignmentOffset = lhs.firstAligned(depth);
const Index rhsAlignmentOffset = rhs.firstAligned(rows);
// find how many rows do we have to skip to be aligned with rhs (if possible)
Index skipRows = 0;
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) ||
(lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
(rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
Index i=0;
for(; i<n8; i+=8)
{
alignedSize = 0;
alignedStart = 0;
alignmentPattern = NoneAligned;
}
else if(LhsPacketSize > 4)
{
// TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4.
alignmentPattern = NoneAligned;
}
else if (LhsPacketSize>1)
{
// eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || depth<LhsPacketSize);
ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
c1 = pset1<ResPacket>(ResScalar(0)),
c2 = pset1<ResPacket>(ResScalar(0)),
c3 = pset1<ResPacket>(ResScalar(0)),
c4 = pset1<ResPacket>(ResScalar(0)),
c5 = pset1<ResPacket>(ResScalar(0)),
c6 = pset1<ResPacket>(ResScalar(0)),
c7 = pset1<ResPacket>(ResScalar(0));
while (skipRows<LhsPacketSize &&
alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
++skipRows;
if (skipRows==LhsPacketSize)
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
// nothing can be aligned, no need to skip any column
alignmentPattern = NoneAligned;
skipRows = 0;
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4);
c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5);
c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6);
c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7);
}
else
ResScalar cc0 = predux(c0);
ResScalar cc1 = predux(c1);
ResScalar cc2 = predux(c2);
ResScalar cc3 = predux(c3);
ResScalar cc4 = predux(c4);
ResScalar cc5 = predux(c5);
ResScalar cc6 = predux(c6);
ResScalar cc7 = predux(c7);
for(; j<cols; ++j)
{
skipRows = (std::min)(skipRows,Index(rows));
// note that the skiped columns are processed later.
RhsScalar b0 = rhs(j,0);
cc0 += cj.pmul(lhs(i+0,j), b0);
cc1 += cj.pmul(lhs(i+1,j), b0);
cc2 += cj.pmul(lhs(i+2,j), b0);
cc3 += cj.pmul(lhs(i+3,j), b0);
cc4 += cj.pmul(lhs(i+4,j), b0);
cc5 += cj.pmul(lhs(i+5,j), b0);
cc6 += cj.pmul(lhs(i+6,j), b0);
cc7 += cj.pmul(lhs(i+7,j), b0);
}
/* eigen_internal_assert( alignmentPattern==NoneAligned
|| LhsPacketSize==1
|| (skipRows + rowsAtOnce >= rows)
|| LhsPacketSize > depth
|| (size_t(firstLhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0);*/
res[(i+0)*resIncr] += alpha*cc0;
res[(i+1)*resIncr] += alpha*cc1;
res[(i+2)*resIncr] += alpha*cc2;
res[(i+3)*resIncr] += alpha*cc3;
res[(i+4)*resIncr] += alpha*cc4;
res[(i+5)*resIncr] += alpha*cc5;
res[(i+6)*resIncr] += alpha*cc6;
res[(i+7)*resIncr] += alpha*cc7;
}
else if(Vectorizable)
for(; i<n4; i+=4)
{
alignedStart = 0;
alignedSize = depth;
alignmentPattern = AllAligned;
ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
c1 = pset1<ResPacket>(ResScalar(0)),
c2 = pset1<ResPacket>(ResScalar(0)),
c3 = pset1<ResPacket>(ResScalar(0));
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
}
ResScalar cc0 = predux(c0);
ResScalar cc1 = predux(c1);
ResScalar cc2 = predux(c2);
ResScalar cc3 = predux(c3);
for(; j<cols; ++j)
{
RhsScalar b0 = rhs(j,0);
cc0 += cj.pmul(lhs(i+0,j), b0);
cc1 += cj.pmul(lhs(i+1,j), b0);
cc2 += cj.pmul(lhs(i+2,j), b0);
cc3 += cj.pmul(lhs(i+3,j), b0);
}
res[(i+0)*resIncr] += alpha*cc0;
res[(i+1)*resIncr] += alpha*cc1;
res[(i+2)*resIncr] += alpha*cc2;
res[(i+3)*resIncr] += alpha*cc3;
}
const Index offset1 = (FirstAligned && alignmentStep==1)?3:1;
const Index offset3 = (FirstAligned && alignmentStep==1)?1:3;
Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
for(; i<n2; i+=2)
{
// FIXME: what is the purpose of this EIGEN_ALIGN_DEFAULT ??
EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
c1 = pset1<ResPacket>(ResScalar(0));
// this helps the compiler generating good binary code
const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0),
lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0);
if (Vectorizable)
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
/* explicit vectorization */
ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
// process initial unaligned coeffs
// FIXME this loop get vectorized by the compiler !
for (Index j=0; j<alignedStart; ++j)
{
RhsScalar b = rhs(j, 0);
tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
}
if (alignedSize>alignedStart)
{
switch(alignmentPattern)
{
case AllAligned:
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned);
break;
case EvenAligned:
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned);
break;
case FirstAligned:
{
Index j = alignedStart;
if (peels>1)
{
/* Here we proccess 4 rows with with two peeled iterations to hide
* the overhead of unaligned loads. Moreover unaligned loads are handled
* using special shift/move operations between the two aligned packets
* overlaping the desired unaligned packet. This is *much* more efficient
* than basic unaligned loads.
*/
LhsPacket A01, A02, A03, A11, A12, A13;
A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
for (; j<peeledSize; j+=peels*RhsPacketSize)
{
RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0);
A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0);
ptmp1 = pcj.pmadd(A01, b, ptmp1);
A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
ptmp2 = pcj.pmadd(A02, b, ptmp2);
A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
ptmp3 = pcj.pmadd(A03, b, ptmp3);
A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0);
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0);
ptmp1 = pcj.pmadd(A11, b, ptmp1);
ptmp2 = pcj.pmadd(A12, b, ptmp2);
ptmp3 = pcj.pmadd(A13, b, ptmp3);
}
}
for (; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned);
break;
}
default:
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
_EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned);
break;
}
tmp0 += predux(ptmp0);
tmp1 += predux(ptmp1);
tmp2 += predux(ptmp2);
tmp3 += predux(ptmp3);
}
} // end explicit vectorization
// process remaining coeffs (or all if no explicit vectorization)
// FIXME this loop get vectorized by the compiler !
for (Index j=alignedSize; j<depth; ++j)
{
RhsScalar b = rhs(j, 0);
tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
}
res[i*resIncr] += alpha*tmp0;
res[(i+offset1)*resIncr] += alpha*tmp1;
res[(i+2)*resIncr] += alpha*tmp2;
res[(i+offset3)*resIncr] += alpha*tmp3;
ResScalar cc0 = predux(c0);
ResScalar cc1 = predux(c1);
for(; j<cols; ++j)
{
RhsScalar b0 = rhs(j,0);
cc0 += cj.pmul(lhs(i+0,j), b0);
cc1 += cj.pmul(lhs(i+1,j), b0);
}
res[(i+0)*resIncr] += alpha*cc0;
res[(i+1)*resIncr] += alpha*cc1;
}
// process remaining first and last rows (at most columnsAtOnce-1)
Index end = rows;
Index start = rowBound;
do
for(; i<rows; ++i)
{
for (Index i=start; i<end; ++i)
ResPacket c0 = pset1<ResPacket>(ResScalar(0));
Index j=0;
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
{
EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
ResPacket ptmp0 = pset1<ResPacket>(tmp0);
const LhsScalars lhs0 = lhs.getVectorMapper(i, 0);
// process first unaligned result's coeffs
// FIXME this loop get vectorized by the compiler !
for (Index j=0; j<alignedStart; ++j)
tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
if (alignedSize>alignedStart)
{
// process aligned rhs coeffs
if (lhs0.template aligned<LhsPacket>(alignedStart))
for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
else
for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
tmp0 += predux(ptmp0);
}
// process remaining scalars
// FIXME this loop get vectorized by the compiler !
for (Index j=alignedSize; j<depth; ++j)
tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
res[i*resIncr] += alpha*tmp0;
RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0);
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
}
if (skipRows)
ResScalar cc0 = predux(c0);
for(; j<cols; ++j)
{
start = 0;
end = skipRows;
skipRows = 0;
cc0 += cj.pmul(lhs(i,j), rhs(j,0));
}
else
break;
} while(Vectorizable);
#undef _EIGEN_ACCUMULATE_PACKETS
res[i*resIncr] += alpha*cc0;
}
#endif
}
} // end namespace internal