Add AVX plog<Packet4d> and AVX512 plog<Packet8d> ops,also unified AVX512 plog<Packet16f> op with generic api

This commit is contained in:
Guoqiang QI 2020-10-15 00:54:45 +00:00 committed by Rasmus Munk Larsen
parent af6f43d7ff
commit 4700713faf
4 changed files with 61 additions and 94 deletions

View File

@ -36,6 +36,12 @@ plog<Packet8f>(const Packet8f& _x) {
return plog_float(_x);
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
plog<Packet4d>(const Packet4d& _x) {
return plog_double(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f plog1p<Packet8f>(const Packet8f& _x) {
return generic_plog1p(_x);

View File

@ -98,6 +98,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasCmp = 1,
HasDiv = 1,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
@ -215,6 +216,7 @@ template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { re
template<> EIGEN_STRONG_INLINE Packet8i pset1<Packet8i>(const int& from) { return _mm256_set1_epi32(from); }
template<> EIGEN_STRONG_INLINE Packet8f pset1frombits<Packet8f>(unsigned int from) { return _mm256_castsi256_ps(pset1<Packet8i>(from)); }
template<> EIGEN_STRONG_INLINE Packet4d pset1frombits<Packet4d>(uint64_t from) { return _mm256_castsi256_pd(_mm256_set1_epi64x(from)); }
template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _mm256_setzero_ps(); }
template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); }
@ -686,6 +688,31 @@ template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Pack
return pfrexp_float(a,exponent);
}
template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
const Packet4d cst_1022d = pset1<Packet4d>(1022.0);
const Packet4d cst_half = pset1<Packet4d>(0.5);
const Packet4d cst_inv_mant_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
__m256i a_expo = _mm256_castpd_si256(a);
#ifdef EIGEN_VECTORIZE_AVX2
a_expo = _mm256_srli_epi64(a_expo, 52);
#else
__m128i lo = _mm_srli_epi64(_mm256_extractf128_si256(a_expo, 0), 52);
__m128i hi = _mm_srli_epi64(_mm256_extractf128_si256(a_expo, 1), 52);
a_expo = _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512DQ finally provides an instruction for this
exponent = _mm256_cvtepi64_pd(a_expo);
#else
exponent = _mm256_set_pd(static_cast<double>(_mm256_extract_epi64(a_expo, 3)),
static_cast<double>(_mm256_extract_epi64(a_expo, 2)),
static_cast<double>(_mm256_extract_epi64(a_expo, 1)),
static_cast<double>(_mm256_extract_epi64(a_expo, 0)));
#endif
exponent = psub(exponent, cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
}
template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {
return pldexp_float(a,exponent);
}

View File

@ -35,104 +35,17 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
// Natural logarithm
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
// be easily approximated by a polynomial centered on m=1 for stability.
#if defined(EIGEN_VECTORIZE_AVX512DQ)
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
plog<Packet16f>(const Packet16f& _x) {
Packet16f x = _x;
_EIGEN_DECLARE_CONST_Packet16f(1, 1.0f);
_EIGEN_DECLARE_CONST_Packet16f(half, 0.5f);
_EIGEN_DECLARE_CONST_Packet16f(126f, 126.0f);
return plog_float(_x);
}
_EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inv_mant_mask, ~0x7f800000);
// The smallest non denormalized float number.
_EIGEN_DECLARE_CONST_Packet16f_FROM_INT(min_norm_pos, 0x00800000);
_EIGEN_DECLARE_CONST_Packet16f_FROM_INT(minus_inf, 0xff800000);
_EIGEN_DECLARE_CONST_Packet16f_FROM_INT(pos_inf, 0x7f800000);
_EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
// Polynomial coefficients.
_EIGEN_DECLARE_CONST_Packet16f(cephes_SQRTHF, 0.707106781186547524f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p0, 7.0376836292E-2f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p1, -1.1514610310E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p2, 1.1676998740E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p3, -1.2420140846E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p4, +1.4249322787E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p5, -1.6668057665E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p6, +2.0000714765E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p7, -2.4999993993E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_p8, +3.3333331174E-1f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_q1, -2.12194440e-4f);
_EIGEN_DECLARE_CONST_Packet16f(cephes_log_q2, 0.693359375f);
// invalid_mask is set to true when x is NaN
__mmask16 invalid_mask = _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_NGE_UQ);
__mmask16 iszero_mask = _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_EQ_OQ);
// Truncate input values to the minimum positive normal.
x = pmax(x, p16f_min_norm_pos);
// Extract the shifted exponents.
Packet16f emm0 = _mm512_cvtepi32_ps(_mm512_srli_epi32((preinterpret<Packet16i,Packet16f>(x)), 23));
Packet16f e = _mm512_sub_ps(emm0, p16f_126f);
// Set the exponents to -1, i.e. x are in the range [0.5,1).
x = _mm512_and_ps(x, p16f_inv_mant_mask);
x = _mm512_or_ps(x, p16f_half);
// part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
// and shift by -1. The values are then centered around 0, which improves
// the stability of the polynomial evaluation.
// if( x < SQRTHF ) {
// e -= 1;
// x = x + x - 1.0;
// } else { x = x - 1.0; }
__mmask16 mask = _mm512_cmp_ps_mask(x, p16f_cephes_SQRTHF, _CMP_LT_OQ);
Packet16f tmp = _mm512_mask_blend_ps(mask, _mm512_setzero_ps(), x);
x = psub(x, p16f_1);
e = psub(e, _mm512_mask_blend_ps(mask, _mm512_setzero_ps(), p16f_1));
x = padd(x, tmp);
Packet16f x2 = pmul(x, x);
Packet16f x3 = pmul(x2, x);
// Evaluate the polynomial approximant of degree 8 in three parts, probably
// to improve instruction-level parallelism.
Packet16f y, y1, y2;
y = pmadd(p16f_cephes_log_p0, x, p16f_cephes_log_p1);
y1 = pmadd(p16f_cephes_log_p3, x, p16f_cephes_log_p4);
y2 = pmadd(p16f_cephes_log_p6, x, p16f_cephes_log_p7);
y = pmadd(y, x, p16f_cephes_log_p2);
y1 = pmadd(y1, x, p16f_cephes_log_p5);
y2 = pmadd(y2, x, p16f_cephes_log_p8);
y = pmadd(y, x3, y1);
y = pmadd(y, x3, y2);
y = pmul(y, x3);
// Add the logarithm of the exponent back to the result of the interpolation.
y1 = pmul(e, p16f_cephes_log_q1);
tmp = pmul(x2, p16f_half);
y = padd(y, y1);
x = psub(x, tmp);
y2 = pmul(e, p16f_cephes_log_q2);
x = padd(x, y);
x = padd(x, y2);
__mmask16 pos_inf_mask = _mm512_cmp_ps_mask(_x,p16f_pos_inf,_CMP_EQ_OQ);
// Filter out invalid inputs, i.e.:
// - negative arg will be NAN,
// - 0 will be -INF.
// - +INF will be +INF
return _mm512_mask_blend_ps(iszero_mask,
_mm512_mask_blend_ps(invalid_mask,
_mm512_mask_blend_ps(pos_inf_mask,x,p16f_pos_inf),
p16f_nan),
p16f_minus_inf);
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
plog<Packet8d>(const Packet8d& _x) {
return plog_double(_x);
}
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)

View File

@ -118,6 +118,9 @@ template<> struct packet_traits<double> : default_packet_traits
size = 8,
HasHalfPacket = 1,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1,
#endif
HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
#endif
@ -184,6 +187,11 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
return _mm512_castsi512_ps(_mm512_set1_epi32(from));
}
template <>
EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(uint64_t from) {
return _mm512_castsi512_pd(_mm512_set1_epi64(from));
}
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
return _mm512_broadcastss_ps(_mm_load_ps1(from));
@ -821,6 +829,20 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
_mm512_set1_epi64(0x7fffffffffffffff)));
}
template<>
EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
return pfrexp_float(a, exponent);
}
template<>
EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent){
const Packet8d cst_1022d = pset1<Packet8d>(1022.0);
const Packet8d cst_half = pset1<Packet8d>(0.5);
const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
@ -1264,7 +1286,6 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f,Packet16i>(const
return _mm512_castsi512_ps(a);
}
// Packet math for Eigen::half
template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);