Fix a few issues for AVX512. This change enables vectorized versions of log, exp, log1p, expm1 when AVX512DQ is not available.

This commit is contained in:
Rasmus Munk Larsen 2020-12-01 11:31:47 -08:00
parent 1992af3de2
commit e57281a741
3 changed files with 10 additions and 15 deletions

View File

@ -232,6 +232,7 @@ EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) {
_mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
}
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i pset1<Packet8i>(const int& from) { return _mm256_set1_epi32(from); }
@ -724,13 +725,11 @@ template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Pack
__m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask));
#ifdef EIGEN_VECTORIZE_AVX2
a_expo = _mm256_srli_epi64(a_expo, 52);
#endif
#if defined(EIGEN_VECTORIZE_AVX2) && defined(EIGEN_VECTORIZE_AVX512DQ)
exponent = _mm256_cvtepi64_pd(a_expo);
__m128i lo = _mm256_extractf128_si256(a_expo, 0);
__m128i hi = _mm256_extractf128_si256(a_expo, 1);
#else
__m128i lo = _mm256_extractf128_si256(a_expo, 0);
__m128i hi = _mm256_extractf128_si256(a_expo, 1);
#ifndef EIGEN_VECTORIZE_AVX2
lo = _mm_srli_epi64(lo, 52);
hi = _mm_srli_epi64(hi, 52);
#endif
@ -738,7 +737,6 @@ template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Pack
Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3));
exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0);
exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1);
#endif // EIGEN_VECTORIZE_AVX512DQ
exponent = psub(exponent, cst_1022d);
const Packet4d cst_mant_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
return por(pand(a, cst_mant_mask), cst_half);

View File

@ -35,7 +35,6 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
#if defined(EIGEN_VECTORIZE_AVX512DQ)
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
plog<Packet16f>(const Packet16f& _x) {
@ -50,7 +49,6 @@ plog<Packet8d>(const Packet8d& _x) {
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
@ -334,7 +332,6 @@ EIGEN_STRONG_INLINE Packet8d prsqrt<Packet8d>(const Packet8d& x) {
}
#endif
#if defined(EIGEN_VECTORIZE_AVX512DQ)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
@ -350,7 +347,6 @@ Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
#endif
#endif

View File

@ -108,13 +108,11 @@ template<> struct packet_traits<float> : default_packet_traits
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasNdtri = 1,
HasBessel = 1,
#endif
HasExp = 1,
HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
@ -139,9 +137,7 @@ template<> struct packet_traits<double> : default_packet_traits
size = 8,
HasHalfPacket = 1,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1,
#endif
HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
#endif
@ -888,11 +884,16 @@ EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& e
}
template<>
EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent){
EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
const Packet8d cst_1022d = pset1<Packet8d>(1022.0);
#ifdef EIGEN_TEST_AVX512DQ
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
#else
exponent = psub(_mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(a), 52))),
cst_1022d);
#endif
const Packet8d cst_half = pset1<Packet8d>(0.5);
const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
}