Fix neon cmp* functions for bf16.

The current impl corrupts the comparison masks when converting
from float back to bfloat16.  The resulting masks are then
no longer all zeros or all ones, which breaks when used with
`pselect` (e.g. in `pmin<PropagateNumbers>`).  This was
causing `packetmath_15` to fail on arm.

Introducing a simple `F32MaskToBf16Mask` corrects this (takes
the lower 16-bits for each float mask).
This commit is contained in:
Antonio Sanchez 2020-12-01 16:28:48 -08:00 committed by Rasmus Munk Larsen
parent ddd48b242c
commit 2627e2f2e6

View File

@ -3371,6 +3371,10 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
}
EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) {
return vmovn_u32(vreinterpretq_f32_u32(p));
}
template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
return pset1<Packet4us>(from.value);
}
@ -3528,17 +3532,17 @@ template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a,
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
return F32MaskToBf16Mask(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
return F32MaskToBf16Mask(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
return F32MaskToBf16Mask(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)