mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-13 18:37:27 +08:00
Fix and optimize mixed products
This commit is contained in:
parent
0fa8290366
commit
11fbdcbc38
@ -180,14 +180,15 @@ public:
|
||||
|
||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||
|
||||
// register block size along the N direction (must be either 2 or 4)
|
||||
nr = 4,//NumberOfRegisters/4,
|
||||
// register block size along the N direction must be 1 or 4
|
||||
nr = 4,
|
||||
|
||||
// register block size along the M direction (currently, this one cannot be modified)
|
||||
#ifdef __FMA__
|
||||
// we assume 16 registers
|
||||
mr = 3*LhsPacketSize,
|
||||
#else
|
||||
mr = 2*LhsPacketSize,
|
||||
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
|
||||
#endif
|
||||
|
||||
LhsProgress = LhsPacketSize,
|
||||
@ -209,15 +210,15 @@ public:
|
||||
p = pset1<ResPacket>(ResScalar(0));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||
{
|
||||
pbroadcast4(b, b0, b1, b2, b3);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||
{
|
||||
pbroadcast2(b, b0, b1);
|
||||
}
|
||||
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||
// {
|
||||
// pbroadcast4(b, b0, b1, b2, b3);
|
||||
// }
|
||||
//
|
||||
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||
// {
|
||||
// pbroadcast2(b, b0, b1);
|
||||
// }
|
||||
|
||||
template<typename RhsPacketType>
|
||||
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
|
||||
@ -290,8 +291,13 @@ public:
|
||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
|
||||
|
||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||
nr = NumberOfRegisters/2,
|
||||
mr = LhsPacketSize,
|
||||
nr = 4,
|
||||
#ifdef __FMA__
|
||||
// we assume 16 registers
|
||||
mr = 3*LhsPacketSize,
|
||||
#else
|
||||
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
|
||||
#endif
|
||||
|
||||
LhsProgress = LhsPacketSize,
|
||||
RhsProgress = 1
|
||||
@ -332,15 +338,15 @@ public:
|
||||
dest = ploadu<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||
{
|
||||
pbroadcast4(b, b0, b1, b2, b3);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||
{
|
||||
pbroadcast2(b, b0, b1);
|
||||
}
|
||||
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||
// {
|
||||
// pbroadcast4(b, b0, b1, b2, b3);
|
||||
// }
|
||||
//
|
||||
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||
// {
|
||||
// pbroadcast2(b, b0, b1);
|
||||
// }
|
||||
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
|
||||
{
|
||||
@ -566,7 +572,7 @@ public:
|
||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||
// FIXME: should depend on NumberOfRegisters
|
||||
nr = 4,
|
||||
mr = ResPacketSize,
|
||||
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*ResPacketSize,
|
||||
|
||||
LhsProgress = ResPacketSize,
|
||||
RhsProgress = 1
|
||||
@ -593,19 +599,25 @@ public:
|
||||
}
|
||||
|
||||
// linking error if instantiated without being optimized out:
|
||||
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
|
||||
|
||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||
{
|
||||
// FIXME not sure that's the best way to implement it!
|
||||
b0 = pload1<RhsPacket>(b+0);
|
||||
b1 = pload1<RhsPacket>(b+1);
|
||||
}
|
||||
// void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
|
||||
//
|
||||
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||
// {
|
||||
// // FIXME not sure that's the best way to implement it!
|
||||
// b0 = pload1<RhsPacket>(b+0);
|
||||
// b1 = pload1<RhsPacket>(b+1);
|
||||
// }
|
||||
|
||||
EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const
|
||||
{
|
||||
dest = ploaddup<LhsPacket>(a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
|
||||
{
|
||||
eigen_internal_assert(unpacket_traits<RhsPacket>::size<=4);
|
||||
loadRhs(b,dest);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||
{
|
||||
@ -619,7 +631,13 @@ public:
|
||||
|
||||
EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
|
||||
{
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
EIGEN_UNUSED_VARIABLE(tmp);
|
||||
c.v = pmadd(a,b.v,c.v);
|
||||
#else
|
||||
tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const
|
||||
@ -956,7 +974,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
for(Index k=0; k<peeled_kc; k+=pk)
|
||||
{
|
||||
IACA_START
|
||||
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4");
|
||||
EIGEN_ASM_COMMENT("begin gegp micro kernel 2pX4");
|
||||
RhsPacket B_0, B1;
|
||||
|
||||
#define EIGEN_GEBGP_ONESTEP(K) \
|
||||
@ -1134,7 +1152,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
for(Index k=0; k<peeled_kc; k+=pk)
|
||||
{
|
||||
IACA_START
|
||||
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4");
|
||||
EIGEN_ASM_COMMENT("begin gegp micro kernel 1pX4");
|
||||
RhsPacket B_0, B1;
|
||||
|
||||
#define EIGEN_GEBGP_ONESTEP(K) \
|
||||
@ -1160,7 +1178,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
EIGEN_GEBGP_ONESTEP(7);
|
||||
|
||||
blB += pk*4*RhsProgress;
|
||||
blA += pk*(1*Traits::LhsProgress);
|
||||
blA += pk*1*LhsProgress;
|
||||
IACA_END
|
||||
}
|
||||
// process remaining peeled loop
|
||||
@ -1169,7 +1187,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
RhsPacket B_0, B1;
|
||||
EIGEN_GEBGP_ONESTEP(0);
|
||||
blB += 4*RhsProgress;
|
||||
blA += 1*Traits::LhsProgress;
|
||||
blA += 1*LhsProgress;
|
||||
}
|
||||
#undef EIGEN_GEBGP_ONESTEP
|
||||
|
||||
@ -1439,6 +1457,8 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
|
||||
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
|
||||
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
|
||||
const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
|
||||
const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1
|
||||
: Pack2>1 ? (rows/Pack2)*Pack2 : 0;
|
||||
|
||||
// Pack 3 packets
|
||||
if(Pack1>=3*PacketSize)
|
||||
@ -1496,16 +1516,20 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
|
||||
}
|
||||
}
|
||||
// Pack scalars
|
||||
// if(rows-peeled_mc>=Pack2)
|
||||
// {
|
||||
// if(PanelMode) count += Pack2*offset;
|
||||
// for(Index k=0; k<depth; k++)
|
||||
// for(Index w=0; w<Pack2; w++)
|
||||
// blockA[count++] = cj(lhs(peeled_mc+w, k));
|
||||
// if(PanelMode) count += Pack2 * (stride-offset-depth);
|
||||
// peeled_mc += Pack2;
|
||||
// }
|
||||
for(Index i=peeled_mc1; i<rows; i++)
|
||||
if(Pack2<PacketSize && Pack2>1)
|
||||
{
|
||||
for(Index i=peeled_mc1; i<peeled_mc0; i+=Pack2)
|
||||
{
|
||||
if(PanelMode) count += Pack2 * offset;
|
||||
|
||||
for(Index k=0; k<depth; k++)
|
||||
for(Index w=0; w<Pack2; w++)
|
||||
blockA[count++] = cj(lhs(i+w, k));
|
||||
|
||||
if(PanelMode) count += Pack2 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
for(Index i=peeled_mc0; i<rows; i++)
|
||||
{
|
||||
if(PanelMode) count += offset;
|
||||
for(Index k=0; k<depth; k++)
|
||||
@ -1539,35 +1563,36 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
|
||||
// const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
|
||||
// const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
|
||||
|
||||
int pack_packets = Pack1/PacketSize;
|
||||
int pack = Pack1;
|
||||
Index i = 0;
|
||||
while(pack_packets>0)
|
||||
while(pack>0)
|
||||
{
|
||||
|
||||
Index remaining_rows = rows-i;
|
||||
Index peeled_mc = i+(remaining_rows/(pack_packets*PacketSize))*(pack_packets*PacketSize);
|
||||
// std::cout << "pack_packets = " << pack_packets << " from " << i << " to " << peeled_mc << "\n";
|
||||
for(; i<peeled_mc; i+=pack_packets*PacketSize)
|
||||
Index peeled_mc = i+(remaining_rows/pack)*pack;
|
||||
for(; i<peeled_mc; i+=pack)
|
||||
{
|
||||
if(PanelMode) count += (pack_packets*PacketSize) * offset;
|
||||
if(PanelMode) count += pack * offset;
|
||||
|
||||
const Index peeled_k = (depth/PacketSize)*PacketSize;
|
||||
Index k=0;
|
||||
for(; k<peeled_k; k+=PacketSize)
|
||||
if(pack>=PacketSize)
|
||||
{
|
||||
for (Index m = 0; m < (pack_packets*PacketSize); m += PacketSize)
|
||||
for(; k<peeled_k; k+=PacketSize)
|
||||
{
|
||||
Kernel<Packet> kernel;
|
||||
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
|
||||
ptranspose(kernel);
|
||||
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack_packets*PacketSize)*p, cj.pconj(kernel.packet[p]));
|
||||
for (Index m = 0; m < pack; m += PacketSize)
|
||||
{
|
||||
Kernel<Packet> kernel;
|
||||
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
|
||||
ptranspose(kernel);
|
||||
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
|
||||
}
|
||||
count += PacketSize*pack;
|
||||
}
|
||||
count += PacketSize*(pack_packets*PacketSize);
|
||||
}
|
||||
for(; k<depth; k++)
|
||||
{
|
||||
Index w=0;
|
||||
for(; w<(pack_packets*PacketSize)-3; w+=4)
|
||||
for(; w<pack-3; w+=4)
|
||||
{
|
||||
Scalar a(cj(lhs(i+w+0, k))),
|
||||
b(cj(lhs(i+w+1, k))),
|
||||
@ -1578,26 +1603,19 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
|
||||
blockA[count++] = c;
|
||||
blockA[count++] = d;
|
||||
}
|
||||
if(PacketSize%4)
|
||||
for(;w<pack_packets*PacketSize;++w)
|
||||
if(pack%4)
|
||||
for(;w<pack;++w)
|
||||
blockA[count++] = cj(lhs(i+w, k));
|
||||
}
|
||||
|
||||
if(PanelMode) count += (pack_packets*PacketSize) * (stride-offset-depth);
|
||||
if(PanelMode) count += pack * (stride-offset-depth);
|
||||
}
|
||||
|
||||
pack_packets--;
|
||||
pack -= PacketSize;
|
||||
if(pack<Pack2 && (pack+PacketSize)!=Pack2)
|
||||
pack = Pack2;
|
||||
}
|
||||
|
||||
// if(rows-peeled_mc>=Pack2)
|
||||
// {
|
||||
// if(PanelMode) count += Pack2*offset;
|
||||
// for(Index k=0; k<depth; k++)
|
||||
// for(Index w=0; w<Pack2; w++)
|
||||
// blockA[count++] = cj(lhs(peeled_mc+w, k));
|
||||
// if(PanelMode) count += Pack2 * (stride-offset-depth);
|
||||
// peeled_mc += Pack2;
|
||||
// }
|
||||
for(; i<rows; i++)
|
||||
{
|
||||
if(PanelMode) count += offset;
|
||||
|
Loading…
x
Reference in New Issue
Block a user