mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Accurate pow, part 2. This change adds specializations of log2 and exp2 for double that
make pow<double> accurate the 1 ULP. Speed for AVX-512 is within 0.5% of the currect implementation.
This commit is contained in:
parent
2ac0b78739
commit
88d4c6d4c8
@ -484,7 +484,7 @@ Packet pexp_float(const Packet _x)
|
||||
y1 = pmadd(y1, r, cst_cephes_exp_p5);
|
||||
y = pmadd(y, r3, y1);
|
||||
y = pmadd(y, r2, y2);
|
||||
|
||||
|
||||
// Return 2^m * exp(r).
|
||||
// TODO: replace pldexp with faster implementation since y in [-1, 1).
|
||||
return pmax(pldexp(y,m), _x);
|
||||
@ -1003,6 +1003,20 @@ EIGEN_STRONG_INLINE
|
||||
fast_twosum(r_hi, s, s_hi, s_lo);
|
||||
}
|
||||
|
||||
// This is a version of twosum for adding a floating point number x to
|
||||
// double word number {y_hi, y_lo} number, with the assumption
|
||||
// that |x| >= |y_hi|.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void fast_twosum(const Packet& x,
|
||||
const Packet& y_hi, const Packet& y_lo,
|
||||
Packet& s_hi, Packet& s_lo) {
|
||||
Packet r_hi, r_lo;
|
||||
fast_twosum(x, y_hi, r_hi, r_lo);
|
||||
const Packet s = padd(y_lo, r_lo);
|
||||
fast_twosum(r_hi, s, s_hi, s_lo);
|
||||
}
|
||||
|
||||
// This function implements the multiplication of a double word
|
||||
// number represented by {x_hi, x_lo} by a floating point number y.
|
||||
// It returns the result as a pair {p_hi, p_lo} such that
|
||||
@ -1024,6 +1038,50 @@ void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
|
||||
fast_twosum(t_hi, t_lo2, p_hi, p_lo);
|
||||
}
|
||||
|
||||
// This function implements the multiplication of two double word
|
||||
// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
|
||||
// It returns the result as a pair {p_hi, p_lo} such that
|
||||
// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error
|
||||
// of less than 2*2^{-2p}, where p is the number of significand bit
|
||||
// in the floating point type.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void twoprod(const Packet& x_hi, const Packet& x_lo,
|
||||
const Packet& y_hi, const Packet& y_lo,
|
||||
Packet& p_hi, Packet& p_lo) {
|
||||
Packet p_hi_hi, p_hi_lo;
|
||||
twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo);
|
||||
Packet p_lo_hi, p_lo_lo;
|
||||
twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo);
|
||||
fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo);
|
||||
}
|
||||
|
||||
// This function computes the reciprocal of a floating point number
|
||||
// with extra precision and returns the result as a double word.
|
||||
template <typename Packet>
|
||||
void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
// 1. Approximate the reciprocal as the reciprocal of the high order element.
|
||||
Packet approx_recip = prsqrt(x);
|
||||
approx_recip = pmul(approx_recip, approx_recip);
|
||||
|
||||
// 2. Run one step of Newton-Raphson iteration in double word arithmetic
|
||||
// to get the bottom half. The NR iteration for reciprocal of 'a' is
|
||||
// x_{i+1} = x_i * (2 - a * x_i)
|
||||
|
||||
// -a*x_i
|
||||
Packet t1_hi, t1_lo;
|
||||
twoprod(pnegate(x), approx_recip, t1_hi, t1_lo);
|
||||
// 2 - a*x_i
|
||||
Packet t2_hi, t2_lo;
|
||||
fast_twosum(pset1<Packet>(Scalar(2)), t1_hi, t2_hi, t2_lo);
|
||||
Packet t3_hi, t3_lo;
|
||||
fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo);
|
||||
// x_i * (2 - a * x_i)
|
||||
twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo);
|
||||
}
|
||||
|
||||
|
||||
// This function computes log2(x) and returns the result as a double word.
|
||||
template <typename Scalar>
|
||||
struct accurate_log2 {
|
||||
@ -1115,6 +1173,101 @@ struct accurate_log2<float> {
|
||||
}
|
||||
};
|
||||
|
||||
// This specialization uses a more accurate algorithm to compute log2(x) for
|
||||
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18.
|
||||
// This additional accuracy is needed to counter the error-magnification
|
||||
// inherent in multiplying by a potentially large exponent in pow(x,y).
|
||||
// The minimax polynomial used was calculated using the Sollya tool.
|
||||
// See sollya.org.
|
||||
|
||||
template <>
|
||||
struct accurate_log2<double> {
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
|
||||
// We use a transformation of variables:
|
||||
// r = c * (x-1) / (x+1),
|
||||
// such that
|
||||
// log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r).
|
||||
// The function f(r) can be approximated well using an odd polynomial
|
||||
// of the form
|
||||
// P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r,
|
||||
// For the implementation of log2<double> here, Q is of degree 6 with
|
||||
// coefficient represented in working precision (double), while C is a
|
||||
// constant represented in extra precision as a double word to achieve
|
||||
// full accuracy.
|
||||
//
|
||||
// The polynomial coefficients were computed by the Sollya script:
|
||||
//
|
||||
// c = 2 / log(2);
|
||||
// trans = c * (x-1)/(x+1);
|
||||
// itrans = (1+x/c)/(1-x/c);
|
||||
// interval=[trans(sqrt(0.5)); trans(sqrt(2))];
|
||||
// print(interval);
|
||||
// f = log2(itrans(x));
|
||||
// p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating);
|
||||
const Packet q12 = pset1<Packet>(2.87074255468000586e-9);
|
||||
const Packet q10 = pset1<Packet>(2.38957980901884082e-8);
|
||||
const Packet q8 = pset1<Packet>(2.31032094540014656e-7);
|
||||
const Packet q6 = pset1<Packet>(2.27279857398537278e-6);
|
||||
const Packet q4 = pset1<Packet>(2.31271023278625638e-5);
|
||||
const Packet q2 = pset1<Packet>(2.47556738444535513e-4);
|
||||
const Packet q0 = pset1<Packet>(2.88543873228900172e-3);
|
||||
const Packet C_hi = pset1<Packet>(0.0400377511598501157);
|
||||
const Packet C_lo = pset1<Packet>(-4.77726582251425391e-19);
|
||||
const Packet one = pset1<Packet>(1.0);
|
||||
|
||||
const Packet cst_2_log2e_hi = pset1<Packet>(2.88539008177792677);
|
||||
const Packet cst_2_log2e_lo = pset1<Packet>(4.07660016854549667e-17);
|
||||
// c * (x - 1)
|
||||
Packet num_hi, num_lo;
|
||||
twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo);
|
||||
// TODO(rmlarsen): Investigate if using the division algorithm by
|
||||
// Muller et al. is faster/more accurate.
|
||||
// 1 / (x + 1)
|
||||
Packet denom_hi, denom_lo;
|
||||
doubleword_reciprocal(padd(x, one), denom_hi, denom_lo);
|
||||
// r = c * (x-1) / (x+1),
|
||||
Packet r_hi, r_lo;
|
||||
twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo);
|
||||
// r2 = r * r
|
||||
Packet r2_hi, r2_lo;
|
||||
twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo);
|
||||
// r4 = r2 * r2
|
||||
Packet r4_hi, r4_lo;
|
||||
twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo);
|
||||
|
||||
// Evaluate Q(r^2) in working precision. We evaluate it in two parts
|
||||
// (even and odd in r^2) to improve instruction level parallelism.
|
||||
Packet q_even = pmadd(q12, r4_hi, q8);
|
||||
Packet q_odd = pmadd(q10, r4_hi, q6);
|
||||
q_even = pmadd(q_even, r4_hi, q4);
|
||||
q_odd = pmadd(q_odd, r4_hi, q2);
|
||||
q_even = pmadd(q_even, r4_hi, q0);
|
||||
Packet q = pmadd(q_odd, r2_hi, q_even);
|
||||
|
||||
// Now evaluate the low order terms of P(x) in double word precision.
|
||||
// In the following, due to the increasing magnitude of the coefficients
|
||||
// and r being constrained to [-0.5, 0.5] we can use fast_twosum instead
|
||||
// of the slower twosum.
|
||||
// Q(r^2) * r^2
|
||||
Packet p_hi, p_lo;
|
||||
twoprod(r2_hi, r2_lo, q, p_hi, p_lo);
|
||||
// Q(r^2) * r^2 + C
|
||||
Packet p1_hi, p1_lo;
|
||||
fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo);
|
||||
// (Q(r^2) * r^2 + C) * r^2
|
||||
Packet p2_hi, p2_lo;
|
||||
twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo);
|
||||
// ((Q(r^2) * r^2 + C) * r^2 + 1)
|
||||
Packet p3_hi, p3_lo;
|
||||
fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo);
|
||||
|
||||
// log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r
|
||||
twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo);
|
||||
}
|
||||
};
|
||||
|
||||
// This function computes exp2(x) (i.e. 2**x).
|
||||
template <typename Scalar>
|
||||
struct fast_accurate_exp2 {
|
||||
@ -1161,8 +1314,75 @@ struct fast_accurate_exp2<float> {
|
||||
// to gain some instruction level parallelism.
|
||||
Packet x2 = pmul(x,x);
|
||||
Packet p_even = pmadd(p4, x2, p2);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
Packet p_odd = pmadd(p3, x2, p1);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
Packet p = pmadd(p_odd, x, p_even);
|
||||
|
||||
// Evaluate the remaining terms of Q(x) with extra precision using
|
||||
// double word arithmetic.
|
||||
Packet p_hi, p_lo;
|
||||
// x * p(x)
|
||||
twoprod(p, x, p_hi, p_lo);
|
||||
// C + x * p(x)
|
||||
Packet q1_hi, q1_lo;
|
||||
twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
|
||||
// x * (C + x * p(x))
|
||||
Packet q2_hi, q2_lo;
|
||||
twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
|
||||
// 1 + x * (C + x * p(x))
|
||||
Packet q3_hi, q3_lo;
|
||||
// Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
|
||||
// for adding it to unity here.
|
||||
fast_twosum(one, q2_hi, q3_hi, q3_lo);
|
||||
return padd(q3_hi, padd(q2_lo, q3_lo));
|
||||
}
|
||||
};
|
||||
|
||||
// in [-0.5;0.5] with a relative accuracy of 1 ulp.
|
||||
// The minimax polynomial used was calculated using the Sollya tool.
|
||||
// See sollya.org.
|
||||
template <>
|
||||
struct fast_accurate_exp2<double> {
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet operator()(const Packet& x) {
|
||||
// This function approximates exp2(x) by a degree 10 polynomial of the form
|
||||
// Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
|
||||
// single precision, and the remaining steps are evaluated with extra precision using
|
||||
// double word arithmetic. C is an extra precise constant stored as a double word.
|
||||
//
|
||||
// The polynomial coefficients were calculated using Sollya commands:
|
||||
// > n = 11;
|
||||
// > f = 2^x;
|
||||
// > interval = [-0.5;0.5];
|
||||
// > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
|
||||
|
||||
const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
|
||||
const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
|
||||
const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
|
||||
const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
|
||||
const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
|
||||
const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
|
||||
const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
|
||||
const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
|
||||
const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
|
||||
const Packet p0 = pset1<Packet>(0.240226506959101332);
|
||||
const Packet C_hi = pset1<Packet>(0.693147180559945286);
|
||||
const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
|
||||
const Packet one = pset1<Packet>(1.0f);
|
||||
|
||||
// Evaluate P(x) in working precision.
|
||||
// We evaluate even and odd parts of the polynomial separately
|
||||
// to gain some instruction level parallelism.
|
||||
Packet x2 = pmul(x,x);
|
||||
Packet p_even = pmadd(p8, x2, p6);
|
||||
Packet p_odd = pmadd(p9, x2, p7);
|
||||
p_even = pmadd(p_even, x2, p4);
|
||||
p_odd = pmadd(p_odd, x2, p5);
|
||||
p_even = pmadd(p_even, x2, p2);
|
||||
p_odd = pmadd(p_odd, x2, p3);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
p_odd = pmadd(p_odd, x2, p1);
|
||||
Packet p = pmadd(p_odd, x, p_even);
|
||||
|
||||
// Evaluate the remaining terms of Q(x) with extra precision using
|
||||
@ -1234,7 +1454,6 @@ EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
|
||||
// Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
|
||||
// using a specialized algorithm. Multiplication by the second factor can
|
||||
// be done exactly using pldexp(), since it is an integer power of 2.
|
||||
// Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
|
||||
const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
|
||||
return pldexp(e_r, n_z);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user