mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Fix NEON sqrt for 32-bit, add prsqrt.
With !406, we accidentally broke arm 32-bit NEON builds, since `vsqrt_f32` is only available for 64-bit. Here we add back the `rsqrt` implementation for 32-bit, relying on a `prsqrt` implementation with better handling of edge cases. Note that several of the 32-bit NEON packet tests are currently failing - either due to denormal handling (NEON versions flush to zero, but scalar paths don't) or due to accuracy (e.g. sin/cos).
This commit is contained in:
parent
fe19714f80
commit
29ebd84cb7
@ -684,7 +684,7 @@ Packet plog2(const Packet& a) {
|
||||
|
||||
/** \internal \returns the square-root of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet psqrt(const Packet& a) { EIGEN_USING_STD(sqrt); return sqrt(a); }
|
||||
Packet psqrt(const Packet& a) { return numext::sqrt(a); }
|
||||
|
||||
/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
|
@ -202,6 +202,7 @@ struct packet_traits<float> : default_packet_traits
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasBessel = 0, // Issues with accuracy.
|
||||
@ -3329,8 +3330,42 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
|
||||
// Compute approximate reciprocal sqrt.
|
||||
Packet4f x = vrsqrteq_f32(a);
|
||||
// Do Newton iterations for 1/sqrt(x).
|
||||
x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
|
||||
x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
|
||||
const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
|
||||
return pselect(pcmp_eq(a, pzero(a)), infinity, x);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) {
|
||||
// Compute approximate reciprocal sqrt.
|
||||
Packet2f x = vrsqrte_f32(a);
|
||||
// Do Newton iterations for 1/sqrt(x).
|
||||
x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
|
||||
x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
|
||||
const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
|
||||
return pselect(pcmp_eq(a, pzero(a)), infinity, x);
|
||||
}
|
||||
|
||||
// Unfortunately vsqrt_f32 is only available for A64.
|
||||
#if EIGEN_ARCH_ARM64
|
||||
template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);}
|
||||
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); }
|
||||
#else
|
||||
template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
|
||||
const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
|
||||
const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
|
||||
return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) {
|
||||
const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
|
||||
const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
|
||||
return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
|
||||
}
|
||||
#endif
|
||||
|
||||
//---------- bfloat16 ----------
|
||||
// TODO: Add support for native armv8.6-a bfloat16_t
|
||||
@ -3722,6 +3757,7 @@ template<> struct packet_traits<double> : default_packet_traits
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasTanh = 0,
|
||||
HasErf = 0
|
||||
};
|
||||
@ -3933,6 +3969,17 @@ template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Pack
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
|
||||
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
|
||||
// Compute approximate reciprocal sqrt.
|
||||
Packet2d x = vrsqrteq_f64(a);
|
||||
// Do Newton iterations for 1/sqrt(x).
|
||||
x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
|
||||
x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
|
||||
x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
|
||||
const Packet2d infinity = pset1<Packet2d>(NumTraits<double>::infinity());
|
||||
return pselect(pcmp_eq(a, pzero(a)), infinity, x);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); }
|
||||
|
||||
#endif // EIGEN_ARCH_ARM64
|
||||
@ -3978,6 +4025,7 @@ struct packet_traits<Eigen::half> : default_packet_traits {
|
||||
HasLog = 0,
|
||||
HasExp = 0,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasBessel = 0, // Issues with accuracy.
|
||||
HasNdtri = 0,
|
||||
|
@ -504,6 +504,7 @@ void packetmath() {
|
||||
data1[i] = numext::abs(internal::random<Scalar>());
|
||||
}
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||
}
|
||||
|
||||
// Notice that this definition works for complex types as well.
|
||||
@ -532,7 +533,7 @@ void packetmath_real() {
|
||||
|
||||
CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasLog, log2, internal::plog2);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, 1 / std::sqrt, internal::prsqrt);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data1[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-3, 3)));
|
||||
|
Loading…
Reference in New Issue
Block a user