Accurate pow, part 2. This change adds specializations of log2 and exp2 for double that

make pow<double> accurate the 1 ULP. Speed for AVX-512 is within 0.5% of the currect
implementation.
This commit is contained in:
Rasmus Munk Larsen 2021-02-22 16:06:00 -08:00
parent 2ac0b78739
commit 88d4c6d4c8

View File

@ -484,7 +484,7 @@ Packet pexp_float(const Packet _x)
y1 = pmadd(y1, r, cst_cephes_exp_p5);
y = pmadd(y, r3, y1);
y = pmadd(y, r2, y2);
// Return 2^m * exp(r).
// TODO: replace pldexp with faster implementation since y in [-1, 1).
return pmax(pldexp(y,m), _x);
@ -1003,6 +1003,20 @@ EIGEN_STRONG_INLINE
fast_twosum(r_hi, s, s_hi, s_lo);
}
// This is a version of twosum for adding a floating point number x to
// double word number {y_hi, y_lo} number, with the assumption
// that |x| >= |y_hi|.
template<typename Packet>
EIGEN_STRONG_INLINE
void fast_twosum(const Packet& x,
const Packet& y_hi, const Packet& y_lo,
Packet& s_hi, Packet& s_lo) {
Packet r_hi, r_lo;
fast_twosum(x, y_hi, r_hi, r_lo);
const Packet s = padd(y_lo, r_lo);
fast_twosum(r_hi, s, s_hi, s_lo);
}
// This function implements the multiplication of a double word
// number represented by {x_hi, x_lo} by a floating point number y.
// It returns the result as a pair {p_hi, p_lo} such that
@ -1024,6 +1038,50 @@ void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
fast_twosum(t_hi, t_lo2, p_hi, p_lo);
}
// This function implements the multiplication of two double word
// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
// It returns the result as a pair {p_hi, p_lo} such that
// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error
// of less than 2*2^{-2p}, where p is the number of significand bit
// in the floating point type.
template<typename Packet>
EIGEN_STRONG_INLINE
void twoprod(const Packet& x_hi, const Packet& x_lo,
const Packet& y_hi, const Packet& y_lo,
Packet& p_hi, Packet& p_lo) {
Packet p_hi_hi, p_hi_lo;
twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo);
Packet p_lo_hi, p_lo_lo;
twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo);
fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo);
}
// This function computes the reciprocal of a floating point number
// with extra precision and returns the result as a double word.
template <typename Packet>
void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) {
typedef typename unpacket_traits<Packet>::type Scalar;
// 1. Approximate the reciprocal as the reciprocal of the high order element.
Packet approx_recip = prsqrt(x);
approx_recip = pmul(approx_recip, approx_recip);
// 2. Run one step of Newton-Raphson iteration in double word arithmetic
// to get the bottom half. The NR iteration for reciprocal of 'a' is
// x_{i+1} = x_i * (2 - a * x_i)
// -a*x_i
Packet t1_hi, t1_lo;
twoprod(pnegate(x), approx_recip, t1_hi, t1_lo);
// 2 - a*x_i
Packet t2_hi, t2_lo;
fast_twosum(pset1<Packet>(Scalar(2)), t1_hi, t2_hi, t2_lo);
Packet t3_hi, t3_lo;
fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo);
// x_i * (2 - a * x_i)
twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo);
}
// This function computes log2(x) and returns the result as a double word.
template <typename Scalar>
struct accurate_log2 {
@ -1115,6 +1173,101 @@ struct accurate_log2<float> {
}
};
// This specialization uses a more accurate algorithm to compute log2(x) for
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18.
// This additional accuracy is needed to counter the error-magnification
// inherent in multiplying by a potentially large exponent in pow(x,y).
// The minimax polynomial used was calculated using the Sollya tool.
// See sollya.org.
template <>
struct accurate_log2<double> {
template <typename Packet>
EIGEN_STRONG_INLINE
void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
// We use a transformation of variables:
// r = c * (x-1) / (x+1),
// such that
// log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r).
// The function f(r) can be approximated well using an odd polynomial
// of the form
// P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r,
// For the implementation of log2<double> here, Q is of degree 6 with
// coefficient represented in working precision (double), while C is a
// constant represented in extra precision as a double word to achieve
// full accuracy.
//
// The polynomial coefficients were computed by the Sollya script:
//
// c = 2 / log(2);
// trans = c * (x-1)/(x+1);
// itrans = (1+x/c)/(1-x/c);
// interval=[trans(sqrt(0.5)); trans(sqrt(2))];
// print(interval);
// f = log2(itrans(x));
// p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating);
const Packet q12 = pset1<Packet>(2.87074255468000586e-9);
const Packet q10 = pset1<Packet>(2.38957980901884082e-8);
const Packet q8 = pset1<Packet>(2.31032094540014656e-7);
const Packet q6 = pset1<Packet>(2.27279857398537278e-6);
const Packet q4 = pset1<Packet>(2.31271023278625638e-5);
const Packet q2 = pset1<Packet>(2.47556738444535513e-4);
const Packet q0 = pset1<Packet>(2.88543873228900172e-3);
const Packet C_hi = pset1<Packet>(0.0400377511598501157);
const Packet C_lo = pset1<Packet>(-4.77726582251425391e-19);
const Packet one = pset1<Packet>(1.0);
const Packet cst_2_log2e_hi = pset1<Packet>(2.88539008177792677);
const Packet cst_2_log2e_lo = pset1<Packet>(4.07660016854549667e-17);
// c * (x - 1)
Packet num_hi, num_lo;
twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo);
// TODO(rmlarsen): Investigate if using the division algorithm by
// Muller et al. is faster/more accurate.
// 1 / (x + 1)
Packet denom_hi, denom_lo;
doubleword_reciprocal(padd(x, one), denom_hi, denom_lo);
// r = c * (x-1) / (x+1),
Packet r_hi, r_lo;
twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo);
// r2 = r * r
Packet r2_hi, r2_lo;
twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo);
// r4 = r2 * r2
Packet r4_hi, r4_lo;
twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo);
// Evaluate Q(r^2) in working precision. We evaluate it in two parts
// (even and odd in r^2) to improve instruction level parallelism.
Packet q_even = pmadd(q12, r4_hi, q8);
Packet q_odd = pmadd(q10, r4_hi, q6);
q_even = pmadd(q_even, r4_hi, q4);
q_odd = pmadd(q_odd, r4_hi, q2);
q_even = pmadd(q_even, r4_hi, q0);
Packet q = pmadd(q_odd, r2_hi, q_even);
// Now evaluate the low order terms of P(x) in double word precision.
// In the following, due to the increasing magnitude of the coefficients
// and r being constrained to [-0.5, 0.5] we can use fast_twosum instead
// of the slower twosum.
// Q(r^2) * r^2
Packet p_hi, p_lo;
twoprod(r2_hi, r2_lo, q, p_hi, p_lo);
// Q(r^2) * r^2 + C
Packet p1_hi, p1_lo;
fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo);
// (Q(r^2) * r^2 + C) * r^2
Packet p2_hi, p2_lo;
twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo);
// ((Q(r^2) * r^2 + C) * r^2 + 1)
Packet p3_hi, p3_lo;
fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo);
// log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r
twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo);
}
};
// This function computes exp2(x) (i.e. 2**x).
template <typename Scalar>
struct fast_accurate_exp2 {
@ -1161,8 +1314,75 @@ struct fast_accurate_exp2<float> {
// to gain some instruction level parallelism.
Packet x2 = pmul(x,x);
Packet p_even = pmadd(p4, x2, p2);
p_even = pmadd(p_even, x2, p0);
Packet p_odd = pmadd(p3, x2, p1);
p_even = pmadd(p_even, x2, p0);
Packet p = pmadd(p_odd, x, p_even);
// Evaluate the remaining terms of Q(x) with extra precision using
// double word arithmetic.
Packet p_hi, p_lo;
// x * p(x)
twoprod(p, x, p_hi, p_lo);
// C + x * p(x)
Packet q1_hi, q1_lo;
twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
// x * (C + x * p(x))
Packet q2_hi, q2_lo;
twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
// 1 + x * (C + x * p(x))
Packet q3_hi, q3_lo;
// Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
// for adding it to unity here.
fast_twosum(one, q2_hi, q3_hi, q3_lo);
return padd(q3_hi, padd(q2_lo, q3_lo));
}
};
// in [-0.5;0.5] with a relative accuracy of 1 ulp.
// The minimax polynomial used was calculated using the Sollya tool.
// See sollya.org.
template <>
struct fast_accurate_exp2<double> {
template <typename Packet>
EIGEN_STRONG_INLINE
Packet operator()(const Packet& x) {
// This function approximates exp2(x) by a degree 10 polynomial of the form
// Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
// single precision, and the remaining steps are evaluated with extra precision using
// double word arithmetic. C is an extra precise constant stored as a double word.
//
// The polynomial coefficients were calculated using Sollya commands:
// > n = 11;
// > f = 2^x;
// > interval = [-0.5;0.5];
// > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
const Packet p0 = pset1<Packet>(0.240226506959101332);
const Packet C_hi = pset1<Packet>(0.693147180559945286);
const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
const Packet one = pset1<Packet>(1.0f);
// Evaluate P(x) in working precision.
// We evaluate even and odd parts of the polynomial separately
// to gain some instruction level parallelism.
Packet x2 = pmul(x,x);
Packet p_even = pmadd(p8, x2, p6);
Packet p_odd = pmadd(p9, x2, p7);
p_even = pmadd(p_even, x2, p4);
p_odd = pmadd(p_odd, x2, p5);
p_even = pmadd(p_even, x2, p2);
p_odd = pmadd(p_odd, x2, p3);
p_even = pmadd(p_even, x2, p0);
p_odd = pmadd(p_odd, x2, p1);
Packet p = pmadd(p_odd, x, p_even);
// Evaluate the remaining terms of Q(x) with extra precision using
@ -1234,7 +1454,6 @@ EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
// Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
// using a specialized algorithm. Multiplication by the second factor can
// be done exactly using pldexp(), since it is an integer power of 2.
// Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
return pldexp(e_r, n_z);
}