Fix and optimize mixed products

This commit is contained in:
Gael Guennebaud 2014-04-17 16:04:30 +02:00
parent 0fa8290366
commit 11fbdcbc38

View File

@ -180,14 +180,15 @@ public:
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
// register block size along the N direction (must be either 2 or 4)
nr = 4,//NumberOfRegisters/4,
// register block size along the N direction must be 1 or 4
nr = 4,
// register block size along the M direction (currently, this one cannot be modified)
#ifdef __FMA__
// we assume 16 registers
mr = 3*LhsPacketSize,
#else
mr = 2*LhsPacketSize,
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
#endif
LhsProgress = LhsPacketSize,
@ -209,15 +210,15 @@ public:
p = pset1<ResPacket>(ResScalar(0));
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
{
pbroadcast4(b, b0, b1, b2, b3);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
{
pbroadcast2(b, b0, b1);
}
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
// {
// pbroadcast4(b, b0, b1, b2, b3);
// }
//
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
// {
// pbroadcast2(b, b0, b1);
// }
template<typename RhsPacketType>
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
@ -290,8 +291,13 @@ public:
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
nr = NumberOfRegisters/2,
mr = LhsPacketSize,
nr = 4,
#ifdef __FMA__
// we assume 16 registers
mr = 3*LhsPacketSize,
#else
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
#endif
LhsProgress = LhsPacketSize,
RhsProgress = 1
@ -332,15 +338,15 @@ public:
dest = ploadu<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
{
pbroadcast4(b, b0, b1, b2, b3);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
{
pbroadcast2(b, b0, b1);
}
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
// {
// pbroadcast4(b, b0, b1, b2, b3);
// }
//
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
// {
// pbroadcast2(b, b0, b1);
// }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
{
@ -566,7 +572,7 @@ public:
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
// FIXME: should depend on NumberOfRegisters
nr = 4,
mr = ResPacketSize,
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*ResPacketSize,
LhsProgress = ResPacketSize,
RhsProgress = 1
@ -593,19 +599,25 @@ public:
}
// linking error if instantiated without being optimized out:
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
{
// FIXME not sure that's the best way to implement it!
b0 = pload1<RhsPacket>(b+0);
b1 = pload1<RhsPacket>(b+1);
}
// void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
//
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
// {
// // FIXME not sure that's the best way to implement it!
// b0 = pload1<RhsPacket>(b+0);
// b1 = pload1<RhsPacket>(b+1);
// }
EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploaddup<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
eigen_internal_assert(unpacket_traits<RhsPacket>::size<=4);
loadRhs(b,dest);
}
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
@ -619,7 +631,13 @@ public:
EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
{
#ifdef EIGEN_VECTORIZE_FMA
EIGEN_UNUSED_VARIABLE(tmp);
c.v = pmadd(a,b.v,c.v);
#else
tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp);
#endif
}
EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const
@ -956,7 +974,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index k=0; k<peeled_kc; k+=pk)
{
IACA_START
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4");
EIGEN_ASM_COMMENT("begin gegp micro kernel 2pX4");
RhsPacket B_0, B1;
#define EIGEN_GEBGP_ONESTEP(K) \
@ -1134,7 +1152,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index k=0; k<peeled_kc; k+=pk)
{
IACA_START
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4");
EIGEN_ASM_COMMENT("begin gegp micro kernel 1pX4");
RhsPacket B_0, B1;
#define EIGEN_GEBGP_ONESTEP(K) \
@ -1160,7 +1178,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
EIGEN_GEBGP_ONESTEP(7);
blB += pk*4*RhsProgress;
blA += pk*(1*Traits::LhsProgress);
blA += pk*1*LhsProgress;
IACA_END
}
// process remaining peeled loop
@ -1169,7 +1187,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
RhsPacket B_0, B1;
EIGEN_GEBGP_ONESTEP(0);
blB += 4*RhsProgress;
blA += 1*Traits::LhsProgress;
blA += 1*LhsProgress;
}
#undef EIGEN_GEBGP_ONESTEP
@ -1439,6 +1457,8 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1
: Pack2>1 ? (rows/Pack2)*Pack2 : 0;
// Pack 3 packets
if(Pack1>=3*PacketSize)
@ -1496,16 +1516,20 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
}
}
// Pack scalars
// if(rows-peeled_mc>=Pack2)
// {
// if(PanelMode) count += Pack2*offset;
// for(Index k=0; k<depth; k++)
// for(Index w=0; w<Pack2; w++)
// blockA[count++] = cj(lhs(peeled_mc+w, k));
// if(PanelMode) count += Pack2 * (stride-offset-depth);
// peeled_mc += Pack2;
// }
for(Index i=peeled_mc1; i<rows; i++)
if(Pack2<PacketSize && Pack2>1)
{
for(Index i=peeled_mc1; i<peeled_mc0; i+=Pack2)
{
if(PanelMode) count += Pack2 * offset;
for(Index k=0; k<depth; k++)
for(Index w=0; w<Pack2; w++)
blockA[count++] = cj(lhs(i+w, k));
if(PanelMode) count += Pack2 * (stride-offset-depth);
}
}
for(Index i=peeled_mc0; i<rows; i++)
{
if(PanelMode) count += offset;
for(Index k=0; k<depth; k++)
@ -1539,35 +1563,36 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
// const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
// const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
int pack_packets = Pack1/PacketSize;
int pack = Pack1;
Index i = 0;
while(pack_packets>0)
while(pack>0)
{
Index remaining_rows = rows-i;
Index peeled_mc = i+(remaining_rows/(pack_packets*PacketSize))*(pack_packets*PacketSize);
// std::cout << "pack_packets = " << pack_packets << " from " << i << " to " << peeled_mc << "\n";
for(; i<peeled_mc; i+=pack_packets*PacketSize)
Index peeled_mc = i+(remaining_rows/pack)*pack;
for(; i<peeled_mc; i+=pack)
{
if(PanelMode) count += (pack_packets*PacketSize) * offset;
if(PanelMode) count += pack * offset;
const Index peeled_k = (depth/PacketSize)*PacketSize;
Index k=0;
for(; k<peeled_k; k+=PacketSize)
if(pack>=PacketSize)
{
for (Index m = 0; m < (pack_packets*PacketSize); m += PacketSize)
for(; k<peeled_k; k+=PacketSize)
{
Kernel<Packet> kernel;
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
ptranspose(kernel);
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack_packets*PacketSize)*p, cj.pconj(kernel.packet[p]));
for (Index m = 0; m < pack; m += PacketSize)
{
Kernel<Packet> kernel;
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
ptranspose(kernel);
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
}
count += PacketSize*pack;
}
count += PacketSize*(pack_packets*PacketSize);
}
for(; k<depth; k++)
{
Index w=0;
for(; w<(pack_packets*PacketSize)-3; w+=4)
for(; w<pack-3; w+=4)
{
Scalar a(cj(lhs(i+w+0, k))),
b(cj(lhs(i+w+1, k))),
@ -1578,26 +1603,19 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
blockA[count++] = c;
blockA[count++] = d;
}
if(PacketSize%4)
for(;w<pack_packets*PacketSize;++w)
if(pack%4)
for(;w<pack;++w)
blockA[count++] = cj(lhs(i+w, k));
}
if(PanelMode) count += (pack_packets*PacketSize) * (stride-offset-depth);
if(PanelMode) count += pack * (stride-offset-depth);
}
pack_packets--;
pack -= PacketSize;
if(pack<Pack2 && (pack+PacketSize)!=Pack2)
pack = Pack2;
}
// if(rows-peeled_mc>=Pack2)
// {
// if(PanelMode) count += Pack2*offset;
// for(Index k=0; k<depth; k++)
// for(Index w=0; w<Pack2; w++)
// blockA[count++] = cj(lhs(peeled_mc+w, k));
// if(PanelMode) count += Pack2 * (stride-offset-depth);
// peeled_mc += Pack2;
// }
for(; i<rows; i++)
{
if(PanelMode) count += offset;