Vectorized the loop peeling of the inner loop of the block-panel matrix multiplication code. This speeds up the multiplication of matrices which size is not a multiple of the packet size.

This commit is contained in:
Benoit Steiner 2014-03-28 12:11:23 -07:00
parent 39bfbd43f0
commit ad59ade116

View File

@ -206,6 +206,11 @@ public:
dest = pload<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploadu<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const
{
// It would be a lot cleaner to call pmadd all the time. Unfortunately if we
@ -278,7 +283,12 @@ public:
{
dest = pload<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploadu<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
{
pbroadcast4(b, b0, b1, b2, b3);
@ -334,7 +344,9 @@ public:
&& packet_traits<Scalar>::Vectorizable,
RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
// FIXME: should depend on NumberOfRegisters
nr = 4,
mr = ResPacketSize,
@ -402,6 +414,11 @@ public:
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
}
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const
{
c.first = padd(pmul(a,b.first), c.first);
@ -509,6 +526,11 @@ public:
dest = ploaddup<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploaddup<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
{
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
@ -706,49 +728,84 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
// gets a 1 x 8 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
// FIXME directly use blockB ???
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
// TODO peel this loop
for(Index k=0; k<depth; k++)
{
LhsScalar A0;
RhsScalar B_0, B_1;
A0 = blA[k];
B_0 = blB[0];
B_1 = blB[1];
MADD(cj,A0,B_0,C0, B_0);
MADD(cj,A0,B_1,C1, B_1);
B_0 = blB[2];
B_1 = blB[3];
MADD(cj,A0,B_0,C2, B_0);
MADD(cj,A0,B_1,C3, B_1);
B_0 = blB[4];
B_1 = blB[5];
MADD(cj,A0,B_0,C4, B_0);
MADD(cj,A0,B_1,C5, B_1);
B_0 = blB[6];
B_1 = blB[7];
MADD(cj,A0,B_0,C6, B_0);
MADD(cj,A0,B_1,C7, B_1);
if(nr == Traits::RhsPacketSize)
{
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
blB += 8;
}
res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += alpha*C1;
res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3;
res[(j2+4)*resStride + i] += alpha*C4;
res[(j2+5)*resStride + i] += alpha*C5;
res[(j2+6)*resStride + i] += alpha*C6;
res[(j2+7)*resStride + i] += alpha*C7;
}
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
SwappedTraits straits;
SAccPacket C0;
straits.initAcc(C0);
for(Index k=0; k<depth; k++)
{
SLhsPacket A0;
straits.loadLhsUnaligned(blB, A0);
SRhsPacket B_0;
straits.loadRhs(&blA[k], B_0);
SRhsPacket T0;
straits.madd(A0,B_0,C0,T0);
blB += nr;
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(C0, alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
}
else
{
// gets a 1 x 8 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
for(Index k=0; k<depth; k++)
{
LhsScalar A0;
RhsScalar B_0, B_1;
A0 = blA[k];
B_0 = blB[0];
B_1 = blB[1];
MADD(cj,A0,B_0,C0, B_0);
MADD(cj,A0,B_1,C1, B_1);
B_0 = blB[2];
B_1 = blB[3];
MADD(cj,A0,B_0,C2, B_0);
MADD(cj,A0,B_1,C3, B_1);
B_0 = blB[4];
B_1 = blB[5];
MADD(cj,A0,B_0,C4, B_0);
MADD(cj,A0,B_1,C5, B_1);
B_0 = blB[6];
B_1 = blB[7];
MADD(cj,A0,B_0,C6, B_0);
MADD(cj,A0,B_1,C7, B_1);
blB += 8;
}
res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += alpha*C1;
res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3;
res[(j2+4)*resStride + i] += alpha*C4;
res[(j2+5)*resStride + i] += alpha*C5;
res[(j2+6)*resStride + i] += alpha*C6;
res[(j2+7)*resStride + i] += alpha*C7;
}
}
}
}
@ -839,35 +896,68 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
// gets a 1 x 4 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0);
// FIXME directly use blockB ???
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
// TODO peel this loop
for(Index k=0; k<depth; k++)
{
LhsScalar A0;
RhsScalar B_0, B_1;
A0 = blA[k];
B_0 = blB[0];
B_1 = blB[1];
MADD(cj,A0,B_0,C0, B_0);
MADD(cj,A0,B_1,C1, B_1);
B_0 = blB[2];
B_1 = blB[3];
MADD(cj,A0,B_0,C2, B_0);
MADD(cj,A0,B_1,C3, B_1);
if(nr == Traits::RhsPacketSize)
{
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
blB += 4;
}
res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += alpha*C1;
res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3;
}
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
SwappedTraits straits;
SAccPacket C0;
straits.initAcc(C0);
for(Index k=0; k<depth; k++)
{
SLhsPacket A0;
straits.loadLhsUnaligned(blB, A0);
SRhsPacket B_0;
straits.loadRhs(&blA[k], B_0);
SRhsPacket T0;
straits.madd(A0,B_0,C0,T0);
blB += nr;
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(C0, alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
} else {
// gets a 1 x 4 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0);
for(Index k=0; k<depth; k++)
{
LhsScalar A0;
RhsScalar B_0, B_1;
A0 = blA[k];
B_0 = blB[0];
B_1 = blB[1];
MADD(cj,A0,B_0,C0, B_0);
MADD(cj,A0,B_1,C1, B_1);
B_0 = blB[2];
B_1 = blB[3];
MADD(cj,A0,B_0,C2, B_0);
MADD(cj,A0,B_1,C3, B_1);
blB += 4;
}
res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += alpha*C1;
res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3;
}
}
}
}