mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-17 18:09:55 +08:00
Vectorized the loop peeling of the inner loop of the block-panel matrix multiplication code. This speeds up the multiplication of matrices which size is not a multiple of the packet size.
This commit is contained in:
parent
39bfbd43f0
commit
ad59ade116
@ -206,6 +206,11 @@ public:
|
||||
dest = pload<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||
{
|
||||
dest = ploadu<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const
|
||||
{
|
||||
// It would be a lot cleaner to call pmadd all the time. Unfortunately if we
|
||||
@ -278,7 +283,12 @@ public:
|
||||
{
|
||||
dest = pload<LhsPacket>(a);
|
||||
}
|
||||
|
||||
|
||||
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||
{
|
||||
dest = ploadu<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||
{
|
||||
pbroadcast4(b, b0, b1, b2, b3);
|
||||
@ -334,7 +344,9 @@ public:
|
||||
&& packet_traits<Scalar>::Vectorizable,
|
||||
RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
|
||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
|
||||
|
||||
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
|
||||
|
||||
// FIXME: should depend on NumberOfRegisters
|
||||
nr = 4,
|
||||
mr = ResPacketSize,
|
||||
@ -402,6 +414,11 @@ public:
|
||||
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||
{
|
||||
dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const
|
||||
{
|
||||
c.first = padd(pmul(a,b.first), c.first);
|
||||
@ -509,6 +526,11 @@ public:
|
||||
dest = ploaddup<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||
{
|
||||
dest = ploaddup<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
|
||||
{
|
||||
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
|
||||
@ -706,49 +728,84 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
||||
prefetch(&blA[0]);
|
||||
|
||||
// gets a 1 x 8 res block as registers
|
||||
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
|
||||
// FIXME directly use blockB ???
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
||||
// TODO peel this loop
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
LhsScalar A0;
|
||||
RhsScalar B_0, B_1;
|
||||
|
||||
A0 = blA[k];
|
||||
|
||||
B_0 = blB[0];
|
||||
B_1 = blB[1];
|
||||
MADD(cj,A0,B_0,C0, B_0);
|
||||
MADD(cj,A0,B_1,C1, B_1);
|
||||
|
||||
B_0 = blB[2];
|
||||
B_1 = blB[3];
|
||||
MADD(cj,A0,B_0,C2, B_0);
|
||||
MADD(cj,A0,B_1,C3, B_1);
|
||||
|
||||
B_0 = blB[4];
|
||||
B_1 = blB[5];
|
||||
MADD(cj,A0,B_0,C4, B_0);
|
||||
MADD(cj,A0,B_1,C5, B_1);
|
||||
|
||||
B_0 = blB[6];
|
||||
B_1 = blB[7];
|
||||
MADD(cj,A0,B_0,C6, B_0);
|
||||
MADD(cj,A0,B_1,C7, B_1);
|
||||
if(nr == Traits::RhsPacketSize)
|
||||
{
|
||||
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
|
||||
|
||||
blB += 8;
|
||||
}
|
||||
res[(j2+0)*resStride + i] += alpha*C0;
|
||||
res[(j2+1)*resStride + i] += alpha*C1;
|
||||
res[(j2+2)*resStride + i] += alpha*C2;
|
||||
res[(j2+3)*resStride + i] += alpha*C3;
|
||||
res[(j2+4)*resStride + i] += alpha*C4;
|
||||
res[(j2+5)*resStride + i] += alpha*C5;
|
||||
res[(j2+6)*resStride + i] += alpha*C6;
|
||||
res[(j2+7)*resStride + i] += alpha*C7;
|
||||
}
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
|
||||
typedef typename SwappedTraits::ResScalar SResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
typedef typename SwappedTraits::RhsPacket SRhsPacket;
|
||||
typedef typename SwappedTraits::ResPacket SResPacket;
|
||||
typedef typename SwappedTraits::AccPacket SAccPacket;
|
||||
SwappedTraits straits;
|
||||
|
||||
SAccPacket C0;
|
||||
straits.initAcc(C0);
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
SLhsPacket A0;
|
||||
straits.loadLhsUnaligned(blB, A0);
|
||||
SRhsPacket B_0;
|
||||
straits.loadRhs(&blA[k], B_0);
|
||||
SRhsPacket T0;
|
||||
straits.madd(A0,B_0,C0,T0);
|
||||
blB += nr;
|
||||
}
|
||||
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
|
||||
SResPacket alphav = pset1<SResPacket>(alpha);
|
||||
straits.acc(C0, alphav, R);
|
||||
pscatter(&res[j2*resStride + i], R, resStride);
|
||||
|
||||
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
|
||||
}
|
||||
else
|
||||
{
|
||||
// gets a 1 x 8 res block as registers
|
||||
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
|
||||
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
LhsScalar A0;
|
||||
RhsScalar B_0, B_1;
|
||||
|
||||
A0 = blA[k];
|
||||
|
||||
B_0 = blB[0];
|
||||
B_1 = blB[1];
|
||||
MADD(cj,A0,B_0,C0, B_0);
|
||||
MADD(cj,A0,B_1,C1, B_1);
|
||||
|
||||
B_0 = blB[2];
|
||||
B_1 = blB[3];
|
||||
MADD(cj,A0,B_0,C2, B_0);
|
||||
MADD(cj,A0,B_1,C3, B_1);
|
||||
|
||||
B_0 = blB[4];
|
||||
B_1 = blB[5];
|
||||
MADD(cj,A0,B_0,C4, B_0);
|
||||
MADD(cj,A0,B_1,C5, B_1);
|
||||
|
||||
B_0 = blB[6];
|
||||
B_1 = blB[7];
|
||||
MADD(cj,A0,B_0,C6, B_0);
|
||||
MADD(cj,A0,B_1,C7, B_1);
|
||||
|
||||
blB += 8;
|
||||
}
|
||||
res[(j2+0)*resStride + i] += alpha*C0;
|
||||
res[(j2+1)*resStride + i] += alpha*C1;
|
||||
res[(j2+2)*resStride + i] += alpha*C2;
|
||||
res[(j2+3)*resStride + i] += alpha*C3;
|
||||
res[(j2+4)*resStride + i] += alpha*C4;
|
||||
res[(j2+5)*resStride + i] += alpha*C5;
|
||||
res[(j2+6)*resStride + i] += alpha*C6;
|
||||
res[(j2+7)*resStride + i] += alpha*C7;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -839,35 +896,68 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
||||
prefetch(&blA[0]);
|
||||
|
||||
// gets a 1 x 4 res block as registers
|
||||
ResScalar C0(0), C1(0), C2(0), C3(0);
|
||||
// FIXME directly use blockB ???
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
||||
// TODO peel this loop
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
LhsScalar A0;
|
||||
RhsScalar B_0, B_1;
|
||||
|
||||
A0 = blA[k];
|
||||
|
||||
B_0 = blB[0];
|
||||
B_1 = blB[1];
|
||||
MADD(cj,A0,B_0,C0, B_0);
|
||||
MADD(cj,A0,B_1,C1, B_1);
|
||||
|
||||
B_0 = blB[2];
|
||||
B_1 = blB[3];
|
||||
MADD(cj,A0,B_0,C2, B_0);
|
||||
MADD(cj,A0,B_1,C3, B_1);
|
||||
if(nr == Traits::RhsPacketSize)
|
||||
{
|
||||
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
|
||||
|
||||
blB += 4;
|
||||
}
|
||||
res[(j2+0)*resStride + i] += alpha*C0;
|
||||
res[(j2+1)*resStride + i] += alpha*C1;
|
||||
res[(j2+2)*resStride + i] += alpha*C2;
|
||||
res[(j2+3)*resStride + i] += alpha*C3;
|
||||
}
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
|
||||
typedef typename SwappedTraits::ResScalar SResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
typedef typename SwappedTraits::RhsPacket SRhsPacket;
|
||||
typedef typename SwappedTraits::ResPacket SResPacket;
|
||||
typedef typename SwappedTraits::AccPacket SAccPacket;
|
||||
SwappedTraits straits;
|
||||
|
||||
SAccPacket C0;
|
||||
straits.initAcc(C0);
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
SLhsPacket A0;
|
||||
straits.loadLhsUnaligned(blB, A0);
|
||||
SRhsPacket B_0;
|
||||
straits.loadRhs(&blA[k], B_0);
|
||||
SRhsPacket T0;
|
||||
straits.madd(A0,B_0,C0,T0);
|
||||
blB += nr;
|
||||
}
|
||||
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
|
||||
SResPacket alphav = pset1<SResPacket>(alpha);
|
||||
straits.acc(C0, alphav, R);
|
||||
pscatter(&res[j2*resStride + i], R, resStride);
|
||||
|
||||
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
|
||||
} else {
|
||||
// gets a 1 x 4 res block as registers
|
||||
ResScalar C0(0), C1(0), C2(0), C3(0);
|
||||
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
LhsScalar A0;
|
||||
RhsScalar B_0, B_1;
|
||||
|
||||
A0 = blA[k];
|
||||
|
||||
B_0 = blB[0];
|
||||
B_1 = blB[1];
|
||||
MADD(cj,A0,B_0,C0, B_0);
|
||||
MADD(cj,A0,B_1,C1, B_1);
|
||||
|
||||
B_0 = blB[2];
|
||||
B_1 = blB[3];
|
||||
MADD(cj,A0,B_0,C2, B_0);
|
||||
MADD(cj,A0,B_1,C3, B_1);
|
||||
|
||||
blB += 4;
|
||||
}
|
||||
res[(j2+0)*resStride + i] += alpha*C0;
|
||||
res[(j2+1)*resStride + i] += alpha*C1;
|
||||
res[(j2+2)*resStride + i] += alpha*C2;
|
||||
res[(j2+3)*resStride + i] += alpha*C3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user