Consolidate float and double implementations of patan().

This commit is contained in:
Rasmus Munk Larsen 2024-08-21 20:44:18 +00:00
parent 87239e058a
commit f91f8e9ab9
5 changed files with 97 additions and 138 deletions

View File

@ -23,7 +23,6 @@ namespace internal {
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(Packet8f)
EIGEN_DOUBLE_PACKET_FUNCTION(atan, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d)
@ -33,6 +32,11 @@ EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d)
#endif
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED Packet4d patan<Packet4d>(const Packet4d& _x) {
return generic_patan(_x);
}
// Notice that for newer processors, it is counterproductive to use Newton
// iteration for square root. In particular, Skylake and Zen2 processors
// have approximately doubled throughput of the _mm_sqrt_ps instruction

View File

@ -958,144 +958,92 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pasin_float(const Pac
return por(invalid_mask, p);
}
// Computes elementwise atan(x) for x in [-1:1] with 2 ulp accuracy.
template <typename Scalar>
struct patan_reduced {
template <typename Packet>
static EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet run(const Packet& x);
};
template <>
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced_float(const Packet& x) {
const Packet q0 = pset1<Packet>(-0.3333314359188079833984375f);
const Packet q2 = pset1<Packet>(0.19993579387664794921875f);
const Packet q4 = pset1<Packet>(-0.14209578931331634521484375f);
const Packet q6 = pset1<Packet>(0.1066047251224517822265625f);
const Packet q8 = pset1<Packet>(-7.5408883392810821533203125e-2f);
const Packet q10 = pset1<Packet>(4.3082617223262786865234375e-2f);
const Packet q12 = pset1<Packet>(-1.62907354533672332763671875e-2f);
const Packet q14 = pset1<Packet>(2.90188402868807315826416015625e-3f);
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced<double>::run(const Packet& x) {
const Packet p1 = pset1<Packet>(3.3004361289279920e-01);
const Packet p3 = pset1<Packet>(8.2704055405494614e-01);
const Packet p5 = pset1<Packet>(7.5365702534987022e-01);
const Packet p7 = pset1<Packet>(3.0409318473444424e-01);
const Packet p9 = pset1<Packet>(5.2574296781008604e-02);
const Packet p11 = pset1<Packet>(3.0917513112462781e-03);
const Packet p13 = pset1<Packet>(2.6667153866462208e-05);
const Packet q0 = p1;
const Packet q2 = pset1<Packet>(9.3705509168587852e-01);
const Packet q4 = pset1<Packet>(1.0);
const Packet q6 = pset1<Packet>(4.9716458728465573e-01);
const Packet q8 = pset1<Packet>(1.1548932646420353e-01);
const Packet q10 = pset1<Packet>(1.0899150928962708e-02);
const Packet q12 = pset1<Packet>(2.7311202462436667e-04);
// Approximate atan(x) by a polynomial of the form
// P(x) = x + x^3 * Q(x^2),
// where Q(x^2) is a 7th order polynomial in x^2.
// We evaluate even and odd terms in x^2 in parallel
// to take advantage of instruction level parallelism
// and hardware with multiple FMA units.
Packet x2 = pmul(x, x);
Packet p = pmadd(p13, x2, p11);
Packet q = pmadd(q12, x2, q10);
p = pmadd(p, x2, p9);
q = pmadd(q, x2, q8);
p = pmadd(p, x2, p7);
q = pmadd(q, x2, q6);
p = pmadd(p, x2, p5);
q = pmadd(q, x2, q4);
p = pmadd(p, x2, p3);
q = pmadd(q, x2, q2);
p = pmadd(p, x2, p1);
q = pmadd(q, x2, q0);
return pmul(x, pdiv(p, q));
}
// note: if x == -0, this returns +0
const Packet x2 = pmul(x, x);
const Packet x4 = pmul(x2, x2);
Packet q_odd = pmadd(q14, x4, q10);
Packet q_even = pmadd(q12, x4, q8);
q_odd = pmadd(q_odd, x4, q6);
q_even = pmadd(q_even, x4, q4);
q_odd = pmadd(q_odd, x4, q2);
q_even = pmadd(q_even, x4, q0);
const Packet q = pmadd(q_odd, x2, q_even);
return pmadd(q, pmul(x, x2), x);
// Computes elementwise atan(x) for x in [-1:1] with 2 ulp accuracy.
template <>
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced<float>::run(const Packet& x) {
const Packet p1 = pset1<Packet>(8.109951019287109375e-01f);
const Packet p3 = pset1<Packet>(7.296695709228515625e-01f);
const Packet p5 = pset1<Packet>(1.12026982009410858154296875e-01f);
const Packet q0 = p1;
const Packet q2 = pset1<Packet>(1.0f);
const Packet q4 = pset1<Packet>(2.8318560123443603515625e-01f);
const Packet q6 = pset1<Packet>(1.00917108356952667236328125e-02f);
Packet x2 = pmul(x, x);
Packet p = pmadd(p5, x2, p3);
Packet q = pmadd(q6, x2, q4);
p = pmadd(p, x2, p1);
q = pmadd(q, x2, q2);
q = pmadd(q, x2, q0);
return pmul(x, pdiv(p, q));
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_float(const Packet& x_in) {
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_patan(const Packet& x_in) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float");
constexpr float kPiOverTwo = static_cast<float>(EIGEN_PI / 2);
constexpr Scalar kPiOverTwo = static_cast<Scalar>(EIGEN_PI / 2);
const Packet cst_signmask = pset1<Packet>(-0.0f);
const Packet cst_one = pset1<Packet>(1.0f);
const Packet cst_signmask = pset1<Packet>(-Scalar(0));
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
// "Large": For |x| > 1, use atan(1/x) = sign(x)*pi/2 - atan(x).
// "Small": For |x| <= 1, approximate atan(x) directly by a polynomial
// calculated using Sollya.
// calculated using Rminimax.
const Packet abs_x = pabs(x_in);
const Packet x_signmask = pand(x_in, cst_signmask);
const Packet large_mask = pcmp_lt(cst_one, abs_x);
const Packet x = pselect(large_mask, preciprocal(abs_x), abs_x);
const Packet p = patan_reduced_float(x);
const Packet p = patan_reduced<Scalar>::run(x);
// Apply transformations according to the range reduction masks.
Packet result = pselect(large_mask, psub(cst_pi_over_two, p), p);
// Return correct sign
return pxor(result, x_signmask);
}
// Computes elementwise atan(x) for x in [-tan(pi/8):tan(pi/8)]
// with 2 ulp accuracy.
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_reduced_double(const Packet& x) {
const Packet q0 = pset1<Packet>(-0.33333333333330028569463365784031338989734649658203);
const Packet q2 = pset1<Packet>(0.199999999990664090177006073645316064357757568359375);
const Packet q4 = pset1<Packet>(-0.142857141937123677255527809393242932856082916259766);
const Packet q6 = pset1<Packet>(0.111111065991039953404495577160560060292482376098633);
const Packet q8 = pset1<Packet>(-9.0907812986129224452902519715280504897236824035645e-2);
const Packet q10 = pset1<Packet>(7.6900542950704739442180368769186316058039665222168e-2);
const Packet q12 = pset1<Packet>(-6.6410112986494976294871150912513257935643196105957e-2);
const Packet q14 = pset1<Packet>(5.6920144995467943094258345126945641823112964630127e-2);
const Packet q16 = pset1<Packet>(-4.3577020814990513608577771265117917209863662719727e-2);
const Packet q18 = pset1<Packet>(2.1244050233624342527427586446719942614436149597168e-2);
// Approximate atan(x) on [0:tan(pi/8)] by a polynomial of the form
// P(x) = x + x^3 * Q(x^2),
// where Q(x^2) is a 9th order polynomial in x^2.
// We evaluate even and odd terms in x^2 in parallel
// to take advantage of instruction level parallelism
// and hardware with multiple FMA units.
const Packet x2 = pmul(x, x);
const Packet x4 = pmul(x2, x2);
Packet q_odd = pmadd(q18, x4, q14);
Packet q_even = pmadd(q16, x4, q12);
q_odd = pmadd(q_odd, x4, q10);
q_even = pmadd(q_even, x4, q8);
q_odd = pmadd(q_odd, x4, q6);
q_even = pmadd(q_even, x4, q4);
q_odd = pmadd(q_odd, x4, q2);
q_even = pmadd(q_even, x4, q0);
const Packet p = pmadd(q_odd, x2, q_even);
return pmadd(p, pmul(x, x2), x);
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_double(const Packet& x_in) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
constexpr double kPiOverTwo = static_cast<double>(EIGEN_PI / 2);
constexpr double kPiOverFour = static_cast<double>(EIGEN_PI / 4);
constexpr double kTanPiOverEight = 0.4142135623730950488016887;
constexpr double kTan3PiOverEight = 2.4142135623730950488016887;
const Packet cst_signmask = pset1<Packet>(-0.0);
const Packet cst_one = pset1<Packet>(1.0);
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
const Packet cst_pi_over_four = pset1<Packet>(kPiOverFour);
const Packet cst_large = pset1<Packet>(kTan3PiOverEight);
const Packet cst_medium = pset1<Packet>(kTanPiOverEight);
// Use the same range reduction strategy (to [0:tan(pi/8)]) as the
// Cephes library:
// "Large": For x >= tan(3*pi/8), use atan(1/x) = pi/2 - atan(x).
// "Medium": For x in [tan(pi/8) : tan(3*pi/8)),
// use atan(x) = pi/4 + atan((x-1)/(x+1)).
// "Small": For x < tan(pi/8), approximate atan(x) directly by a polynomial
// calculated using Sollya.
const Packet abs_x = pabs(x_in);
const Packet x_signmask = pand(x_in, cst_signmask);
const Packet large_mask = pcmp_lt(cst_large, abs_x);
const Packet medium_mask = pandnot(pcmp_lt(cst_medium, abs_x), large_mask);
Packet x = abs_x;
x = pselect(large_mask, preciprocal(abs_x), x);
x = pselect(medium_mask, pdiv(psub(abs_x, cst_one), padd(abs_x, cst_one)), x);
// Compute approximation of p ~= atan(x') where x' is the argument reduced to
// [0:tan(pi/8)].
Packet p = patan_reduced_double(x);
// Apply transformations according to the range reduction masks.
p = pselect(large_mask, psub(cst_pi_over_two, p), p);
p = pselect(medium_mask, padd(cst_pi_over_four, p), p);
// Return the correct sign
return pxor(p, x_signmask);
}
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
Doesn't do anything fancy, just a 9/8-degree rational interpolant which
is accurate up to a couple of ulps in the (approximate) range [-8, 8],

View File

@ -66,6 +66,10 @@ Packet generic_plog1p(const Packet& x);
template <typename Packet>
Packet generic_expm1(const Packet& x);
/** \internal \returns atan(x) */
template <typename Packet>
Packet generic_patan(const Packet& x);
/** \internal \returns exp(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_float(const Packet _x);
@ -98,14 +102,6 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pasin_float(const Pac
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pacos_float(const Packet& x);
/** \internal \returns atan(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_float(const Packet& x);
/** \internal \returns atan(x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_double(const Packet& x);
/** \internal \returns tanh(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh_float(const Packet& x);
@ -167,7 +163,6 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_FLOAT_PACKET_FUNCTION(cos, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(asin, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(acos, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(atan, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(tanh, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(atanh, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(log, PACKET) \
@ -175,21 +170,28 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_FLOAT_PACKET_FUNCTION(exp, PACKET) \
template <> \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET pexpm1<PACKET>(const PACKET& _x) { \
return internal::generic_expm1(_x); \
return generic_expm1(_x); \
} \
template <> \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET plog1p<PACKET>(const PACKET& _x) { \
return internal::generic_plog1p(_x); \
return generic_plog1p(_x); \
} \
template <> \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET patan<PACKET>(const PACKET& _x) { \
return generic_patan(_x); \
}
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(atan, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET)
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET) \
template <> \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET patan<PACKET>(const PACKET& _x) { \
return generic_patan(_x); \
}
} // end namespace internal
} // end namespace Eigen

View File

@ -12,6 +12,7 @@
#include <cerrno>
#include <ctime>
#include <iostream>
#include <iomanip>
#include <fstream>
#include <string>
#include <sstream>

View File

@ -59,7 +59,8 @@ bool areApproxAbs(const Scalar* a, const Scalar* b, int size, const typename Num
for (int i = 0; i < size; ++i) {
if (!isApproxAbs(a[i], b[i], refvalue)) {
print_mismatch(a, b, size);
std::cout << "Values differ in position " << i << ": " << a[i] << " vs " << b[i] << std::endl;
std::cout << std::setprecision(16) << "Values differ in position " << i << ": " << a[i] << " vs " << b[i]
<< std::endl;
return false;
}
}
@ -72,7 +73,8 @@ bool areApprox(const Scalar* a, const Scalar* b, int size) {
if (numext::not_equal_strict(a[i], b[i]) && !internal::isApprox(a[i], b[i]) &&
!((numext::isnan)(a[i]) && (numext::isnan)(b[i]))) {
print_mismatch(a, b, size);
std::cout << "Values differ in position " << i << ": " << a[i] << " vs " << b[i] << std::endl;
std::cout << std::setprecision(16) << "Values differ in position " << i << ": " << a[i] << " vs " << b[i]
<< std::endl;
return false;
}
}
@ -84,7 +86,8 @@ bool areEqual(const Scalar* a, const Scalar* b, int size) {
for (int i = 0; i < size; ++i) {
if (numext::not_equal_strict(a[i], b[i]) && !((numext::isnan)(a[i]) && (numext::isnan)(b[i]))) {
print_mismatch(a, b, size);
std::cout << "Values differ in position " << i << ": " << a[i] << " vs " << b[i] << std::endl;
std::cout << std::setprecision(16) << "Values differ in position " << i << ": " << a[i] << " vs " << b[i]
<< std::endl;
return false;
}
}
@ -97,7 +100,8 @@ bool areApprox(const Scalar* a, const Scalar* b, int size, const typename NumTra
if (numext::not_equal_strict(a[i], b[i]) && !internal::isApprox(a[i], b[i], precision) &&
!((numext::isnan)(a[i]) && (numext::isnan)(b[i]))) {
print_mismatch(a, b, size);
std::cout << "Values differ in position " << i << ": " << a[i] << " vs " << b[i] << std::endl;
std::cout << std::setprecision(16) << "Values differ in position " << i << ": " << a[i] << " vs " << b[i]
<< std::endl;
return false;
}
}