bfloat16 packetmath for Arm Neon backend

This commit is contained in:
David Tellenbach 2020-08-13 15:48:40 +00:00 committed by Rasmus Munk Larsen
parent 704798d1df
commit 8ba1b0f41a
2 changed files with 269 additions and 0 deletions

View File

@ -38,6 +38,12 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Pack
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh<Packet4f>(const Packet4f& x)
{ return internal::generic_fast_tanh_float(x); }
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
} // end namespace internal
} // end namespace Eigen

View File

@ -3218,6 +3218,269 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
return res;
}
//---------- bfloat16 ----------
// TODO: Add support for native armv8.6-a bfloat16_t
// TODO: Guard if we have native bfloat16 support
typedef eigen_packet_wrapper<uint16x4_t, 19> Packet4bf;
template<> struct is_arithmetic<Packet4bf> { enum { value = true }; };
template<> struct packet_traits<bfloat16> : default_packet_traits
{
typedef Packet4bf type;
typedef Packet4bf half;
enum
{
Vectorizable = 1,
AlignedOnScalar = 1,
size = 4,
HasHalfPacket = 0,
HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
HasMul = 1,
HasNegate = 1,
HasAbs = 1,
HasArg = 0,
HasAbs2 = 1,
HasAbsDiff = 1,
HasMin = 1,
HasMax = 1,
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
HasDiv = 1,
HasFloor = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 0,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH
};
};
template<> struct unpacket_traits<Packet4bf>
{
typedef bfloat16 type;
typedef Packet4bf half;
enum
{
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p)
{
// See the scalar implemention in BFloat16.h for a comprehensible explanation
// of this fast rounding algorithm
Packet4ui input = reinterpret_cast<Packet4ui>(p);
// lsb = (input >> 16) & 1
Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1));
// rounding_bias = 0x7fff + lsb
Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff));
// input += rounding_bias
input = vaddq_u32(input, rounding_bias);
// input = input >> 16
input = vshrq_n_u32(input, 16);
// Replace float-nans by bfloat16-nans, that is 0x7fc0
const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0);
const Packet4ui mask = vceqq_f32(p, p);
input = vbslq_u32(mask, input, bf16_nan);
// output = static_cast<uint16_t>(input)
return vmovn_u32(input);
}
EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
{
return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
}
template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
return pset1<Packet4us>(from.value);
}
template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) {
return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<uint16_t>(pfirst<Packet4us>(from)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from)
{
return pload<Packet4us>(reinterpret_cast<const uint16_t*>(from));
}
template<> EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from)
{
return ploadu<Packet4us>(reinterpret_cast<const uint16_t*>(from));
}
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from)
{
EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
}
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet4bf& from)
{
EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
}
template<> EIGEN_STRONG_INLINE Packet4bf ploaddup<Packet4bf>(const bfloat16* from)
{
return ploaddup<Packet4us>(reinterpret_cast<const uint16_t*>(from));
}
template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) {
return F32ToBf16(pabs<Packet4f>(Bf16ToF32(a)));
}
template <> EIGEN_STRONG_INLINE Packet4bf pmin<Packet4bf>(const Packet4bf &a,
const Packet4bf &b)
{
return F32ToBf16(pmin<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a,
const Packet4bf &b)
{
return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) {
return por<Packet4us>(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) {
return pxor<Packet4us>(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) {
return pand<Packet4us>(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) {
return pandnot<Packet4us>(a, b);
}
template<> EIGEN_DEVICE_FUNC inline Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a,
const Packet4bf& b)
{
return pselect<Packet4us>(mask, a, b);
}
template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
{
return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
return F32ToBf16(padd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf psub<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
return F32ToBf16(psub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<>
EIGEN_STRONG_INLINE Packet4bf pgather<bfloat16, Packet4bf>(const bfloat16* from, Index stride)
{
return pgather<uint16_t, Packet4us>(reinterpret_cast<const uint16_t*>(from), stride);
}
template<>
EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet4bf>(bfloat16* to, const Packet4bf& from, Index stride)
{
pscatter<uint16_t, Packet4us>(reinterpret_cast<uint16_t*>(to), from, stride);
}
template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet4bf>(const Packet4bf& a)
{
return static_cast<bfloat16>(predux<Packet4f>(Bf16ToF32(a)));
}
template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet4bf>(const Packet4bf& a)
{
return static_cast<bfloat16>(predux_max<Packet4f>(Bf16ToF32(a)));
}
template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet4bf>(const Packet4bf& a)
{
return static_cast<bfloat16>(predux_min<Packet4f>(Bf16ToF32(a)));
}
template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet4bf>(const Packet4bf& a)
{
return static_cast<bfloat16>(predux_mul<Packet4f>(Bf16ToF32(a)));
}
template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a)
{
return preverse<Packet4us>(a);
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4bf, 4>& kernel)
{
PacketBlock<Packet4us, 4> k;
k.packet[0] = kernel.packet[0];
k.packet[1] = kernel.packet[1];
k.packet[2] = kernel.packet[2];
k.packet[3] = kernel.packet[3];
ptranspose(k);
kernel.packet[0] = k.packet[0];
kernel.packet[1] = k.packet[1];
kernel.packet[2] = k.packet[2];
kernel.packet[3] = k.packet[3];
}
template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pabsdiff<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)
{
return pxor<Packet4us>(a, pset1<Packet4us>(static_cast<uint16_t>(0x8000)));
}
//---------- double ----------
// Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double.