From 704798d1df4866be335ca013da19a44791f85a7e Mon Sep 17 00:00:00 2001 From: Pedro Caldeira Date: Wed, 24 Jun 2020 15:27:26 -0500 Subject: [PATCH] Add support for Bfloat16 to use vector instructions on Altivec architecture --- Eigen/src/Core/arch/AltiVec/PacketMath.h | 431 ++++++++++++++++++++++- test/packetmath.cpp | 31 ++ 2 files changed, 457 insertions(+), 5 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index e0da7c329..09ad0a74e 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -39,10 +39,10 @@ typedef __vector short int Packet8s; typedef __vector unsigned short int Packet8us; typedef __vector int8_t Packet16c; typedef __vector uint8_t Packet16uc; +typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf; // We don't want to write the same code all the time, but we need to reuse the constants // and it doesn't really work to declare them global, so we define macros instead - #define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \ Packet4f p4f_##NAME = {X, X, X, X} @@ -96,6 +96,7 @@ static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 }; static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 }; static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 }; static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 }; + static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, @@ -108,6 +109,8 @@ static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 }; static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 }; static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 }; static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 }; +static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 }; +static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 }; static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 }; @@ -189,6 +192,48 @@ struct packet_traits : default_packet_traits { HasBlend = 1 }; }; +template <> +struct packet_traits : default_packet_traits { + typedef Packet8bf type; + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, +#ifdef __VSX__ + HasSqrt = 1, +#if !EIGEN_COMP_CLANG + HasRsqrt = 1, +#else + HasRsqrt = 0, +#endif +#else + HasSqrt = 0, + HasRsqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; + template <> struct packet_traits : default_packet_traits { typedef Packet4i type; @@ -319,6 +364,12 @@ template<> struct unpacket_traits enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; +template<> struct unpacket_traits +{ + typedef bfloat16 type; + typedef Packet8bf half; + enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; inline std::ostream & operator <<(std::ostream & s, const Packet16c & v) { union { @@ -421,6 +472,11 @@ template<> EIGEN_STRONG_INLINE Packet16uc pload(const uint8_t* f return pload_common(from); } +template<> EIGEN_STRONG_INLINE Packet8bf pload(const bfloat16* from) +{ + return pload_common(reinterpret_cast(from)); +} + template EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){ // some versions of GCC throw "unused-but-set-parameter" (float *to). @@ -431,7 +487,7 @@ EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet vec_xst(from, 0, to); #else vec_st(from, 0, to); -#endif +#endif } template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) @@ -454,6 +510,11 @@ template<> EIGEN_STRONG_INLINE void pstore(unsigned short in pstore_common(to, from); } +template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet8bf& from) +{ + pstore_common(reinterpret_cast(to), from); +} + template<> EIGEN_STRONG_INLINE void pstore(int8_t* to, const Packet16c& from) { pstore_common(to, from); @@ -513,6 +574,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pset1frombits(unsigned int fro return reinterpret_cast(pset1(from)); } +template<> EIGEN_STRONG_INLINE Packet8bf pset1(const bfloat16& from) { + return pset1_size8(reinterpret_cast(from)); +} + template EIGEN_STRONG_INLINE void pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a, Packet& a0, Packet& a1, Packet& a2, Packet& a3) @@ -700,6 +765,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc plset(const uint8_t& a) template<> EIGEN_STRONG_INLINE Packet4f padd (const Packet4f& a, const Packet4f& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet4i padd (const Packet4i& a, const Packet4i& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet4ui padd (const Packet4ui& a, const Packet4ui& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet8s padd (const Packet8s& a, const Packet8s& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet8us padd (const Packet8us& a, const Packet8us& b) { return a + b; } template<> EIGEN_STRONG_INLINE Packet16c padd (const Packet16c& a, const Packet16c& b) { return a + b; } @@ -721,6 +787,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pmul (const Packet4i& a, template<> EIGEN_STRONG_INLINE Packet16c pmul (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); } template<> EIGEN_STRONG_INLINE Packet16uc pmul(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); } + template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) { #ifndef __VSX__ // VSX actually provides a div instruction @@ -765,6 +832,7 @@ template<> EIGEN_STRONG_INLINE Packet8us pmin(const Packet8us& a, con template<> EIGEN_STRONG_INLINE Packet16c pmin(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); } template<> EIGEN_STRONG_INLINE Packet16uc pmin(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); } + template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { #ifdef __VSX__ @@ -785,6 +853,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pmax(const Packet16uc& a, template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmple(a,b)); } template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmplt(a,b)); } template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmpeq(a,b)); } + template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { Packet4f c = reinterpret_cast(vec_cmpge(a,b)); return vec_nor(c,c); @@ -793,12 +862,26 @@ template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4 template<> EIGEN_STRONG_INLINE Packet4f pand(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); } template<> EIGEN_STRONG_INLINE Packet4i pand(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet4ui pand(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us pand(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a, const Packet8bf& b) { + return pand(a, b); +} + template<> EIGEN_STRONG_INLINE Packet4f por(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); } template<> EIGEN_STRONG_INLINE Packet4i por(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8s por(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us por(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a, const Packet8bf& b) { + return por(a, b); +} template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); } template<> EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a, const Packet8bf& b) { + return pxor(a, b); +} template<> EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); } template<> EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); } @@ -806,6 +889,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, con template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) { return vec_sel(b, a, reinterpret_cast(mask)); } + template<> EIGEN_STRONG_INLINE Packet4f pround(const Packet4f& a) { Packet4f t = vec_add(reinterpret_cast(vec_or(vec_and(reinterpret_cast(a), p4ui_SIGN), p4ui_PREV0DOT5)), a); Packet4f res; @@ -852,6 +936,10 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadu(const unsigned short { return ploadu_common(from); } +template<> EIGEN_STRONG_INLINE Packet8bf ploadu(const bfloat16* from) +{ + return ploadu_common(reinterpret_cast(from)); +} template<> EIGEN_STRONG_INLINE Packet16c ploadu(const int8_t* from) { return ploadu_common(from); @@ -909,6 +997,11 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadquad(const unsigned sho return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI); } +template<> EIGEN_STRONG_INLINE Packet8bf ploadquad(const bfloat16* from) +{ + return ploadquad(reinterpret_cast(from)); +} + template<> EIGEN_STRONG_INLINE Packet16c ploaddup(const int8_t* from) { Packet16c p; @@ -962,6 +1055,10 @@ template<> EIGEN_STRONG_INLINE void pstoreu(unsigned short i { pstoreu_common(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet8bf& from) +{ + pstoreu_common(reinterpret_cast(to), from); +} template<> EIGEN_STRONG_INLINE void pstoreu(int8_t* to, const Packet16c& from) { pstoreu_common(to, from); @@ -977,17 +1074,17 @@ template<> EIGEN_STRONG_INLINE void prefetch(const int* addr) { EIGE template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { EIGEN_ALIGN16 float x; vec_ste(a, 0, &x); return x; } template<> EIGEN_STRONG_INLINE int pfirst(const Packet4i& a) { EIGEN_ALIGN16 int x; vec_ste(a, 0, &x); return x; } -template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) { +template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) { EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) x; vec_ste(a, 0, &x); return x; } -template<> EIGEN_STRONG_INLINE short int pfirst(const Packet8s& a) { +template<> EIGEN_STRONG_INLINE short int pfirst(const Packet8s& a) { return pfirst_common(a); } -template<> EIGEN_STRONG_INLINE unsigned short int pfirst(const Packet8us& a) { +template<> EIGEN_STRONG_INLINE unsigned short int pfirst(const Packet8us& a) { return pfirst_common(a); } @@ -1025,6 +1122,10 @@ template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a) { return vec_perm(a, a, p16uc_REVERSE8); } +template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) +{ + return preverse(a); +} template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); } template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); } @@ -1032,6 +1133,10 @@ template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; } template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); } template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { + _EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF); + return pand(p8us_abs_mask, a); +} template EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a) { return vec_sra(a,reinterpret_cast(pset1(N))); } @@ -1039,6 +1144,175 @@ template EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a) { return vec_sr(a,reinterpret_cast(pset1(N))); } template EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a) { return vec_sl(a,reinterpret_cast(pset1(N))); } +template EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + Packet4ui r = vec_sl(reinterpret_cast(a), p4ui_mask); + return reinterpret_cast(r); +} + +template EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + Packet4ui r = vec_sr(reinterpret_cast(a), p4ui_mask); + return reinterpret_cast(r); +} + +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + return vec_sr(a, p4ui_mask); +} + +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + return vec_sl(a, p4ui_mask); +} + +template EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N); + return vec_sl(a, p8us_mask); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){ + return plogical_shift_left<16>(reinterpret_cast(bf.m_val)); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); + return pand( + reinterpret_cast(bf.m_val), + reinterpret_cast(p4ui_high_mask) + ); +} + +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ + Packet4ui input = reinterpret_cast(p4f); + Packet4ui lsb = plogical_shift_right<16>(input); + lsb = pand(lsb, reinterpret_cast(p4i_ONE)); + + _EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu); + Packet4ui rounding_bias = padd(lsb, p4ui_BIAS); + input = padd(input, rounding_bias); + + //Test NaN and Subnormal - Begin + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000); + Packet4ui exp = pand(p4ui_exp_mask, reinterpret_cast(p4f)); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF); + Packet4ui mantissa = pand(p4ui_mantissa_mask, reinterpret_cast(p4f)); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000); + Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp); + Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast(p4i_ZERO)); + + Packet4bi is_mant_not_zero = vec_cmpne(mantissa, reinterpret_cast(p4i_ZERO)); + Packet4ui nan_selector = pand( + reinterpret_cast(is_max_exp), + reinterpret_cast(is_mant_not_zero) + ); + + Packet4ui subnormal_selector = pand( + reinterpret_cast(is_zero_exp), + reinterpret_cast(is_mant_not_zero) + ); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000); + input = vec_sel(input, p4ui_nan, nan_selector); + input = vec_sel(input, reinterpret_cast(p4f), subnormal_selector); + //Test NaN and Subnormal - End + + input = plogical_shift_right<16>(input); + return reinterpret_cast(input); +} + +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){ + Packet4f bf_odd, bf_even; + bf_odd = reinterpret_cast(F32ToBf16(odd).m_val); + bf_odd = plogical_shift_left<16>(bf_odd); + bf_even = reinterpret_cast(F32ToBf16(even).m_val); + return reinterpret_cast(por(bf_even, bf_odd)); +} +#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f op_even = OP(a_even);\ + Packet4f op_odd = OP(a_odd);\ + return F32ToBf16(op_even, op_odd);\ + +#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f b_even = Bf16ToF32Even(B);\ + Packet4f b_odd = Bf16ToF32Odd(B);\ + Packet4f op_even = OP(a_even, b_even);\ + Packet4f op_odd = OP(a_odd, b_odd);\ + return F32ToBf16(op_even, op_odd);\ + +template<> EIGEN_STRONG_INLINE Packet8bf padd(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(padd, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmul(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmul, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pdiv(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pdiv, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { + BF16_TO_F32_UNARY_OP_WRAPPER(pnegate, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psub(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(psub, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psqrt (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) { + Packet4f a_even = Bf16ToF32Even(a); + Packet4f a_odd = Bf16ToF32Odd(a); + Packet4f b_even = Bf16ToF32Even(b); + Packet4f b_odd = Bf16ToF32Odd(b); + Packet4f c_even = Bf16ToF32Even(c); + Packet4f c_odd = Bf16ToF32Odd(c); + Packet4f pmadd_even = pmadd(a_even, b_even, c_even); + Packet4f pmadd_odd = pmadd(a_odd, b_odd, c_odd); + return F32ToBf16(pmadd_even, pmadd_odd); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmin(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmin, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmax(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmax, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq, a, b); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) { + return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploaddup(const bfloat16* from) +{ + return ploaddup(reinterpret_cast(from)); +} template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { return pfrexp_float(a,exponent); @@ -1070,6 +1344,13 @@ template<> EIGEN_STRONG_INLINE int predux(const Packet4i& a) return pfirst(sum); } +template<> EIGEN_STRONG_INLINE bfloat16 predux(const Packet8bf& a) +{ + float redux_even = predux(Bf16ToF32Even(a)); + float redux_odd = predux(Bf16ToF32Odd(a)); + float f32_result = redux_even + redux_odd; + return bfloat16(f32_result); +} template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a) { union{ @@ -1166,6 +1447,15 @@ template<> EIGEN_STRONG_INLINE unsigned short int predux_mul(const Pa return pfirst(octo); } +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet8bf& a) +{ + float redux_even = predux_mul(Bf16ToF32Even(a)); + float redux_odd = predux_mul(Bf16ToF32Odd(a)); + float f32_result = redux_even * redux_odd; + return bfloat16(f32_result); +} + + template<> EIGEN_STRONG_INLINE int8_t predux_mul(const Packet16c& a) { Packet16c pair, quad, octo, result; @@ -1211,6 +1501,14 @@ template<> EIGEN_STRONG_INLINE int predux_min(const Packet4i& a) return predux_min4(a); } +template<> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) +{ + float redux_even = predux_min(Bf16ToF32Even(a)); + float redux_odd = predux_min(Bf16ToF32Odd(a)); + float f32_result = (std::min)(redux_even, redux_odd); + return bfloat16(f32_result); +} + template<> EIGEN_STRONG_INLINE short int predux_min(const Packet8s& a) { Packet8s pair, quad, octo; @@ -1283,6 +1581,14 @@ template<> EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) return predux_max4(a); } +template<> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) +{ + float redux_even = predux_max(Bf16ToF32Even(a)); + float redux_odd = predux_max(Bf16ToF32Odd(a)); + float f32_result = (std::max)(redux_even, redux_odd); + return bfloat16(f32_result); +} + template<> EIGEN_STRONG_INLINE short int predux_max(const Packet8s& a) { Packet8s pair, quad, octo; @@ -1391,6 +1697,21 @@ ptranspose(PacketBlock& kernel) { kernel.packet[3] = vec_mergel(t1, t3); } + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8us t0, t1, t2, t3; + + t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val); + t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val); + t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val); + t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { Packet16c t0, t1, t2, t3; @@ -1480,6 +1801,37 @@ ptranspose(PacketBlock& kernel) { kernel.packet[7] = vec_mergel(sum[3], sum[7]); } +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8bf v[8], sum[8]; + + v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val); + v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val); + v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val); + v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val); + v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val); + v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val); + v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val); + v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val); + sum[0] = vec_mergeh(v[0].m_val, v[4].m_val); + sum[1] = vec_mergel(v[0].m_val, v[4].m_val); + sum[2] = vec_mergeh(v[1].m_val, v[5].m_val); + sum[3] = vec_mergel(v[1].m_val, v[5].m_val); + sum[4] = vec_mergeh(v[2].m_val, v[6].m_val); + sum[5] = vec_mergel(v[2].m_val, v[6].m_val); + sum[6] = vec_mergeh(v[3].m_val, v[7].m_val); + sum[7] = vec_mergel(v[3].m_val, v[7].m_val); + + kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val); + kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val); + kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val); + kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val); + kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val); + kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val); + kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val); + kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val); +} + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { Packet16c step1[16], step2[16], step3[16]; @@ -1656,6 +2008,10 @@ template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, con return vec_sel(elsePacket, thenPacket, mask); } +template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) { + return pblend(ifPacket, thenPacket, elsePacket); +} + template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) { Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7], @@ -1694,15 +2050,78 @@ struct type_casting_traits { }; }; +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; template<> EIGEN_STRONG_INLINE Packet4i pcast(const Packet4f& a) { return vec_cts(a,0); } +template<> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4f& a) { + return vec_ctu(a,0); +} + template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4i& a) { return vec_ctf(a,0); } +template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4ui& a) { + return vec_ctf(a,0); +} + +template<> EIGEN_STRONG_INLINE Packet8us pcast(const Packet8bf& a) { + Packet4f float_even = Bf16ToF32Even(a); + Packet4f float_odd = Bf16ToF32Odd(a); + Packet4ui int_even = pcast(float_even); + Packet4ui int_odd = pcast(float_odd); + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF); + Packet4ui low_even = pand(int_even, p4ui_low_mask); + Packet4ui low_odd = pand(int_odd, p4ui_low_mask); + + //Check values that are bigger than USHRT_MAX (0xFFFF) + Packet4bi overflow_selector; + if(vec_any_gt(int_even, p4ui_low_mask)){ + overflow_selector = vec_cmpgt(int_even, p4ui_low_mask); + low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector); + } + if(vec_any_gt(int_odd, p4ui_low_mask)){ + overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask); + low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector); + } + + low_odd = plogical_shift_left<16>(low_odd); + + Packet4ui int_final = por(low_even, low_odd); + return reinterpret_cast(int_final); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcast(const Packet8us& a) { + //short -> int -> float -> bfloat16 + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF); + Packet4ui int_cast = reinterpret_cast(a); + Packet4ui int_even = pand(int_cast, p4ui_low_mask); + Packet4ui int_odd = plogical_shift_right<16>(int_cast); + Packet4f float_even = pcast(int_even); + Packet4f float_odd = pcast(int_odd); + return F32ToBf16(float_even, float_odd); +} + + template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4f& a) { return reinterpret_cast(a); } @@ -2024,6 +2443,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, cons Packet2bl mask = reinterpret_cast( vec_cmpeq(reinterpret_cast(select), reinterpret_cast(p2l_ONE)) ); return vec_sel(elsePacket, thenPacket, mask); } + + #endif // __VSX__ } // end namespace internal diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 0fe29102a..c8ea3139e 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -247,6 +247,20 @@ void packetmath_boolean_mask_ops() { data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + + //Test (-0) == (0) for signed operations + for (int i = 0; i < PacketSize; ++i) { + data1[i] = Scalar(-0.0); + data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); + } + CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + + //Test NaN + for (int i = 0; i < PacketSize; ++i) { + data1[i] = std::numeric_limits::quiet_NaN(); + data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); + } + CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); } // Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path @@ -254,6 +268,22 @@ void packetmath_boolean_mask_ops() { template<> void packetmath_boolean_mask_ops::type>() {} +template +void packetmath_minus_zero_add() { + const int PacketSize = internal::unpacket_traits::size; + const int size = 2 * PacketSize; + EIGEN_ALIGN_MAX Scalar data1[size]; + EIGEN_ALIGN_MAX Scalar data2[size]; + EIGEN_ALIGN_MAX Scalar ref[size]; + + for (int i = 0; i < PacketSize; ++i) { + data1[i] = Scalar(-0.0); + data1[i + PacketSize] = Scalar(-0.0); + } + CHECK_CWISE2_IF(internal::packet_traits::HasAdd, REF_ADD, internal::padd); +} + + template void packetmath() { typedef internal::packet_traits PacketTraits; @@ -454,6 +484,7 @@ void packetmath() { packetmath_boolean_mask_ops(); packetmath_pcast_ops_runner::run(); + packetmath_minus_zero_add(); } template