Simplify and speed up pow() by 5-6%

This commit is contained in:
Rasmus Munk Larsen 2024-11-20 12:45:00 +00:00
parent 6c6ce9d06b
commit 5610a13b77

View File

@ -2016,130 +2016,6 @@ struct fast_accurate_exp2 {
}
};
// This specialization uses a faster algorithm to compute exp2(x) for floats
// in [-0.5;0.5] with a relative accuracy of 1 ulp.
// The minimax polynomial used was calculated using the Sollya tool.
// See sollya.org.
template <>
struct fast_accurate_exp2<float> {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) {
// This function approximates exp2(x) by a degree 6 polynomial of the form
// Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
// single precision, and the remaining steps are evaluated with extra precision using
// double word arithmetic. C is an extra precise constant stored as a double word.
//
// The polynomial coefficients were calculated using Sollya commands:
// > n = 6;
// > f = 2^x;
// > interval = [-0.5;0.5];
// > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
const Packet p4 = pset1<Packet>(1.539513905e-4f);
const Packet p3 = pset1<Packet>(1.340007293e-3f);
const Packet p2 = pset1<Packet>(9.618283249e-3f);
const Packet p1 = pset1<Packet>(5.550328270e-2f);
const Packet p0 = pset1<Packet>(0.2402264923f);
const Packet C_hi = pset1<Packet>(0.6931471825f);
const Packet C_lo = pset1<Packet>(2.36836577e-08f);
const Packet one = pset1<Packet>(1.0f);
// Evaluate P(x) in working precision.
// We evaluate even and odd parts of the polynomial separately
// to gain some instruction level parallelism.
Packet x2 = pmul(x, x);
Packet p_even = pmadd(p4, x2, p2);
Packet p_odd = pmadd(p3, x2, p1);
p_even = pmadd(p_even, x2, p0);
Packet p = pmadd(p_odd, x, p_even);
// Evaluate the remaining terms of Q(x) with extra precision using
// double word arithmetic.
Packet p_hi, p_lo;
// x * p(x)
twoprod(p, x, p_hi, p_lo);
// C + x * p(x)
Packet q1_hi, q1_lo;
twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
// x * (C + x * p(x))
Packet q2_hi, q2_lo;
twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
// 1 + x * (C + x * p(x))
Packet q3_hi, q3_lo;
// Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
// for adding it to unity here.
fast_twosum(one, q2_hi, q3_hi, q3_lo);
return padd(q3_hi, padd(q2_lo, q3_lo));
}
};
// in [-0.5;0.5] with a relative accuracy of 1 ulp.
// The minimax polynomial used was calculated using the Sollya tool.
// See sollya.org.
template <>
struct fast_accurate_exp2<double> {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) {
// This function approximates exp2(x) by a degree 10 polynomial of the form
// Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
// single precision, and the remaining steps are evaluated with extra precision using
// double word arithmetic. C is an extra precise constant stored as a double word.
//
// The polynomial coefficients were calculated using Sollya commands:
// > n = 11;
// > f = 2^x;
// > interval = [-0.5;0.5];
// > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
const Packet p0 = pset1<Packet>(0.240226506959101332);
const Packet C_hi = pset1<Packet>(0.693147180559945286);
const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
const Packet one = pset1<Packet>(1.0);
// Evaluate P(x) in working precision.
// We evaluate even and odd parts of the polynomial separately
// to gain some instruction level parallelism.
Packet x2 = pmul(x, x);
Packet p_even = pmadd(p8, x2, p6);
Packet p_odd = pmadd(p9, x2, p7);
p_even = pmadd(p_even, x2, p4);
p_odd = pmadd(p_odd, x2, p5);
p_even = pmadd(p_even, x2, p2);
p_odd = pmadd(p_odd, x2, p3);
p_even = pmadd(p_even, x2, p0);
p_odd = pmadd(p_odd, x2, p1);
Packet p = pmadd(p_odd, x, p_even);
// Evaluate the remaining terms of Q(x) with extra precision using
// double word arithmetic.
Packet p_hi, p_lo;
// x * p(x)
twoprod(p, x, p_hi, p_lo);
// C + x * p(x)
Packet q1_hi, q1_lo;
twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
// x * (C + x * p(x))
Packet q2_hi, q2_lo;
twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
// 1 + x * (C + x * p(x))
Packet q3_hi, q3_lo;
// Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
// for adding it to unity here.
fast_twosum(one, q2_hi, q3_hi, q3_lo);
return padd(q3_hi, padd(q2_lo, q3_lo));
}
};
// This function implements the non-trivial case of pow(x,y) where x is
// positive and y is (possibly) non-integer.
// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
@ -2186,11 +2062,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, c
// We now have an accurate split of f = n_z + r_z and can compute
// x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
// Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
// using a specialized algorithm. Multiplication by the second factor can
// be done exactly using pldexp(), since it is an integer power of 2.
const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
return pldexp(e_r, n_z);
// Multiplication by the second factor can be done exactly using pldexp(), since
// it is an integer power of 2.
const Packet e_r = generic_exp2(r_z);
// Since we know that e_r is in [1/sqrt(2); sqrt(2)], we can use the fast version
// of pldexp to multiply by 2**{n_z} when |n_z| is sufficiently small.
constexpr Scalar kPldExpThresh = std::numeric_limits<Scalar>::max_exponent - 2;
const Packet pldexp_fast_unsafe = pcmp_lt(pset1<Packet>(kPldExpThresh), pabs(n_z));
if (predux_any(pldexp_fast_unsafe)) {
return pldexp(e_r, n_z);
}
return pldexp_fast(e_r, n_z);
}
// Generic implementation of pow(x,y).