diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index e21d3ef1c..652892e2c 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -2016,130 +2016,6 @@ struct fast_accurate_exp2 { } }; -// This specialization uses a faster algorithm to compute exp2(x) for floats -// in [-0.5;0.5] with a relative accuracy of 1 ulp. -// The minimax polynomial used was calculated using the Sollya tool. -// See sollya.org. -template <> -struct fast_accurate_exp2 { - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) { - // This function approximates exp2(x) by a degree 6 polynomial of the form - // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in - // single precision, and the remaining steps are evaluated with extra precision using - // double word arithmetic. C is an extra precise constant stored as a double word. - // - // The polynomial coefficients were calculated using Sollya commands: - // > n = 6; - // > f = 2^x; - // > interval = [-0.5;0.5]; - // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating); - - const Packet p4 = pset1(1.539513905e-4f); - const Packet p3 = pset1(1.340007293e-3f); - const Packet p2 = pset1(9.618283249e-3f); - const Packet p1 = pset1(5.550328270e-2f); - const Packet p0 = pset1(0.2402264923f); - - const Packet C_hi = pset1(0.6931471825f); - const Packet C_lo = pset1(2.36836577e-08f); - const Packet one = pset1(1.0f); - - // Evaluate P(x) in working precision. - // We evaluate even and odd parts of the polynomial separately - // to gain some instruction level parallelism. - Packet x2 = pmul(x, x); - Packet p_even = pmadd(p4, x2, p2); - Packet p_odd = pmadd(p3, x2, p1); - p_even = pmadd(p_even, x2, p0); - Packet p = pmadd(p_odd, x, p_even); - - // Evaluate the remaining terms of Q(x) with extra precision using - // double word arithmetic. - Packet p_hi, p_lo; - // x * p(x) - twoprod(p, x, p_hi, p_lo); - // C + x * p(x) - Packet q1_hi, q1_lo; - twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); - // x * (C + x * p(x)) - Packet q2_hi, q2_lo; - twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); - // 1 + x * (C + x * p(x)) - Packet q3_hi, q3_lo; - // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum - // for adding it to unity here. - fast_twosum(one, q2_hi, q3_hi, q3_lo); - return padd(q3_hi, padd(q2_lo, q3_lo)); - } -}; - -// in [-0.5;0.5] with a relative accuracy of 1 ulp. -// The minimax polynomial used was calculated using the Sollya tool. -// See sollya.org. -template <> -struct fast_accurate_exp2 { - template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) { - // This function approximates exp2(x) by a degree 10 polynomial of the form - // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in - // single precision, and the remaining steps are evaluated with extra precision using - // double word arithmetic. C is an extra precise constant stored as a double word. - // - // The polynomial coefficients were calculated using Sollya commands: - // > n = 11; - // > f = 2^x; - // > interval = [-0.5;0.5]; - // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating); - - const Packet p9 = pset1(4.431642109085495276e-10); - const Packet p8 = pset1(7.073829923303358410e-9); - const Packet p7 = pset1(1.017822306737031311e-7); - const Packet p6 = pset1(1.321543498017646657e-6); - const Packet p5 = pset1(1.525273342728892877e-5); - const Packet p4 = pset1(1.540353045780084423e-4); - const Packet p3 = pset1(1.333355814685869807e-3); - const Packet p2 = pset1(9.618129107593478832e-3); - const Packet p1 = pset1(5.550410866481961247e-2); - const Packet p0 = pset1(0.240226506959101332); - const Packet C_hi = pset1(0.693147180559945286); - const Packet C_lo = pset1(4.81927865669806721e-17); - const Packet one = pset1(1.0); - - // Evaluate P(x) in working precision. - // We evaluate even and odd parts of the polynomial separately - // to gain some instruction level parallelism. - Packet x2 = pmul(x, x); - Packet p_even = pmadd(p8, x2, p6); - Packet p_odd = pmadd(p9, x2, p7); - p_even = pmadd(p_even, x2, p4); - p_odd = pmadd(p_odd, x2, p5); - p_even = pmadd(p_even, x2, p2); - p_odd = pmadd(p_odd, x2, p3); - p_even = pmadd(p_even, x2, p0); - p_odd = pmadd(p_odd, x2, p1); - Packet p = pmadd(p_odd, x, p_even); - - // Evaluate the remaining terms of Q(x) with extra precision using - // double word arithmetic. - Packet p_hi, p_lo; - // x * p(x) - twoprod(p, x, p_hi, p_lo); - // C + x * p(x) - Packet q1_hi, q1_lo; - twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); - // x * (C + x * p(x)) - Packet q2_hi, q2_lo; - twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); - // 1 + x * (C + x * p(x)) - Packet q3_hi, q3_lo; - // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum - // for adding it to unity here. - fast_twosum(one, q2_hi, q3_hi, q3_lo); - return padd(q3_hi, padd(q2_lo, q3_lo)); - } -}; - // This function implements the non-trivial case of pow(x,y) where x is // positive and y is (possibly) non-integer. // Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x. @@ -2186,11 +2062,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, c // We now have an accurate split of f = n_z + r_z and can compute // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}. - // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy - // using a specialized algorithm. Multiplication by the second factor can - // be done exactly using pldexp(), since it is an integer power of 2. - const Packet e_r = fast_accurate_exp2()(r_z); - return pldexp(e_r, n_z); + // Multiplication by the second factor can be done exactly using pldexp(), since + // it is an integer power of 2. + const Packet e_r = generic_exp2(r_z); + + // Since we know that e_r is in [1/sqrt(2); sqrt(2)], we can use the fast version + // of pldexp to multiply by 2**{n_z} when |n_z| is sufficiently small. + constexpr Scalar kPldExpThresh = std::numeric_limits::max_exponent - 2; + const Packet pldexp_fast_unsafe = pcmp_lt(pset1(kPldExpThresh), pabs(n_z)); + if (predux_any(pldexp_fast_unsafe)) { + return pldexp(e_r, n_z); + } + return pldexp_fast(e_r, n_z); } // Generic implementation of pow(x,y).