mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-06 14:14:46 +08:00
bfloat16 packetmath for Arm Neon backend
This commit is contained in:
parent
704798d1df
commit
8ba1b0f41a
@ -38,6 +38,12 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Pack
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh<Packet4f>(const Packet4f& x)
|
||||
{ return internal::generic_fast_tanh_float(x); }
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -3218,6 +3218,269 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
//---------- bfloat16 ----------
|
||||
// TODO: Add support for native armv8.6-a bfloat16_t
|
||||
|
||||
// TODO: Guard if we have native bfloat16 support
|
||||
typedef eigen_packet_wrapper<uint16x4_t, 19> Packet4bf;
|
||||
|
||||
template<> struct is_arithmetic<Packet4bf> { enum { value = true }; };
|
||||
|
||||
template<> struct packet_traits<bfloat16> : default_packet_traits
|
||||
{
|
||||
typedef Packet4bf type;
|
||||
typedef Packet4bf half;
|
||||
enum
|
||||
{
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
size = 4,
|
||||
HasHalfPacket = 0,
|
||||
|
||||
HasCmp = 1,
|
||||
HasAdd = 1,
|
||||
HasSub = 1,
|
||||
HasShift = 1,
|
||||
HasMul = 1,
|
||||
HasNegate = 1,
|
||||
HasAbs = 1,
|
||||
HasArg = 0,
|
||||
HasAbs2 = 1,
|
||||
HasAbsDiff = 1,
|
||||
HasMin = 1,
|
||||
HasMax = 1,
|
||||
HasConj = 1,
|
||||
HasSetLinear = 0,
|
||||
HasBlend = 0,
|
||||
HasDiv = 1,
|
||||
HasFloor = 1,
|
||||
|
||||
HasSin = EIGEN_FAST_MATH,
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasSqrt = 0,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH
|
||||
};
|
||||
};
|
||||
|
||||
template<> struct unpacket_traits<Packet4bf>
|
||||
{
|
||||
typedef bfloat16 type;
|
||||
typedef Packet4bf half;
|
||||
enum
|
||||
{
|
||||
size = 4,
|
||||
alignment = Aligned16,
|
||||
vectorizable = true,
|
||||
masked_load_available = false,
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p)
|
||||
{
|
||||
// See the scalar implemention in BFloat16.h for a comprehensible explanation
|
||||
// of this fast rounding algorithm
|
||||
Packet4ui input = reinterpret_cast<Packet4ui>(p);
|
||||
|
||||
// lsb = (input >> 16) & 1
|
||||
Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1));
|
||||
|
||||
// rounding_bias = 0x7fff + lsb
|
||||
Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff));
|
||||
|
||||
// input += rounding_bias
|
||||
input = vaddq_u32(input, rounding_bias);
|
||||
|
||||
// input = input >> 16
|
||||
input = vshrq_n_u32(input, 16);
|
||||
|
||||
// Replace float-nans by bfloat16-nans, that is 0x7fc0
|
||||
const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0);
|
||||
const Packet4ui mask = vceqq_f32(p, p);
|
||||
input = vbslq_u32(mask, input, bf16_nan);
|
||||
|
||||
// output = static_cast<uint16_t>(input)
|
||||
return vmovn_u32(input);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
|
||||
{
|
||||
return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
|
||||
return pset1<Packet4us>(from.value);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<uint16_t>(pfirst<Packet4us>(from)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from)
|
||||
{
|
||||
return pload<Packet4us>(reinterpret_cast<const uint16_t*>(from));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from)
|
||||
{
|
||||
return ploadu<Packet4us>(reinterpret_cast<const uint16_t*>(from));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from)
|
||||
{
|
||||
EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet4bf& from)
|
||||
{
|
||||
EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf ploaddup<Packet4bf>(const bfloat16* from)
|
||||
{
|
||||
return ploaddup<Packet4us>(reinterpret_cast<const uint16_t*>(from));
|
||||
}
|
||||
|
||||
template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) {
|
||||
return F32ToBf16(pabs<Packet4f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template <> EIGEN_STRONG_INLINE Packet4bf pmin<Packet4bf>(const Packet4bf &a,
|
||||
const Packet4bf &b)
|
||||
{
|
||||
return F32ToBf16(pmin<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a,
|
||||
const Packet4bf &b)
|
||||
{
|
||||
return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) {
|
||||
return por<Packet4us>(a, b);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) {
|
||||
return pxor<Packet4us>(a, b);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) {
|
||||
return pand<Packet4us>(a, b);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) {
|
||||
return pandnot<Packet4us>(a, b);
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC inline Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a,
|
||||
const Packet4bf& b)
|
||||
{
|
||||
return pselect<Packet4us>(mask, a, b);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
|
||||
return F32ToBf16(padd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf psub<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
|
||||
return F32ToBf16(psub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
|
||||
return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
|
||||
return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE Packet4bf pgather<bfloat16, Packet4bf>(const bfloat16* from, Index stride)
|
||||
{
|
||||
return pgather<uint16_t, Packet4us>(reinterpret_cast<const uint16_t*>(from), stride);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet4bf>(bfloat16* to, const Packet4bf& from, Index stride)
|
||||
{
|
||||
pscatter<uint16_t, Packet4us>(reinterpret_cast<uint16_t*>(to), from, stride);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return static_cast<bfloat16>(predux<Packet4f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return static_cast<bfloat16>(predux_max<Packet4f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return static_cast<bfloat16>(predux_min<Packet4f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return static_cast<bfloat16>(predux_mul<Packet4f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return preverse<Packet4us>(a);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4bf, 4>& kernel)
|
||||
{
|
||||
PacketBlock<Packet4us, 4> k;
|
||||
k.packet[0] = kernel.packet[0];
|
||||
k.packet[1] = kernel.packet[1];
|
||||
k.packet[2] = kernel.packet[2];
|
||||
k.packet[3] = kernel.packet[3];
|
||||
ptranspose(k);
|
||||
kernel.packet[0] = k.packet[0];
|
||||
kernel.packet[1] = k.packet[1];
|
||||
kernel.packet[2] = k.packet[2];
|
||||
kernel.packet[3] = k.packet[3];
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
|
||||
{
|
||||
return F32ToBf16(pabsdiff<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
|
||||
{
|
||||
return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
|
||||
{
|
||||
return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
|
||||
{
|
||||
return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)
|
||||
{
|
||||
return pxor<Packet4us>(a, pset1<Packet4us>(static_cast<uint16_t>(0x8000)));
|
||||
}
|
||||
|
||||
//---------- double ----------
|
||||
|
||||
// Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double.
|
||||
|
Loading…
Reference in New Issue
Block a user