optimize predux if architecture is aarch64

This commit is contained in:
Han-Kuan Chen 2021-08-25 19:18:54 +00:00 committed by Rasmus Munk Larsen
parent 4011e4d258
commit ab28419298

View File

@ -2386,12 +2386,17 @@ template<> EIGEN_STRONG_INLINE Packet2f pldexp<Packet2f>(const Packet2f& a, cons
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent)
{ return pldexp_generic(a,exponent); }
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE float predux<Packet2f>(const Packet2f& a) { return vaddv_f32(a); }
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a) { return vaddvq_f32(a); }
#else
template<> EIGEN_STRONG_INLINE float predux<Packet2f>(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
const float32x2_t sum = vadd_f32(vget_low_f32(a), vget_high_f32(a));
return vget_lane_f32(vpadd_f32(sum, sum), 0);
}
#endif
template<> EIGEN_STRONG_INLINE int8_t predux<Packet4c>(const Packet4c& a)
{
const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
@ -2399,6 +2404,10 @@ template<> EIGEN_STRONG_INLINE int8_t predux<Packet4c>(const Packet4c& a)
sum = vpadd_s8(sum, sum);
return vget_lane_s8(sum, 0);
}
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE int8_t predux<Packet8c>(const Packet8c& a) { return vaddv_s8(a); }
template<> EIGEN_STRONG_INLINE int8_t predux<Packet16c>(const Packet16c& a) { return vaddvq_s8(a); }
#else
template<> EIGEN_STRONG_INLINE int8_t predux<Packet8c>(const Packet8c& a)
{
int8x8_t sum = vpadd_s8(a,a);
@ -2414,6 +2423,7 @@ template<> EIGEN_STRONG_INLINE int8_t predux<Packet16c>(const Packet16c& a)
sum = vpadd_s8(sum, sum);
return vget_lane_s8(sum, 0);
}
#endif
template<> EIGEN_STRONG_INLINE uint8_t predux<Packet4uc>(const Packet4uc& a)
{
const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
@ -2421,6 +2431,20 @@ template<> EIGEN_STRONG_INLINE uint8_t predux<Packet4uc>(const Packet4uc& a)
sum = vpadd_u8(sum, sum);
return vget_lane_u8(sum, 0);
}
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE uint8_t predux<Packet8uc>(const Packet8uc& a) { return vaddv_u8(a); }
template<> EIGEN_STRONG_INLINE uint8_t predux<Packet16uc>(const Packet16uc& a) { return vaddvq_u8(a); }
template<> EIGEN_STRONG_INLINE int16_t predux<Packet4s>(const Packet4s& a) { return vaddv_s16(a); }
template<> EIGEN_STRONG_INLINE int16_t predux<Packet8s>(const Packet8s& a) { return vaddvq_s16(a); }
template<> EIGEN_STRONG_INLINE uint16_t predux<Packet4us>(const Packet4us& a) { return vaddv_u16(a); }
template<> EIGEN_STRONG_INLINE uint16_t predux<Packet8us>(const Packet8us& a) { return vaddvq_u16(a); }
template<> EIGEN_STRONG_INLINE int32_t predux<Packet2i>(const Packet2i& a) { return vaddv_s32(a); }
template<> EIGEN_STRONG_INLINE int32_t predux<Packet4i>(const Packet4i& a) { return vaddvq_s32(a); }
template<> EIGEN_STRONG_INLINE uint32_t predux<Packet2ui>(const Packet2ui& a) { return vaddv_u32(a); }
template<> EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a) { return vaddvq_u32(a); }
template<> EIGEN_STRONG_INLINE int64_t predux<Packet2l>(const Packet2l& a) { return vaddvq_s64(a); }
template<> EIGEN_STRONG_INLINE uint64_t predux<Packet2ul>(const Packet2ul& a) { return vaddvq_u64(a); }
#else
template<> EIGEN_STRONG_INLINE uint8_t predux<Packet8uc>(const Packet8uc& a)
{
uint8x8_t sum = vpadd_u8(a,a);
@ -2476,6 +2500,7 @@ template<> EIGEN_STRONG_INLINE int64_t predux<Packet2l>(const Packet2l& a)
{ return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1); }
template<> EIGEN_STRONG_INLINE uint64_t predux<Packet2ul>(const Packet2ul& a)
{ return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1); }
#endif
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c predux_half_dowto4(const Packet8c& a)
{
@ -2576,6 +2601,10 @@ template<> EIGEN_STRONG_INLINE uint64_t predux_mul<Packet2ul>(const Packet2ul& a
{ return vgetq_lane_u64(a, 0) * vgetq_lane_u64(a, 1); }
// min
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE float predux_min<Packet2f>(const Packet2f& a) { return vminv_f32(a); }
template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a) { return vminvq_f32(a); }
#else
template<> EIGEN_STRONG_INLINE float predux_min<Packet2f>(const Packet2f& a)
{ return vget_lane_f32(vpmin_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
@ -2583,6 +2612,7 @@ template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
const float32x2_t min = vmin_f32(vget_low_f32(a), vget_high_f32(a));
return vget_lane_f32(vpmin_f32(min, min), 0);
}
#endif
template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet4c>(const Packet4c& a)
{
const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
@ -2590,6 +2620,10 @@ template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet4c>(const Packet4c& a)
min = vpmin_s8(min, min);
return vget_lane_s8(min, 0);
}
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet8c>(const Packet8c& a) { return vminv_s8(a); }
template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet16c>(const Packet16c& a) { return vminvq_s8(a); }
#else
template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet8c>(const Packet8c& a)
{
int8x8_t min = vpmin_s8(a,a);
@ -2605,6 +2639,7 @@ template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet16c>(const Packet16c& a)
min = vpmin_s8(min, min);
return vget_lane_s8(min, 0);
}
#endif
template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet4uc>(const Packet4uc& a)
{
const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
@ -2612,6 +2647,18 @@ template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet4uc>(const Packet4uc& a)
min = vpmin_u8(min, min);
return vget_lane_u8(min, 0);
}
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet8uc>(const Packet8uc& a) { return vminv_u8(a); }
template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet16uc>(const Packet16uc& a) { return vminvq_u8(a); }
template<> EIGEN_STRONG_INLINE int16_t predux_min<Packet4s>(const Packet4s& a) { return vminv_s16(a); }
template<> EIGEN_STRONG_INLINE int16_t predux_min<Packet8s>(const Packet8s& a) { return vminvq_s16(a); }
template<> EIGEN_STRONG_INLINE uint16_t predux_min<Packet4us>(const Packet4us& a) { return vminv_u16(a); }
template<> EIGEN_STRONG_INLINE uint16_t predux_min<Packet8us>(const Packet8us& a) { return vminvq_u16(a); }
template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet2i>(const Packet2i& a) { return vminv_s32(a); }
template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet4i>(const Packet4i& a) { return vminvq_s32(a); }
template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet2ui>(const Packet2ui& a) { return vminv_u32(a); }
template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet4ui>(const Packet4ui& a) { return vminvq_u32(a); }
#else
template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet8uc>(const Packet8uc& a)
{
uint8x8_t min = vpmin_u8(a,a);
@ -2665,12 +2712,17 @@ template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet4ui>(const Packet4ui& a
const uint32x2_t min = vmin_u32(vget_low_u32(a), vget_high_u32(a));
return vget_lane_u32(vpmin_u32(min, min), 0);
}
#endif
template<> EIGEN_STRONG_INLINE int64_t predux_min<Packet2l>(const Packet2l& a)
{ return (std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); }
template<> EIGEN_STRONG_INLINE uint64_t predux_min<Packet2ul>(const Packet2ul& a)
{ return (std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); }
// max
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE float predux_max<Packet2f>(const Packet2f& a) { return vmaxv_f32(a); }
template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a) { return vmaxvq_f32(a); }
#else
template<> EIGEN_STRONG_INLINE float predux_max<Packet2f>(const Packet2f& a)
{ return vget_lane_f32(vpmax_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
@ -2678,6 +2730,7 @@ template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
const float32x2_t max = vmax_f32(vget_low_f32(a), vget_high_f32(a));
return vget_lane_f32(vpmax_f32(max, max), 0);
}
#endif
template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet4c>(const Packet4c& a)
{
const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
@ -2685,6 +2738,10 @@ template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet4c>(const Packet4c& a)
max = vpmax_s8(max, max);
return vget_lane_s8(max, 0);
}
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet8c>(const Packet8c& a) { return vmaxv_s8(a); }
template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet16c>(const Packet16c& a) { return vmaxvq_s8(a); }
#else
template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet8c>(const Packet8c& a)
{
int8x8_t max = vpmax_s8(a,a);
@ -2700,6 +2757,7 @@ template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet16c>(const Packet16c& a)
max = vpmax_s8(max, max);
return vget_lane_s8(max, 0);
}
#endif
template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet4uc>(const Packet4uc& a)
{
const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
@ -2707,6 +2765,18 @@ template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet4uc>(const Packet4uc& a)
max = vpmax_u8(max, max);
return vget_lane_u8(max, 0);
}
#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet8uc>(const Packet8uc& a) { return vmaxv_u8(a); }
template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet16uc>(const Packet16uc& a) { return vmaxvq_u8(a); }
template<> EIGEN_STRONG_INLINE int16_t predux_max<Packet4s>(const Packet4s& a) { return vmaxv_s16(a); }
template<> EIGEN_STRONG_INLINE int16_t predux_max<Packet8s>(const Packet8s& a) { return vmaxvq_s16(a); }
template<> EIGEN_STRONG_INLINE uint16_t predux_max<Packet4us>(const Packet4us& a) { return vmaxv_u16(a); }
template<> EIGEN_STRONG_INLINE uint16_t predux_max<Packet8us>(const Packet8us& a) { return vmaxvq_u16(a); }
template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet2i>(const Packet2i& a) { return vmaxv_s32(a); }
template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet4i>(const Packet4i& a) { return vmaxvq_s32(a); }
template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet2ui>(const Packet2ui& a) { return vmaxv_u32(a); }
template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet4ui>(const Packet4ui& a) { return vmaxvq_u32(a); }
#else
template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet8uc>(const Packet8uc& a)
{
uint8x8_t max = vpmax_u8(a,a);
@ -2760,6 +2830,7 @@ template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet4ui>(const Packet4ui& a
const uint32x2_t max = vmax_u32(vget_low_u32(a), vget_high_u32(a));
return vget_lane_u32(vpmax_u32(max, max), 0);
}
#endif
template<> EIGEN_STRONG_INLINE int64_t predux_max<Packet2l>(const Packet2l& a)
{ return (std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); }
template<> EIGEN_STRONG_INLINE uint64_t predux_max<Packet2ul>(const Packet2ul& a)
@ -3848,14 +3919,8 @@ template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vabsq_f64(a); }
#if EIGEN_COMP_CLANG && defined(__apple_build_version__)
// workaround ICE, see bug 907
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
{ return (vget_low_f64(a) + vget_high_f64(a))[0]; }
#else
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
{ return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0); }
#endif
{ return vaddvq_f64(a); }
// Other reduction functions:
// mul
@ -3869,11 +3934,11 @@ template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
// min
template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a)
{ return vgetq_lane_f64(vpminq_f64(a,a), 0); }
{ return vminvq_f64(a); }
// max
template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
{ return vgetq_lane_f64(vpmaxq_f64(a,a), 0); }
{ return vmaxvq_f64(a); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
@ -4478,51 +4543,29 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4hf>(const Packet4hf& a) {
template <>
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8hf>(const Packet8hf& a) {
float16x4_t a_lo, a_hi, min;
a_lo = vget_low_f16(a);
a_hi = vget_high_f16(a);
min = vpmin_f16(a_lo, a_hi);
min = vpmin_f16(min, min);
min = vpmin_f16(min, min);
Eigen::half h;
h.x = vget_lane_f16(min, 0);
h.x = vminvq_f16(a);
return h;
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4hf>(const Packet4hf& a) {
Packet4hf tmp;
tmp = vpmin_f16(a, a);
tmp = vpmin_f16(tmp, tmp);
Eigen::half h;
h.x = vget_lane_f16(tmp, 0);
h.x = vminv_f16(a);
return h;
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8hf>(const Packet8hf& a) {
float16x4_t a_lo, a_hi, max;
a_lo = vget_low_f16(a);
a_hi = vget_high_f16(a);
max = vpmax_f16(a_lo, a_hi);
max = vpmax_f16(max, max);
max = vpmax_f16(max, max);
Eigen::half h;
h.x = vget_lane_f16(max, 0);
h.x = vmaxvq_f16(a);
return h;
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4hf>(const Packet4hf& a) {
Packet4hf tmp;
tmp = vpmax_f16(a, a);
tmp = vpmax_f16(tmp, tmp);
Eigen::half h;
h.x = vget_lane_f16(tmp, 0);
h.x = vmaxv_f16(a);
return h;
}