Extend the generic psin_float code to handle cosine and make SSE and AVX use it (-> this adds pcos for AVX)

This commit is contained in:
Gael Guennebaud 2018-11-30 11:26:30 +01:00
parent e19ece822d
commit b477d60bc6
4 changed files with 41 additions and 105 deletions

View File

@ -29,16 +29,18 @@ inline Packet8i pshiftleft(Packet8i v, int n)
#endif
}
// Sine function
// Computes sin(x) by wrapping x to the interval [-Pi/4,3*Pi/4] and
// evaluating interpolants in [-Pi/4,Pi/4] or [Pi/4,3*Pi/4]. The interpolants
// are (anti-)symmetric and thus have only odd/even coefficients
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
psin<Packet8f>(const Packet8f& _x) {
return psin_float(_x);
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
pcos<Packet8f>(const Packet8f& _x) {
return pcos_float(_x);
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
plog<Packet8f>(const Packet8f& _x) {

View File

@ -63,7 +63,7 @@ template<> struct packet_traits<float> : default_packet_traits
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = 0,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,

View File

@ -245,14 +245,23 @@ Packet pexp_double(const Packet _x)
// Construct the result 2^n * exp(g) = e * x. The max is used to catch
// non-finite values in the input.
//return pmax(pmul(x, _mm256_castsi256_pd(e)), _x);
return pmax(pldexp(x,fx), _x);
}
template<typename Packet>
/* The code is the rewriting of the cephes sinf/cosf functions.
Precision is excellent as long as x < 8192 (I did not bother to
take into account the special handling they have for greater values
-- it does not return garbage for arguments over 8192, though, but
the extra precision is missing).
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
surprising but correct result.
*/
template<bool ComputeSine,typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psin_float(const Packet& _x)
Packet psincos_float(const Packet& _x)
{
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
const Packet cst_1 = pset1<Packet>(1.0f);
@ -261,6 +270,7 @@ Packet psin_float(const Packet& _x)
const PacketI csti_1 = pset1<PacketI>(1);
const PacketI csti_not1 = pset1<PacketI>(~1);
const PacketI csti_2 = pset1<PacketI>(2);
const PacketI csti_3 = pset1<PacketI>(3);
const Packet cst_sign_mask = pset1frombits<Packet>(0x80000000u);
@ -290,7 +300,8 @@ Packet psin_float(const Packet& _x)
// Compute the sign to apply to the polynomial.
// sign = third_bit(y_int1) xor signbit(_x)
Packet sign_bit = pxor(_x, preinterpret<Packet>(pshiftleft<29>(y_int1)));
Packet sign_bit = ComputeSine ? pxor(_x, preinterpret<Packet>(pshiftleft<29>(y_int1)))
: preinterpret<Packet>(pshiftleft<29>(padd(y_int1,csti_3)));
sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
// Get the polynomial selection mask from the second bit of y_int1
@ -323,11 +334,28 @@ Packet psin_float(const Packet& _x)
y2 = pmadd(y2, x, x);
// Select the correct result from the two polynoms.
y = pselect(poly_mask,y2,y1);
y = ComputeSine ? pselect(poly_mask,y2,y1)
: pselect(poly_mask,y1,y2);
// Update the sign
return pxor(y, sign_bit);
}
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psin_float(const Packet& x)
{
return psincos_float<true>(x);
}
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet pcos_float(const Packet& x)
{
return psincos_float<false>(x);
}
} // end namespace internal
} // end namespace Eigen

View File

@ -39,110 +39,16 @@ Packet2d pexp<Packet2d>(const Packet2d& x)
return pexp_double(x);
}
/* evaluation of 4 sines at once, using SSE2 intrinsics.
The code is the exact rewriting of the cephes sinf function.
Precision is excellent as long as x < 8192 (I did not bother to
take into account the special handling they have for greater values
-- it does not return garbage for arguments over 8192, though, but
the extra precision is missing).
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
surprising but correct result.
*/
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psin<Packet4f>(const Packet4f& _x)
{
return psin_float(_x);
}
/* almost the same as psin */
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pcos<Packet4f>(const Packet4f& _x)
{
Packet4f x = _x;
_EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
_EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
_EIGEN_DECLARE_CONST_Packet4i(1, 1);
_EIGEN_DECLARE_CONST_Packet4i(not1, ~1);
_EIGEN_DECLARE_CONST_Packet4i(2, 2);
_EIGEN_DECLARE_CONST_Packet4i(4, 4);
_EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1,-0.78515625f);
_EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
_EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
_EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891E-4f);
_EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736E-3f);
_EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611E-1f);
_EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948E-005f);
_EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765E-003f);
_EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827E-002f);
_EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4 / M_PI
Packet4f xmm1, xmm2, xmm3, y;
Packet4i emm0, emm2;
x = pabs(x);
/* scale by 4/Pi */
y = pmul(x, p4f_cephes_FOPI);
/* get the integer part of y */
emm2 = _mm_cvttps_epi32(y);
/* j=(j+1) & (~1) (see the cephes sources) */
emm2 = _mm_add_epi32(emm2, p4i_1);
emm2 = _mm_and_si128(emm2, p4i_not1);
y = _mm_cvtepi32_ps(emm2);
emm2 = _mm_sub_epi32(emm2, p4i_2);
/* get the swap sign flag */
emm0 = _mm_andnot_si128(emm2, p4i_4);
emm0 = _mm_slli_epi32(emm0, 29);
/* get the polynom selection mask */
emm2 = _mm_and_si128(emm2, p4i_2);
emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
Packet4f sign_bit = _mm_castsi128_ps(emm0);
Packet4f poly_mask = _mm_castsi128_ps(emm2);
/* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = pmul(y, p4f_minus_cephes_DP1);
xmm2 = pmul(y, p4f_minus_cephes_DP2);
xmm3 = pmul(y, p4f_minus_cephes_DP3);
x = padd(x, xmm1);
x = padd(x, xmm2);
x = padd(x, xmm3);
/* Evaluate the first polynom (0 <= x <= Pi/4) */
y = p4f_coscof_p0;
Packet4f z = pmul(x,x);
y = pmadd(y,z,p4f_coscof_p1);
y = pmadd(y,z,p4f_coscof_p2);
y = pmul(y, z);
y = pmul(y, z);
Packet4f tmp = _mm_mul_ps(z, p4f_half);
y = psub(y, tmp);
y = padd(y, p4f_1);
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
Packet4f y2 = p4f_sincof_p0;
y2 = pmadd(y2, z, p4f_sincof_p1);
y2 = pmadd(y2, z, p4f_sincof_p2);
y2 = pmul(y2, z);
y2 = pmadd(y2, x, x);
/* select the correct result from the two polynoms */
y2 = _mm_and_ps(poly_mask, y2);
y = _mm_andnot_ps(poly_mask, y);
y = _mm_or_ps(y,y2);
/* update the sign */
return _mm_xor_ps(y, sign_bit);
return pcos_float(_x);
}
#if EIGEN_FAST_MATH