mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-24 14:45:14 +08:00
Improve performance of row-major-dense-matrix * vector products for recent CPUs.
This revised version does not bother about aligned loads/stores, and rather processes 8 rows at ones for better instruction pipelining.
This commit is contained in:
parent
3abc827354
commit
e3f613cbd4
@ -242,253 +242,161 @@ EIGEN_DONT_INLINE static void run(
|
||||
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
|
||||
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
|
||||
Index rows, Index cols,
|
||||
const LhsMapper& lhs,
|
||||
const LhsMapper& alhs,
|
||||
const RhsMapper& rhs,
|
||||
ResScalar* res, Index resIncr,
|
||||
ResScalar alpha)
|
||||
{
|
||||
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
||||
// This helps GCC to generate propoer code.
|
||||
LhsMapper lhs(alhs);
|
||||
|
||||
eigen_internal_assert(rhs.stride()==1);
|
||||
|
||||
#ifdef _EIGEN_ACCUMULATE_PACKETS
|
||||
#error _EIGEN_ACCUMULATE_PACKETS has already been defined
|
||||
#endif
|
||||
|
||||
#define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\
|
||||
RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \
|
||||
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \
|
||||
ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \
|
||||
ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \
|
||||
ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); }
|
||||
|
||||
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
|
||||
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
|
||||
|
||||
typedef typename LhsMapper::VectorMapper LhsScalars;
|
||||
// TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
|
||||
// processing 8 rows at once might be counter productive wrt cache.
|
||||
const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7;
|
||||
const Index n4 = rows-3;
|
||||
const Index n2 = rows-1;
|
||||
|
||||
enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
|
||||
const Index rowsAtOnce = 4;
|
||||
const Index peels = 2;
|
||||
const Index RhsPacketAlignedMask = RhsPacketSize-1;
|
||||
const Index LhsPacketAlignedMask = LhsPacketSize-1;
|
||||
const Index depth = cols;
|
||||
const Index lhsStride = lhs.stride();
|
||||
// TODO: for padded aligned inputs, we could enable aligned reads
|
||||
enum { LhsAlignment = Unaligned };
|
||||
|
||||
// How many coeffs of the result do we have to skip to be aligned.
|
||||
// Here we assume data are at least aligned on the base scalar type
|
||||
// if that's not the case then vectorization is discarded, see below.
|
||||
Index alignedStart = rhs.firstAligned(depth);
|
||||
Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
|
||||
const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
|
||||
|
||||
const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
|
||||
Index alignmentPattern = alignmentStep==0 ? AllAligned
|
||||
: alignmentStep==(LhsPacketSize/2) ? EvenAligned
|
||||
: FirstAligned;
|
||||
|
||||
// we cannot assume the first element is aligned because of sub-matrices
|
||||
const Index lhsAlignmentOffset = lhs.firstAligned(depth);
|
||||
const Index rhsAlignmentOffset = rhs.firstAligned(rows);
|
||||
|
||||
// find how many rows do we have to skip to be aligned with rhs (if possible)
|
||||
Index skipRows = 0;
|
||||
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
|
||||
if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) ||
|
||||
(lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
|
||||
(rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
|
||||
Index i=0;
|
||||
for(; i<n8; i+=8)
|
||||
{
|
||||
alignedSize = 0;
|
||||
alignedStart = 0;
|
||||
alignmentPattern = NoneAligned;
|
||||
}
|
||||
else if(LhsPacketSize > 4)
|
||||
{
|
||||
// TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4.
|
||||
alignmentPattern = NoneAligned;
|
||||
}
|
||||
else if (LhsPacketSize>1)
|
||||
{
|
||||
// eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || depth<LhsPacketSize);
|
||||
ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
|
||||
c1 = pset1<ResPacket>(ResScalar(0)),
|
||||
c2 = pset1<ResPacket>(ResScalar(0)),
|
||||
c3 = pset1<ResPacket>(ResScalar(0)),
|
||||
c4 = pset1<ResPacket>(ResScalar(0)),
|
||||
c5 = pset1<ResPacket>(ResScalar(0)),
|
||||
c6 = pset1<ResPacket>(ResScalar(0)),
|
||||
c7 = pset1<ResPacket>(ResScalar(0));
|
||||
|
||||
while (skipRows<LhsPacketSize &&
|
||||
alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
|
||||
++skipRows;
|
||||
if (skipRows==LhsPacketSize)
|
||||
Index j=0;
|
||||
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
|
||||
{
|
||||
// nothing can be aligned, no need to skip any column
|
||||
alignmentPattern = NoneAligned;
|
||||
skipRows = 0;
|
||||
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
|
||||
|
||||
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
|
||||
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
|
||||
c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
|
||||
c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
|
||||
c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4);
|
||||
c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5);
|
||||
c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6);
|
||||
c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7);
|
||||
}
|
||||
else
|
||||
ResScalar cc0 = predux(c0);
|
||||
ResScalar cc1 = predux(c1);
|
||||
ResScalar cc2 = predux(c2);
|
||||
ResScalar cc3 = predux(c3);
|
||||
ResScalar cc4 = predux(c4);
|
||||
ResScalar cc5 = predux(c5);
|
||||
ResScalar cc6 = predux(c6);
|
||||
ResScalar cc7 = predux(c7);
|
||||
for(; j<cols; ++j)
|
||||
{
|
||||
skipRows = (std::min)(skipRows,Index(rows));
|
||||
// note that the skiped columns are processed later.
|
||||
RhsScalar b0 = rhs(j,0);
|
||||
|
||||
cc0 += cj.pmul(lhs(i+0,j), b0);
|
||||
cc1 += cj.pmul(lhs(i+1,j), b0);
|
||||
cc2 += cj.pmul(lhs(i+2,j), b0);
|
||||
cc3 += cj.pmul(lhs(i+3,j), b0);
|
||||
cc4 += cj.pmul(lhs(i+4,j), b0);
|
||||
cc5 += cj.pmul(lhs(i+5,j), b0);
|
||||
cc6 += cj.pmul(lhs(i+6,j), b0);
|
||||
cc7 += cj.pmul(lhs(i+7,j), b0);
|
||||
}
|
||||
/* eigen_internal_assert( alignmentPattern==NoneAligned
|
||||
|| LhsPacketSize==1
|
||||
|| (skipRows + rowsAtOnce >= rows)
|
||||
|| LhsPacketSize > depth
|
||||
|| (size_t(firstLhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0);*/
|
||||
res[(i+0)*resIncr] += alpha*cc0;
|
||||
res[(i+1)*resIncr] += alpha*cc1;
|
||||
res[(i+2)*resIncr] += alpha*cc2;
|
||||
res[(i+3)*resIncr] += alpha*cc3;
|
||||
res[(i+4)*resIncr] += alpha*cc4;
|
||||
res[(i+5)*resIncr] += alpha*cc5;
|
||||
res[(i+6)*resIncr] += alpha*cc6;
|
||||
res[(i+7)*resIncr] += alpha*cc7;
|
||||
}
|
||||
else if(Vectorizable)
|
||||
for(; i<n4; i+=4)
|
||||
{
|
||||
alignedStart = 0;
|
||||
alignedSize = depth;
|
||||
alignmentPattern = AllAligned;
|
||||
ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
|
||||
c1 = pset1<ResPacket>(ResScalar(0)),
|
||||
c2 = pset1<ResPacket>(ResScalar(0)),
|
||||
c3 = pset1<ResPacket>(ResScalar(0));
|
||||
|
||||
Index j=0;
|
||||
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
|
||||
{
|
||||
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
|
||||
|
||||
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
|
||||
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
|
||||
c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
|
||||
c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
|
||||
}
|
||||
ResScalar cc0 = predux(c0);
|
||||
ResScalar cc1 = predux(c1);
|
||||
ResScalar cc2 = predux(c2);
|
||||
ResScalar cc3 = predux(c3);
|
||||
for(; j<cols; ++j)
|
||||
{
|
||||
RhsScalar b0 = rhs(j,0);
|
||||
|
||||
cc0 += cj.pmul(lhs(i+0,j), b0);
|
||||
cc1 += cj.pmul(lhs(i+1,j), b0);
|
||||
cc2 += cj.pmul(lhs(i+2,j), b0);
|
||||
cc3 += cj.pmul(lhs(i+3,j), b0);
|
||||
}
|
||||
res[(i+0)*resIncr] += alpha*cc0;
|
||||
res[(i+1)*resIncr] += alpha*cc1;
|
||||
res[(i+2)*resIncr] += alpha*cc2;
|
||||
res[(i+3)*resIncr] += alpha*cc3;
|
||||
}
|
||||
|
||||
const Index offset1 = (FirstAligned && alignmentStep==1)?3:1;
|
||||
const Index offset3 = (FirstAligned && alignmentStep==1)?1:3;
|
||||
|
||||
Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
|
||||
for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
|
||||
for(; i<n2; i+=2)
|
||||
{
|
||||
// FIXME: what is the purpose of this EIGEN_ALIGN_DEFAULT ??
|
||||
EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
|
||||
ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
|
||||
ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
|
||||
c1 = pset1<ResPacket>(ResScalar(0));
|
||||
|
||||
// this helps the compiler generating good binary code
|
||||
const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0),
|
||||
lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0);
|
||||
|
||||
if (Vectorizable)
|
||||
Index j=0;
|
||||
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
|
||||
{
|
||||
/* explicit vectorization */
|
||||
ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
|
||||
ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
|
||||
RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
|
||||
|
||||
// process initial unaligned coeffs
|
||||
// FIXME this loop get vectorized by the compiler !
|
||||
for (Index j=0; j<alignedStart; ++j)
|
||||
{
|
||||
RhsScalar b = rhs(j, 0);
|
||||
tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
|
||||
tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
|
||||
}
|
||||
|
||||
if (alignedSize>alignedStart)
|
||||
{
|
||||
switch(alignmentPattern)
|
||||
{
|
||||
case AllAligned:
|
||||
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
|
||||
_EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned);
|
||||
break;
|
||||
case EvenAligned:
|
||||
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
|
||||
_EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned);
|
||||
break;
|
||||
case FirstAligned:
|
||||
{
|
||||
Index j = alignedStart;
|
||||
if (peels>1)
|
||||
{
|
||||
/* Here we proccess 4 rows with with two peeled iterations to hide
|
||||
* the overhead of unaligned loads. Moreover unaligned loads are handled
|
||||
* using special shift/move operations between the two aligned packets
|
||||
* overlaping the desired unaligned packet. This is *much* more efficient
|
||||
* than basic unaligned loads.
|
||||
*/
|
||||
LhsPacket A01, A02, A03, A11, A12, A13;
|
||||
A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
|
||||
A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
|
||||
A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
|
||||
|
||||
for (; j<peeledSize; j+=peels*RhsPacketSize)
|
||||
{
|
||||
RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0);
|
||||
A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
|
||||
A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
|
||||
A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
|
||||
|
||||
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0);
|
||||
ptmp1 = pcj.pmadd(A01, b, ptmp1);
|
||||
A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
|
||||
ptmp2 = pcj.pmadd(A02, b, ptmp2);
|
||||
A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
|
||||
ptmp3 = pcj.pmadd(A03, b, ptmp3);
|
||||
A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
|
||||
|
||||
b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0);
|
||||
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0);
|
||||
ptmp1 = pcj.pmadd(A11, b, ptmp1);
|
||||
ptmp2 = pcj.pmadd(A12, b, ptmp2);
|
||||
ptmp3 = pcj.pmadd(A13, b, ptmp3);
|
||||
}
|
||||
}
|
||||
for (; j<alignedSize; j+=RhsPacketSize)
|
||||
_EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
|
||||
_EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned);
|
||||
break;
|
||||
}
|
||||
tmp0 += predux(ptmp0);
|
||||
tmp1 += predux(ptmp1);
|
||||
tmp2 += predux(ptmp2);
|
||||
tmp3 += predux(ptmp3);
|
||||
}
|
||||
} // end explicit vectorization
|
||||
|
||||
// process remaining coeffs (or all if no explicit vectorization)
|
||||
// FIXME this loop get vectorized by the compiler !
|
||||
for (Index j=alignedSize; j<depth; ++j)
|
||||
{
|
||||
RhsScalar b = rhs(j, 0);
|
||||
tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
|
||||
tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
|
||||
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
|
||||
c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
|
||||
}
|
||||
res[i*resIncr] += alpha*tmp0;
|
||||
res[(i+offset1)*resIncr] += alpha*tmp1;
|
||||
res[(i+2)*resIncr] += alpha*tmp2;
|
||||
res[(i+offset3)*resIncr] += alpha*tmp3;
|
||||
ResScalar cc0 = predux(c0);
|
||||
ResScalar cc1 = predux(c1);
|
||||
for(; j<cols; ++j)
|
||||
{
|
||||
RhsScalar b0 = rhs(j,0);
|
||||
|
||||
cc0 += cj.pmul(lhs(i+0,j), b0);
|
||||
cc1 += cj.pmul(lhs(i+1,j), b0);
|
||||
}
|
||||
res[(i+0)*resIncr] += alpha*cc0;
|
||||
res[(i+1)*resIncr] += alpha*cc1;
|
||||
}
|
||||
|
||||
// process remaining first and last rows (at most columnsAtOnce-1)
|
||||
Index end = rows;
|
||||
Index start = rowBound;
|
||||
do
|
||||
for(; i<rows; ++i)
|
||||
{
|
||||
for (Index i=start; i<end; ++i)
|
||||
ResPacket c0 = pset1<ResPacket>(ResScalar(0));
|
||||
Index j=0;
|
||||
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
|
||||
{
|
||||
EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
|
||||
ResPacket ptmp0 = pset1<ResPacket>(tmp0);
|
||||
const LhsScalars lhs0 = lhs.getVectorMapper(i, 0);
|
||||
// process first unaligned result's coeffs
|
||||
// FIXME this loop get vectorized by the compiler !
|
||||
for (Index j=0; j<alignedStart; ++j)
|
||||
tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
|
||||
|
||||
if (alignedSize>alignedStart)
|
||||
{
|
||||
// process aligned rhs coeffs
|
||||
if (lhs0.template aligned<LhsPacket>(alignedStart))
|
||||
for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
|
||||
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
|
||||
else
|
||||
for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
|
||||
ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
|
||||
tmp0 += predux(ptmp0);
|
||||
}
|
||||
|
||||
// process remaining scalars
|
||||
// FIXME this loop get vectorized by the compiler !
|
||||
for (Index j=alignedSize; j<depth; ++j)
|
||||
tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
|
||||
res[i*resIncr] += alpha*tmp0;
|
||||
RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0);
|
||||
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
|
||||
}
|
||||
if (skipRows)
|
||||
ResScalar cc0 = predux(c0);
|
||||
for(; j<cols; ++j)
|
||||
{
|
||||
start = 0;
|
||||
end = skipRows;
|
||||
skipRows = 0;
|
||||
cc0 += cj.pmul(lhs(i,j), rhs(j,0));
|
||||
}
|
||||
else
|
||||
break;
|
||||
} while(Vectorizable);
|
||||
|
||||
#undef _EIGEN_ACCUMULATE_PACKETS
|
||||
res[i*resIncr] += alpha*cc0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
Loading…
Reference in New Issue
Block a user