Vectorize row-by-row gebp loop iterations on 16 packets as well

Signed-off-by: Gustavo Lima Chaves <gustavo.lima.chaves@intel.com>
Signed-off-by: Mark D. Ryan <mark.d.ryan@intel.com>
This commit is contained in:
Gustavo Lima Chaves 2018-11-06 10:48:42 -08:00
parent 9d318b92c6
commit 4ad359237a

View File

@ -891,6 +891,86 @@ struct gebp_kernel
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
};
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs,
int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs>::LhsProgress>
struct last_row_process_16_packets
{
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename Traits::ResScalar ResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA,
const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
ResScalar alpha, SAccPacket &C0)
{
EIGEN_UNUSED_VARIABLE(res);
EIGEN_UNUSED_VARIABLE(straits);
EIGEN_UNUSED_VARIABLE(blA);
EIGEN_UNUSED_VARIABLE(blB);
EIGEN_UNUSED_VARIABLE(depth);
EIGEN_UNUSED_VARIABLE(endk);
EIGEN_UNUSED_VARIABLE(i);
EIGEN_UNUSED_VARIABLE(j2);
EIGEN_UNUSED_VARIABLE(alpha);
EIGEN_UNUSED_VARIABLE(C0);
}
};
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename Traits::ResScalar ResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA,
const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
ResScalar alpha, SAccPacket &C0)
{
typedef typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half SResPacketQuarter;
typedef typename unpacket_traits<typename unpacket_traits<SLhsPacket>::half>::half SLhsPacketQuarter;
typedef typename unpacket_traits<typename unpacket_traits<SRhsPacket>::half>::half SRhsPacketQuarter;
typedef typename unpacket_traits<typename unpacket_traits<SAccPacket>::half>::half SAccPacketQuarter;
SResPacketQuarter R = res.template gatherPacket<SResPacketQuarter>(i, j2);
SResPacketQuarter alphav = pset1<SResPacketQuarter>(alpha);
if (depth - endk > 0)
{
// We have to handle the last row(s) of the rhs, which
// correspond to a half-packet
SAccPacketQuarter c0 = predux_half_dowto4(predux_half_dowto4(C0));
for (Index kk = endk; kk < depth; kk++)
{
SLhsPacketQuarter a0;
SRhsPacketQuarter b0;
straits.loadLhsUnaligned(blB, a0);
straits.loadRhs(blA, b0);
straits.madd(a0,b0,c0,b0);
blB += SwappedTraits::LhsProgress/4;
blA += 1;
}
straits.acc(c0, alphav, R);
}
else
{
straits.acc(predux_half_dowto4(predux_half_dowto4(C0)), alphav, R);
}
res.scatterPacket(i, j2, R);
}
};
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE
void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,ConjugateRhs>
@ -1527,13 +1607,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
prefetch(&blA[0]);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
// The following piece of code won't work for 512 bit registers
// Moreover, if LhsProgress==8 it assumes that there is a half packet of the same size
// as nr (which is currently 4) for the return type.
// If LhsProgress is 8 or 16, it assumes that there is a
// half or quarter packet, respectively, of the same size as
// nr (which is currently 4) for the return type.
const int SResPacketHalfSize = unpacket_traits<typename unpacket_traits<SResPacket>::half>::size;
const int SResPacketQuarterSize = unpacket_traits<typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half>::size;
if ((SwappedTraits::LhsProgress % 4) == 0 &&
(SwappedTraits::LhsProgress <= 8) &&
(SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr))
(SwappedTraits::LhsProgress<=16) &&
(SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) &&
(SwappedTraits::LhsProgress!=16 || SResPacketQuarterSize==nr))
{
SAccPacket C0, C1, C2, C3;
straits.initAcc(C0);
@ -1610,6 +1692,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
}
res.scatterPacket(i, j2, R);
}
else if (SwappedTraits::LhsProgress==16)
{
// Special case where we have to first reduce the
// accumulation register C0. We specialize the block in
// template form, so that LhsProgress < 16 paths don't
// fail to compile
last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> p;
p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0);
}
else
{
SResPacket R = res.template gatherPacket<SResPacket>(i, j2);