mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Vectorize row-by-row gebp loop iterations on 16 packets as well
Signed-off-by: Gustavo Lima Chaves <gustavo.lima.chaves@intel.com> Signed-off-by: Mark D. Ryan <mark.d.ryan@intel.com>
This commit is contained in:
parent
9d318b92c6
commit
4ad359237a
@ -891,6 +891,86 @@ struct gebp_kernel
|
||||
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
|
||||
};
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs,
|
||||
int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs>::LhsProgress>
|
||||
struct last_row_process_16_packets
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
typedef typename SwappedTraits::RhsPacket SRhsPacket;
|
||||
typedef typename SwappedTraits::ResPacket SResPacket;
|
||||
typedef typename SwappedTraits::AccPacket SAccPacket;
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA,
|
||||
const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
|
||||
ResScalar alpha, SAccPacket &C0)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(res);
|
||||
EIGEN_UNUSED_VARIABLE(straits);
|
||||
EIGEN_UNUSED_VARIABLE(blA);
|
||||
EIGEN_UNUSED_VARIABLE(blB);
|
||||
EIGEN_UNUSED_VARIABLE(depth);
|
||||
EIGEN_UNUSED_VARIABLE(endk);
|
||||
EIGEN_UNUSED_VARIABLE(i);
|
||||
EIGEN_UNUSED_VARIABLE(j2);
|
||||
EIGEN_UNUSED_VARIABLE(alpha);
|
||||
EIGEN_UNUSED_VARIABLE(C0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
||||
struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
typedef typename SwappedTraits::RhsPacket SRhsPacket;
|
||||
typedef typename SwappedTraits::ResPacket SResPacket;
|
||||
typedef typename SwappedTraits::AccPacket SAccPacket;
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA,
|
||||
const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2,
|
||||
ResScalar alpha, SAccPacket &C0)
|
||||
{
|
||||
typedef typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half SResPacketQuarter;
|
||||
typedef typename unpacket_traits<typename unpacket_traits<SLhsPacket>::half>::half SLhsPacketQuarter;
|
||||
typedef typename unpacket_traits<typename unpacket_traits<SRhsPacket>::half>::half SRhsPacketQuarter;
|
||||
typedef typename unpacket_traits<typename unpacket_traits<SAccPacket>::half>::half SAccPacketQuarter;
|
||||
|
||||
SResPacketQuarter R = res.template gatherPacket<SResPacketQuarter>(i, j2);
|
||||
SResPacketQuarter alphav = pset1<SResPacketQuarter>(alpha);
|
||||
|
||||
if (depth - endk > 0)
|
||||
{
|
||||
// We have to handle the last row(s) of the rhs, which
|
||||
// correspond to a half-packet
|
||||
SAccPacketQuarter c0 = predux_half_dowto4(predux_half_dowto4(C0));
|
||||
|
||||
for (Index kk = endk; kk < depth; kk++)
|
||||
{
|
||||
SLhsPacketQuarter a0;
|
||||
SRhsPacketQuarter b0;
|
||||
straits.loadLhsUnaligned(blB, a0);
|
||||
straits.loadRhs(blA, b0);
|
||||
straits.madd(a0,b0,c0,b0);
|
||||
blB += SwappedTraits::LhsProgress/4;
|
||||
blA += 1;
|
||||
}
|
||||
straits.acc(c0, alphav, R);
|
||||
}
|
||||
else
|
||||
{
|
||||
straits.acc(predux_half_dowto4(predux_half_dowto4(C0)), alphav, R);
|
||||
}
|
||||
res.scatterPacket(i, j2, R);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
||||
EIGEN_DONT_INLINE
|
||||
void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,ConjugateRhs>
|
||||
@ -1527,13 +1607,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
prefetch(&blA[0]);
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
|
||||
|
||||
// The following piece of code won't work for 512 bit registers
|
||||
// Moreover, if LhsProgress==8 it assumes that there is a half packet of the same size
|
||||
// as nr (which is currently 4) for the return type.
|
||||
// If LhsProgress is 8 or 16, it assumes that there is a
|
||||
// half or quarter packet, respectively, of the same size as
|
||||
// nr (which is currently 4) for the return type.
|
||||
const int SResPacketHalfSize = unpacket_traits<typename unpacket_traits<SResPacket>::half>::size;
|
||||
const int SResPacketQuarterSize = unpacket_traits<typename unpacket_traits<typename unpacket_traits<SResPacket>::half>::half>::size;
|
||||
if ((SwappedTraits::LhsProgress % 4) == 0 &&
|
||||
(SwappedTraits::LhsProgress <= 8) &&
|
||||
(SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr))
|
||||
(SwappedTraits::LhsProgress<=16) &&
|
||||
(SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) &&
|
||||
(SwappedTraits::LhsProgress!=16 || SResPacketQuarterSize==nr))
|
||||
{
|
||||
SAccPacket C0, C1, C2, C3;
|
||||
straits.initAcc(C0);
|
||||
@ -1610,6 +1692,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
}
|
||||
res.scatterPacket(i, j2, R);
|
||||
}
|
||||
else if (SwappedTraits::LhsProgress==16)
|
||||
{
|
||||
// Special case where we have to first reduce the
|
||||
// accumulation register C0. We specialize the block in
|
||||
// template form, so that LhsProgress < 16 paths don't
|
||||
// fail to compile
|
||||
last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> p;
|
||||
p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0);
|
||||
}
|
||||
else
|
||||
{
|
||||
SResPacket R = res.template gatherPacket<SResPacket>(i, j2);
|
||||
|
Loading…
Reference in New Issue
Block a user