mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
Speed up GEMV on AVX-512 builds, just as done for GEBP previously.
We take advantage of smaller SIMD registers as well, in that case. Gains up to 3x for select input sizes.
This commit is contained in:
parent
665ac22cc6
commit
d4dcb71bcb
@ -15,13 +15,13 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
enum PacketSizeType {
|
||||
PacketFull = 0,
|
||||
PacketHalf,
|
||||
PacketQuarter
|
||||
enum GEBPPacketSizeType {
|
||||
GEBPPacketFull = 0,
|
||||
GEBPPacketHalf,
|
||||
GEBPPacketQuarter
|
||||
};
|
||||
|
||||
template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=PacketFull>
|
||||
template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=GEBPPacketFull>
|
||||
class gebp_traits;
|
||||
|
||||
|
||||
@ -375,10 +375,10 @@ template <int N, typename T1, typename T2, typename T3>
|
||||
struct packet_conditional { typedef T3 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct packet_conditional<PacketFull, T1, T2, T3> { typedef T1 type; };
|
||||
struct packet_conditional<GEBPPacketFull, T1, T2, T3> { typedef T1 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct packet_conditional<PacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
|
||||
#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
|
||||
typedef typename packet_conditional<packet_size, \
|
||||
@ -1054,8 +1054,8 @@ protected:
|
||||
#if EIGEN_ARCH_ARM64 && defined EIGEN_VECTORIZE_NEON
|
||||
|
||||
template<>
|
||||
struct gebp_traits <float, float, false, false,Architecture::NEON,PacketFull>
|
||||
: gebp_traits<float,float,false,false,Architecture::Generic,PacketFull>
|
||||
struct gebp_traits <float, float, false, false,Architecture::NEON,GEBPPacketFull>
|
||||
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
|
||||
{
|
||||
typedef float RhsPacket;
|
||||
|
||||
@ -1203,8 +1203,8 @@ template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMa
|
||||
struct gebp_kernel
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketHalf> HalfTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketQuarter> QuarterTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
|
@ -14,6 +14,54 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
enum GEMVPacketSizeType {
|
||||
GEMVPacketFull = 0,
|
||||
GEMVPacketHalf,
|
||||
GEMVPacketQuarter
|
||||
};
|
||||
|
||||
template <int N, typename T1, typename T2, typename T3>
|
||||
struct gemv_packet_cond { typedef T3 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, int _PacketSize=GEMVPacketFull>
|
||||
class gemv_traits
|
||||
{
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
|
||||
#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
|
||||
typedef typename gemv_packet_cond<packet_size, \
|
||||
typename packet_traits<name ## Scalar>::type, \
|
||||
typename packet_traits<name ## Scalar>::half, \
|
||||
typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
|
||||
prefix ## name ## Packet
|
||||
|
||||
PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
|
||||
PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
|
||||
PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
|
||||
#undef PACKET_DECL_COND_PREFIX
|
||||
|
||||
public:
|
||||
enum {
|
||||
Vectorizable = unpacket_traits<_LhsPacket>::vectorizable &&
|
||||
unpacket_traits<_RhsPacket>::vectorizable &&
|
||||
int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size),
|
||||
LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
|
||||
ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1
|
||||
};
|
||||
|
||||
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
|
||||
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
|
||||
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
|
||||
};
|
||||
|
||||
|
||||
/* Optimized col-major matrix * vector product:
|
||||
* This algorithm processes the matrix per vertical panels,
|
||||
* which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments.
|
||||
@ -30,23 +78,23 @@ namespace internal {
|
||||
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
|
||||
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
|
||||
{
|
||||
typedef gemv_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
|
||||
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
|
||||
enum {
|
||||
Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
|
||||
&& int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
|
||||
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
|
||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
|
||||
};
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
typedef typename Traits::RhsPacket RhsPacket;
|
||||
typedef typename Traits::ResPacket ResPacket;
|
||||
|
||||
typedef typename packet_traits<LhsScalar>::type _LhsPacket;
|
||||
typedef typename packet_traits<RhsScalar>::type _RhsPacket;
|
||||
typedef typename packet_traits<ResScalar>::type _ResPacket;
|
||||
typedef typename HalfTraits::LhsPacket LhsPacketHalf;
|
||||
typedef typename HalfTraits::RhsPacket RhsPacketHalf;
|
||||
typedef typename HalfTraits::ResPacket ResPacketHalf;
|
||||
|
||||
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
|
||||
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
|
||||
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
|
||||
typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
|
||||
typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
|
||||
typedef typename QuarterTraits::ResPacket ResPacketQuarter;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
|
||||
Index rows, Index cols,
|
||||
@ -73,19 +121,33 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
|
||||
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
|
||||
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
|
||||
conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
|
||||
conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
|
||||
|
||||
const Index lhsStride = lhs.stride();
|
||||
// TODO: for padded aligned inputs, we could enable aligned reads
|
||||
enum { LhsAlignment = Unaligned };
|
||||
enum { LhsAlignment = Unaligned,
|
||||
ResPacketSize = Traits::ResPacketSize,
|
||||
ResPacketSizeHalf = HalfTraits::ResPacketSize,
|
||||
ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
|
||||
LhsPacketSize = Traits::LhsPacketSize,
|
||||
HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
|
||||
HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
|
||||
};
|
||||
|
||||
const Index n8 = rows-8*ResPacketSize+1;
|
||||
const Index n4 = rows-4*ResPacketSize+1;
|
||||
const Index n3 = rows-3*ResPacketSize+1;
|
||||
const Index n2 = rows-2*ResPacketSize+1;
|
||||
const Index n1 = rows-1*ResPacketSize+1;
|
||||
const Index n_half = rows-1*ResPacketSizeHalf+1;
|
||||
const Index n_quarter = rows-1*ResPacketSizeQuarter+1;
|
||||
|
||||
// TODO: improve the following heuristic:
|
||||
const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4);
|
||||
ResPacket palpha = pset1<ResPacket>(alpha);
|
||||
ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
|
||||
ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
|
||||
|
||||
for(Index j2=0; j2<cols; j2+=block_cols)
|
||||
{
|
||||
@ -190,6 +252,28 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
|
||||
i+=ResPacketSize;
|
||||
}
|
||||
if(HasHalf && i<n_half)
|
||||
{
|
||||
ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
|
||||
for(Index j=j2; j<jend; j+=1)
|
||||
{
|
||||
RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0));
|
||||
c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0);
|
||||
}
|
||||
pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0)));
|
||||
i+=ResPacketSizeHalf;
|
||||
}
|
||||
if(HasQuarter && i<n_quarter)
|
||||
{
|
||||
ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
|
||||
for(Index j=j2; j<jend; j+=1)
|
||||
{
|
||||
RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0));
|
||||
c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0);
|
||||
}
|
||||
pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0)));
|
||||
i+=ResPacketSizeQuarter;
|
||||
}
|
||||
for(;i<rows;++i)
|
||||
{
|
||||
ResScalar c0(0);
|
||||
@ -213,23 +297,24 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
|
||||
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
|
||||
{
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
|
||||
|
||||
enum {
|
||||
Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
|
||||
&& int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
|
||||
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
|
||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
|
||||
};
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
|
||||
typedef typename packet_traits<LhsScalar>::type _LhsPacket;
|
||||
typedef typename packet_traits<RhsScalar>::type _RhsPacket;
|
||||
typedef typename packet_traits<ResScalar>::type _ResPacket;
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
static const Index LhsPacketSize = Traits::LhsPacketSize;
|
||||
typedef typename Traits::RhsPacket RhsPacket;
|
||||
typedef typename Traits::ResPacket ResPacket;
|
||||
|
||||
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
|
||||
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
|
||||
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
|
||||
typedef typename HalfTraits::LhsPacket LhsPacketHalf;
|
||||
typedef typename HalfTraits::RhsPacket RhsPacketHalf;
|
||||
typedef typename HalfTraits::ResPacket ResPacketHalf;
|
||||
|
||||
typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
|
||||
typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
|
||||
typedef typename QuarterTraits::ResPacket ResPacketQuarter;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
|
||||
Index rows, Index cols,
|
||||
@ -254,6 +339,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
eigen_internal_assert(rhs.stride()==1);
|
||||
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
|
||||
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
|
||||
conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
|
||||
conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
|
||||
|
||||
// TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
|
||||
// processing 8 rows at once might be counter productive wrt cache.
|
||||
@ -262,7 +349,16 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
const Index n2 = rows-1;
|
||||
|
||||
// TODO: for padded aligned inputs, we could enable aligned reads
|
||||
enum { LhsAlignment = Unaligned };
|
||||
enum { LhsAlignment = Unaligned,
|
||||
ResPacketSize = Traits::ResPacketSize,
|
||||
ResPacketSizeHalf = HalfTraits::ResPacketSize,
|
||||
ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
|
||||
LhsPacketSize = Traits::LhsPacketSize,
|
||||
LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
|
||||
LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
|
||||
HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
|
||||
HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
|
||||
};
|
||||
|
||||
Index i=0;
|
||||
for(; i<n8; i+=8)
|
||||
@ -383,6 +479,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
for(; i<rows; ++i)
|
||||
{
|
||||
ResPacket c0 = pset1<ResPacket>(ResScalar(0));
|
||||
ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
|
||||
ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
|
||||
Index j=0;
|
||||
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
|
||||
{
|
||||
@ -390,6 +488,22 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
|
||||
}
|
||||
ResScalar cc0 = predux(c0);
|
||||
if (HasHalf) {
|
||||
for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf)
|
||||
{
|
||||
RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0);
|
||||
c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h);
|
||||
}
|
||||
cc0 += predux(c0_h);
|
||||
}
|
||||
if (HasQuarter) {
|
||||
for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter)
|
||||
{
|
||||
RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0);
|
||||
c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q);
|
||||
}
|
||||
cc0 += predux(c0_q);
|
||||
}
|
||||
for(; j<cols; ++j)
|
||||
{
|
||||
cc0 += cj.pmul(lhs(i,j), rhs(j,0));
|
||||
|
Loading…
Reference in New Issue
Block a user