mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-06 14:14:46 +08:00
Unify SSE and AVX implementation of pexp
This commit is contained in:
parent
c2f35b1b47
commit
cf8b85d5c5
@ -220,6 +220,12 @@ pandnot(const Packet& a, const Packet& b) { return a & (!b); }
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
pfrexp(const Packet &a, Packet &exponent) { return std::frexp(a,&exponent); }
|
||||
|
||||
/** \internal \returns a * 2^exponent
|
||||
* See https://en.cppreference.com/w/cpp/numeric/math/ldexp
|
||||
*/
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
pldexp(const Packet &a, const Packet &exponent) { return std::ldexp(a,exponent); }
|
||||
|
||||
/** \internal \returns zeros */
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
pzero(const Packet& a) { return pxor(a,a); }
|
||||
@ -656,6 +662,23 @@ pfrexp_float(const Packet& a, Packet& exponent) {
|
||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||
}
|
||||
|
||||
/** \internal shift the bits by n and cast the result to the initial type, i.e.:
|
||||
* return reinterpret_cast<float>(int(a) >> n)
|
||||
*/
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
pcast_and_shiftleft(Packet a, int n);
|
||||
|
||||
/** Default implementation of pldexp for float.
|
||||
* It is expected to be called by implementers of template<> pldexp,
|
||||
* and the above pcast_and_shiftleft function must be implemented.
|
||||
*/
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_float(Packet a, Packet exponent) {
|
||||
const Packet cst_127 = pset1<Packet>(127.f);
|
||||
// return a * 2^exponent
|
||||
return pmul(a, pcast_and_shiftleft(padd(exponent, cst_127), 23));
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -111,62 +111,7 @@ plog<Packet8f>(const Packet8f& _x) {
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
|
||||
pexp<Packet8f>(const Packet8f& _x) {
|
||||
_EIGEN_DECLARE_CONST_Packet8f(1, 1.0f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(half, 0.5f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(127, 127.0f);
|
||||
|
||||
_EIGEN_DECLARE_CONST_Packet8f(exp_hi, 88.3762626647950f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(exp_lo, -88.3762626647949f);
|
||||
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_LOG2EF, 1.44269504088896341f);
|
||||
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p0, 1.9875691500E-4f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p1, 1.3981999507E-3f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p2, 8.3334519073E-3f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p3, 4.1665795894E-2f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p4, 1.6666665459E-1f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p5, 5.0000001201E-1f);
|
||||
|
||||
// Clamp x.
|
||||
Packet8f x = pmax(pmin(_x, p8f_exp_hi), p8f_exp_lo);
|
||||
|
||||
// Express exp(x) as exp(m*ln(2) + r), start by extracting
|
||||
// m = floor(x/ln(2) + 0.5).
|
||||
Packet8f m = _mm256_floor_ps(pmadd(x, p8f_cephes_LOG2EF, p8f_half));
|
||||
|
||||
// Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
|
||||
// subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
|
||||
// truncation errors. Note that we don't use the "pmadd" function here to
|
||||
// ensure that a precision-preserving FMA instruction is used.
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
_EIGEN_DECLARE_CONST_Packet8f(nln2, -0.6931471805599453f);
|
||||
Packet8f r = _mm256_fmadd_ps(m, p8f_nln2, x);
|
||||
#else
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C1, 0.693359375f);
|
||||
_EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C2, -2.12194440e-4f);
|
||||
Packet8f r = psub(x, pmul(m, p8f_cephes_exp_C1));
|
||||
r = psub(r, pmul(m, p8f_cephes_exp_C2));
|
||||
#endif
|
||||
|
||||
Packet8f r2 = pmul(r, r);
|
||||
|
||||
// TODO(gonnet): Split into odd/even polynomials and try to exploit
|
||||
// instruction-level parallelism.
|
||||
Packet8f y = p8f_cephes_exp_p0;
|
||||
y = pmadd(y, r, p8f_cephes_exp_p1);
|
||||
y = pmadd(y, r, p8f_cephes_exp_p2);
|
||||
y = pmadd(y, r, p8f_cephes_exp_p3);
|
||||
y = pmadd(y, r, p8f_cephes_exp_p4);
|
||||
y = pmadd(y, r, p8f_cephes_exp_p5);
|
||||
y = pmadd(y, r2, r);
|
||||
y = padd(y, p8f_1);
|
||||
|
||||
// Build emm0 = 2^m.
|
||||
Packet8i emm0 = _mm256_cvttps_epi32(padd(m, p8f_127));
|
||||
emm0 = pshiftleft(emm0, 23);
|
||||
|
||||
// Return 2^m * exp(r).
|
||||
return pmax(pmul(y, _mm256_castsi256_ps(emm0)), _x);
|
||||
return pexp_float(_x);
|
||||
}
|
||||
|
||||
// Hyperbolic Tangent function.
|
||||
|
@ -389,6 +389,22 @@ template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Pack
|
||||
return pfrexp_float(a,exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcast_and_shiftleft<Packet8f>(Packet8f v, int n)
|
||||
{
|
||||
Packet8i vi = _mm256_cvttps_epi32(v);
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
return _mm256_castsi256_ps(_mm256_slli_epi32(vi, n));
|
||||
#else
|
||||
__m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(vi, 0), n);
|
||||
__m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(vi, 1), n);
|
||||
return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
}
|
||||
|
||||
// preduxp should be ok
|
||||
// FIXME: why is this ok? why isn't the simply implementation working as expected?
|
||||
template<> EIGEN_STRONG_INLINE Packet8f preduxp<Packet8f>(const Packet8f* vecs)
|
||||
|
@ -9,7 +9,7 @@
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
/* The log function of this file initially comes from
|
||||
/* The exp and log functions of this file initially come from
|
||||
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
|
||||
*/
|
||||
|
||||
@ -25,12 +25,12 @@ namespace internal {
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet plog_float(const Packet _x) {
|
||||
Packet plog_float(const Packet _x)
|
||||
{
|
||||
Packet x = _x;
|
||||
|
||||
const Packet cst_1 = pset1<Packet>(1.0f);
|
||||
const Packet cst_half = pset1<Packet>(0.5f);
|
||||
//const Packet cst_126f = pset1<Packet>(126.0f);
|
||||
// The smallest non denormalized float number.
|
||||
const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
|
||||
const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
|
||||
@ -101,5 +101,64 @@ Packet plog_float(const Packet _x) {
|
||||
return pselect(iszero_mask, cst_minus_inf, por(x, invalid_mask));
|
||||
}
|
||||
|
||||
// Exponential function. Works by writing "x = m*log(2) + r" where
|
||||
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
|
||||
// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet pexp_float(const Packet _x)
|
||||
{
|
||||
const Packet cst_1 = pset1<Packet>(1.0f);
|
||||
const Packet cst_half = pset1<Packet>(0.5f);
|
||||
const Packet cst_exp_hi = pset1<Packet>( 88.3762626647950f);
|
||||
const Packet cst_exp_lo = pset1<Packet>(-88.3762626647949f);
|
||||
|
||||
const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
|
||||
const Packet cst_cephes_exp_p0 = pset1<Packet>(1.9875691500E-4f);
|
||||
const Packet cst_cephes_exp_p1 = pset1<Packet>(1.3981999507E-3f);
|
||||
const Packet cst_cephes_exp_p2 = pset1<Packet>(8.3334519073E-3f);
|
||||
const Packet cst_cephes_exp_p3 = pset1<Packet>(4.1665795894E-2f);
|
||||
const Packet cst_cephes_exp_p4 = pset1<Packet>(1.6666665459E-1f);
|
||||
const Packet cst_cephes_exp_p5 = pset1<Packet>(5.0000001201E-1f);
|
||||
|
||||
// Clamp x.
|
||||
Packet x = pmax(pmin(_x, cst_exp_hi), cst_exp_lo);
|
||||
|
||||
// Express exp(x) as exp(m*ln(2) + r), start by extracting
|
||||
// m = floor(x/ln(2) + 0.5).
|
||||
Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
|
||||
|
||||
// Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
|
||||
// subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
|
||||
// truncation errors.
|
||||
Packet r;
|
||||
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
|
||||
const Packet cst_nln2 = pset1<Packet>(-0.6931471805599453f);
|
||||
r = pmadd(m, cst_nln2, x);
|
||||
#else
|
||||
const Packet cst_cephes_exp_C1 = pset1<Packet>(0.693359375f);
|
||||
const Packet cst_cephes_exp_C2 = pset1<Packet>(-2.12194440e-4f);
|
||||
r = psub(x, pmul(m, cst_cephes_exp_C1));
|
||||
r = psub(r, pmul(m, cst_cephes_exp_C2));
|
||||
#endif
|
||||
|
||||
Packet r2 = pmul(r, r);
|
||||
|
||||
// TODO(gonnet): Split into odd/even polynomials and try to exploit
|
||||
// instruction-level parallelism.
|
||||
Packet y = cst_cephes_exp_p0;
|
||||
y = pmadd(y, r, cst_cephes_exp_p1);
|
||||
y = pmadd(y, r, cst_cephes_exp_p2);
|
||||
y = pmadd(y, r, cst_cephes_exp_p3);
|
||||
y = pmadd(y, r, cst_cephes_exp_p4);
|
||||
y = pmadd(y, r, cst_cephes_exp_p5);
|
||||
y = pmadd(y, r2, r);
|
||||
y = padd(y, cst_1);
|
||||
|
||||
// Return 2^m * exp(r).
|
||||
return pmax(pldexp(y,m), _x);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
@ -30,67 +30,7 @@ Packet4f plog<Packet4f>(const Packet4f& _x)
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||
Packet4f pexp<Packet4f>(const Packet4f& _x)
|
||||
{
|
||||
Packet4f x = _x;
|
||||
_EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
|
||||
_EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
|
||||
|
||||
|
||||
_EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
|
||||
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
|
||||
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
|
||||
_EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
|
||||
|
||||
Packet4f tmp, fx;
|
||||
Packet4i emm0;
|
||||
|
||||
// clamp x
|
||||
x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo);
|
||||
|
||||
/* express exp(x) as exp(g + n*log(2)) */
|
||||
fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half);
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_SSE4_1
|
||||
fx = _mm_floor_ps(fx);
|
||||
#else
|
||||
emm0 = _mm_cvttps_epi32(fx);
|
||||
tmp = _mm_cvtepi32_ps(emm0);
|
||||
/* if greater, substract 1 */
|
||||
Packet4f mask = _mm_cmpgt_ps(tmp, fx);
|
||||
mask = _mm_and_ps(mask, p4f_1);
|
||||
fx = psub(tmp, mask);
|
||||
#endif
|
||||
|
||||
tmp = pmul(fx, p4f_cephes_exp_C1);
|
||||
Packet4f z = pmul(fx, p4f_cephes_exp_C2);
|
||||
x = psub(x, tmp);
|
||||
x = psub(x, z);
|
||||
|
||||
z = pmul(x,x);
|
||||
|
||||
Packet4f y = p4f_cephes_exp_p0;
|
||||
y = pmadd(y, x, p4f_cephes_exp_p1);
|
||||
y = pmadd(y, x, p4f_cephes_exp_p2);
|
||||
y = pmadd(y, x, p4f_cephes_exp_p3);
|
||||
y = pmadd(y, x, p4f_cephes_exp_p4);
|
||||
y = pmadd(y, x, p4f_cephes_exp_p5);
|
||||
y = pmadd(y, z, x);
|
||||
y = padd(y, p4f_1);
|
||||
|
||||
// build 2^n
|
||||
emm0 = _mm_cvttps_epi32(fx);
|
||||
emm0 = _mm_add_epi32(emm0, p4i_0x7f);
|
||||
emm0 = _mm_slli_epi32(emm0, 23);
|
||||
return pmax(pmul(y, Packet4f(_mm_castsi128_ps(emm0))), _x);
|
||||
return pexp_float(_x);
|
||||
}
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||
Packet2d pexp<Packet2d>(const Packet2d& _x)
|
||||
|
@ -110,12 +110,12 @@ template<> struct packet_traits<float> : default_packet_traits
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasBlend = 1
|
||||
HasBlend = 1,
|
||||
HasFloor = 1
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_SSE4_1
|
||||
,
|
||||
HasRound = 1,
|
||||
HasFloor = 1,
|
||||
HasCeil = 1
|
||||
#endif
|
||||
};
|
||||
@ -348,6 +348,17 @@ template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { ret
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) { return _mm_floor_ps(a); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return _mm_floor_pd(a); }
|
||||
#else
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
|
||||
{
|
||||
const Packet4f cst_1 = pset1<Packet4f>(1.0f);
|
||||
Packet4i emm0 = _mm_cvttps_epi32(a);
|
||||
Packet4f tmp = _mm_cvtepi32_ps(emm0);
|
||||
/* if greater, substract 1 */
|
||||
Packet4f mask = _mm_cmpgt_ps(tmp, a);
|
||||
mask = _mm_and_ps(mask, cst_1);
|
||||
return psub(tmp, mask);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); }
|
||||
@ -536,6 +547,16 @@ template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Pack
|
||||
return pfrexp_float(a,exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcast_and_shiftleft<Packet4f>(Packet4f v, int n)
|
||||
{
|
||||
Packet4i vi = _mm_cvttps_epi32(v);
|
||||
return _mm_castsi128_ps(_mm_slli_epi32(vi, n));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
}
|
||||
|
||||
// with AVX, the default implementations based on pload1 are faster
|
||||
#ifndef __AVX__
|
||||
template<> EIGEN_STRONG_INLINE void
|
||||
|
Loading…
Reference in New Issue
Block a user