Add support for Bfloat16 to use vector instructions on Altivec

architecture
This commit is contained in:
Pedro Caldeira 2020-06-24 15:27:26 -05:00
parent 46f8a18567
commit 704798d1df
2 changed files with 457 additions and 5 deletions

View File

@ -39,10 +39,10 @@ typedef __vector short int Packet8s;
typedef __vector unsigned short int Packet8us;
typedef __vector int8_t Packet16c;
typedef __vector uint8_t Packet16uc;
typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf;
// We don't want to write the same code all the time, but we need to reuse the constants
// and it doesn't really work to declare them global, so we define macros instead
#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \
Packet4f p4f_##NAME = {X, X, X, X}
@ -96,6 +96,7 @@ static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15};
static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
@ -108,6 +109,8 @@ static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 };
static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 };
static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 };
static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 };
static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 };
static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
@ -189,6 +192,48 @@ struct packet_traits<float> : default_packet_traits {
HasBlend = 1
};
};
template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet8bf type;
typedef Packet8bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 8,
HasHalfPacket = 0,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasMin = 1,
HasMax = 1,
HasAbs = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
#ifdef __VSX__
HasSqrt = 1,
#if !EIGEN_COMP_CLANG
HasRsqrt = 1,
#else
HasRsqrt = 0,
#endif
#else
HasSqrt = 0,
HasRsqrt = 0,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
#endif
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
HasNegate = 1,
HasBlend = 1
};
};
template <>
struct packet_traits<int> : default_packet_traits {
typedef Packet4i type;
@ -319,6 +364,12 @@ template<> struct unpacket_traits<Packet16uc>
enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
template<> struct unpacket_traits<Packet8bf>
{
typedef bfloat16 type;
typedef Packet8bf half;
enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
inline std::ostream & operator <<(std::ostream & s, const Packet16c & v)
{
union {
@ -421,6 +472,11 @@ template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const uint8_t* f
return pload_common<Packet16uc>(from);
}
template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from)
{
return pload_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
}
template <typename Packet>
EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){
// some versions of GCC throw "unused-but-set-parameter" (float *to).
@ -431,7 +487,7 @@ EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet
vec_xst(from, 0, to);
#else
vec_st(from, 0, to);
#endif
#endif
}
template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
@ -454,6 +510,11 @@ template<> EIGEN_STRONG_INLINE void pstore<unsigned short int>(unsigned short in
pstore_common<Packet8us>(to, from);
}
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from)
{
pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
}
template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from)
{
pstore_common<Packet16c>(to, from);
@ -513,6 +574,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int fro
return reinterpret_cast<Packet4f>(pset1<Packet4i>(from));
}
template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
return pset1_size8<Packet8us>(reinterpret_cast<const unsigned short int&>(from));
}
template<typename Packet> EIGEN_STRONG_INLINE void
pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a,
Packet& a0, Packet& a1, Packet& a2, Packet& a3)
@ -700,6 +765,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const uint8_t& a)
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f> (const Packet4f& a, const Packet4f& b) { return a + b; }
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i> (const Packet4i& a, const Packet4i& b) { return a + b; }
template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui> (const Packet4ui& a, const Packet4ui& b) { return a + b; }
template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s> (const Packet8s& a, const Packet8s& b) { return a + b; }
template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us> (const Packet8us& a, const Packet8us& b) { return a + b; }
template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c> (const Packet16c& a, const Packet16c& b) { return a + b; }
@ -721,6 +787,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i> (const Packet4i& a,
template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c> (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); }
template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
{
#ifndef __VSX__ // VSX actually provides a div instruction
@ -765,6 +832,7 @@ template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, con
template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); }
template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); }
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b)
{
#ifdef __VSX__
@ -785,6 +853,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a,
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) {
Packet4f c = reinterpret_cast<Packet4f>(vec_cmpge(a,b));
return vec_nor(c,c);
@ -793,12 +862,26 @@ template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet8bf pand<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
return pand<Packet8us>(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
return por<Packet8us>(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
return pxor<Packet8us>(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); }
template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); }
@ -806,6 +889,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, con
template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
return vec_sel(b, a, reinterpret_cast<Packet4ui>(mask));
}
template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) {
Packet4f t = vec_add(reinterpret_cast<Packet4f>(vec_or(vec_and(reinterpret_cast<Packet4ui>(a), p4ui_SIGN), p4ui_PREV0DOT5)), a);
Packet4f res;
@ -852,6 +936,10 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const unsigned short
{
return ploadu_common<Packet8us>(from);
}
template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from)
{
return ploadu_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
}
template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const int8_t* from)
{
return ploadu_common<Packet16c>(from);
@ -909,6 +997,11 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const unsigned sho
return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
}
template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from)
{
return ploadquad<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
}
template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const int8_t* from)
{
Packet16c p;
@ -962,6 +1055,10 @@ template<> EIGEN_STRONG_INLINE void pstoreu<unsigned short int>(unsigned short i
{
pstoreu_common<Packet8us>(to, from);
}
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from)
{
pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
}
template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet16c& from)
{
pstoreu_common<Packet16c>(to, from);
@ -977,17 +1074,17 @@ template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { EIGE
template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { EIGEN_ALIGN16 float x; vec_ste(a, 0, &x); return x; }
template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { EIGEN_ALIGN16 int x; vec_ste(a, 0, &x); return x; }
template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) {
template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) {
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) x;
vec_ste(a, 0, &x);
return x;
}
template<> EIGEN_STRONG_INLINE short int pfirst<Packet8s>(const Packet8s& a) {
template<> EIGEN_STRONG_INLINE short int pfirst<Packet8s>(const Packet8s& a) {
return pfirst_common<Packet8s>(a);
}
template<> EIGEN_STRONG_INLINE unsigned short int pfirst<Packet8us>(const Packet8us& a) {
template<> EIGEN_STRONG_INLINE unsigned short int pfirst<Packet8us>(const Packet8us& a) {
return pfirst_common<Packet8us>(a);
}
@ -1025,6 +1122,10 @@ template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
{
return vec_perm(a, a, p16uc_REVERSE8);
}
template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
{
return preverse<Packet8us>(a);
}
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); }
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); }
@ -1032,6 +1133,10 @@ template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs
template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); }
template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
_EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF);
return pand<Packet8us>(p8us_abs_mask, a);
}
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a)
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
@ -1039,6 +1144,175 @@ template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a)
{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a)
{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask);
return reinterpret_cast<Packet4f>(r);
}
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask);
return reinterpret_cast<Packet4f>(r);
}
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
return vec_sr(a, p4ui_mask);
}
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
return vec_sl(a, p4ui_mask);
}
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
return vec_sl(a, p8us_mask);
}
EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){
return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val));
}
EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
return pand<Packet4f>(
reinterpret_cast<Packet4f>(bf.m_val),
reinterpret_cast<Packet4f>(p4ui_high_mask)
);
}
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
Packet4ui lsb = plogical_shift_right<16>(input);
lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE));
_EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu);
Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS);
input = padd<Packet4ui>(input, rounding_bias);
//Test NaN and Subnormal - Begin
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000);
Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f));
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF);
Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f));
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000);
Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp);
Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
Packet4bi is_mant_not_zero = vec_cmpne(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO));
Packet4ui nan_selector = pand<Packet4ui>(
reinterpret_cast<Packet4ui>(is_max_exp),
reinterpret_cast<Packet4ui>(is_mant_not_zero)
);
Packet4ui subnormal_selector = pand<Packet4ui>(
reinterpret_cast<Packet4ui>(is_zero_exp),
reinterpret_cast<Packet4ui>(is_mant_not_zero)
);
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
input = vec_sel(input, p4ui_nan, nan_selector);
input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
//Test NaN and Subnormal - End
input = plogical_shift_right<16>(input);
return reinterpret_cast<Packet8us>(input);
}
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
Packet4f bf_odd, bf_even;
bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val);
bf_odd = plogical_shift_left<16>(bf_odd);
bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val);
return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
}
#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \
Packet4f a_even = Bf16ToF32Even(A);\
Packet4f a_odd = Bf16ToF32Odd(A);\
Packet4f op_even = OP(a_even);\
Packet4f op_odd = OP(a_odd);\
return F32ToBf16(op_even, op_odd);\
#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \
Packet4f a_even = Bf16ToF32Even(A);\
Packet4f a_odd = Bf16ToF32Odd(A);\
Packet4f b_even = Bf16ToF32Even(B);\
Packet4f b_odd = Bf16ToF32Odd(B);\
Packet4f op_even = OP(a_even, b_even);\
Packet4f op_odd = OP(a_odd, b_odd);\
return F32ToBf16(op_even, op_odd);\
template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pmul<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pdiv<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(psub<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf psqrt<Packet8bf> (const Packet8bf& a){
BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
Packet4f a_even = Bf16ToF32Even(a);
Packet4f a_odd = Bf16ToF32Odd(a);
Packet4f b_even = Bf16ToF32Even(b);
Packet4f b_odd = Bf16ToF32Odd(b);
Packet4f c_even = Bf16ToF32Even(c);
Packet4f c_odd = Bf16ToF32Odd(c);
Packet4f pmadd_even = pmadd<Packet4f>(a_even, b_even, c_even);
Packet4f pmadd_odd = pmadd<Packet4f>(a_odd, b_odd, c_odd);
return F32ToBf16(pmadd_even, pmadd_odd);
}
template<> EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pmax<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq<Packet4f>, a, b);
}
template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst<Packet8us>(a)));
}
template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from)
{
return ploaddup<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
}
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
return pfrexp_float(a,exponent);
@ -1070,6 +1344,13 @@ template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
return pfirst(sum);
}
template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a)
{
float redux_even = predux<Packet4f>(Bf16ToF32Even(a));
float redux_odd = predux<Packet4f>(Bf16ToF32Odd(a));
float f32_result = redux_even + redux_odd;
return bfloat16(f32_result);
}
template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a)
{
union{
@ -1166,6 +1447,15 @@ template<> EIGEN_STRONG_INLINE unsigned short int predux_mul<Packet8us>(const Pa
return pfirst(octo);
}
template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a)
{
float redux_even = predux_mul<Packet4f>(Bf16ToF32Even(a));
float redux_odd = predux_mul<Packet4f>(Bf16ToF32Odd(a));
float f32_result = redux_even * redux_odd;
return bfloat16(f32_result);
}
template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet16c>(const Packet16c& a)
{
Packet16c pair, quad, octo, result;
@ -1211,6 +1501,14 @@ template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a)
return predux_min4<Packet4i>(a);
}
template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a)
{
float redux_even = predux_min<Packet4f>(Bf16ToF32Even(a));
float redux_odd = predux_min<Packet4f>(Bf16ToF32Odd(a));
float f32_result = (std::min)(redux_even, redux_odd);
return bfloat16(f32_result);
}
template<> EIGEN_STRONG_INLINE short int predux_min<Packet8s>(const Packet8s& a)
{
Packet8s pair, quad, octo;
@ -1283,6 +1581,14 @@ template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
return predux_max4<Packet4i>(a);
}
template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a)
{
float redux_even = predux_max<Packet4f>(Bf16ToF32Even(a));
float redux_odd = predux_max<Packet4f>(Bf16ToF32Odd(a));
float f32_result = (std::max)(redux_even, redux_odd);
return bfloat16(f32_result);
}
template<> EIGEN_STRONG_INLINE short int predux_max<Packet8s>(const Packet8s& a)
{
Packet8s pair, quad, octo;
@ -1391,6 +1697,21 @@ ptranspose(PacketBlock<Packet8us,4>& kernel) {
kernel.packet[3] = vec_mergel(t1, t3);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8bf,4>& kernel) {
Packet8us t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val);
t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val);
t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val);
t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val);
kernel.packet[0] = vec_mergeh(t0, t2);
kernel.packet[1] = vec_mergel(t0, t2);
kernel.packet[2] = vec_mergeh(t1, t3);
kernel.packet[3] = vec_mergel(t1, t3);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet16c,4>& kernel) {
Packet16c t0, t1, t2, t3;
@ -1480,6 +1801,37 @@ ptranspose(PacketBlock<Packet8us,8>& kernel) {
kernel.packet[7] = vec_mergel(sum[3], sum[7]);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8bf,8>& kernel) {
Packet8bf v[8], sum[8];
v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val);
v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val);
v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val);
v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val);
v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val);
v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val);
v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val);
v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val);
sum[0] = vec_mergeh(v[0].m_val, v[4].m_val);
sum[1] = vec_mergel(v[0].m_val, v[4].m_val);
sum[2] = vec_mergeh(v[1].m_val, v[5].m_val);
sum[3] = vec_mergel(v[1].m_val, v[5].m_val);
sum[4] = vec_mergeh(v[2].m_val, v[6].m_val);
sum[5] = vec_mergel(v[2].m_val, v[6].m_val);
sum[6] = vec_mergeh(v[3].m_val, v[7].m_val);
sum[7] = vec_mergel(v[3].m_val, v[7].m_val);
kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val);
kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val);
kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val);
kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val);
kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val);
kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val);
kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val);
kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet16c,16>& kernel) {
Packet16c step1[16], step2[16], step3[16];
@ -1656,6 +2008,10 @@ template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, con
return vec_sel(elsePacket, thenPacket, mask);
}
template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) {
return pblend<Packet8us>(ifPacket, thenPacket, elsePacket);
}
template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) {
Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
@ -1694,15 +2050,78 @@ struct type_casting_traits<int, float> {
};
};
template <>
struct type_casting_traits<bfloat16, unsigned short int> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template <>
struct type_casting_traits<unsigned short int, bfloat16> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
return vec_cts(a,0);
}
template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
return vec_ctu(a,0);
}
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
return vec_ctf(a,0);
}
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
return vec_ctf(a,0);
}
template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
Packet4f float_even = Bf16ToF32Even(a);
Packet4f float_odd = Bf16ToF32Odd(a);
Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
//Check values that are bigger than USHRT_MAX (0xFFFF)
Packet4bi overflow_selector;
if(vec_any_gt(int_even, p4ui_low_mask)){
overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
}
if(vec_any_gt(int_odd, p4ui_low_mask)){
overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
}
low_odd = plogical_shift_left<16>(low_odd);
Packet4ui int_final = por<Packet4ui>(low_even, low_odd);
return reinterpret_cast<Packet8us>(int_final);
}
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
//short -> int -> float -> bfloat16
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
Packet4ui int_odd = plogical_shift_right<16>(int_cast);
Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
return F32ToBf16(float_even, float_odd);
}
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
return reinterpret_cast<Packet4i>(a);
}
@ -2024,6 +2443,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, cons
Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
return vec_sel(elsePacket, thenPacket, mask);
}
#endif // __VSX__
} // end namespace internal

View File

@ -247,6 +247,20 @@ void packetmath_boolean_mask_ops() {
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
}
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
//Test (-0) == (0) for signed operations
for (int i = 0; i < PacketSize; ++i) {
data1[i] = Scalar(-0.0);
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
}
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
//Test NaN
for (int i = 0; i < PacketSize; ++i) {
data1[i] = std::numeric_limits<Scalar>::quiet_NaN();
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
}
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
}
// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path
@ -254,6 +268,22 @@ void packetmath_boolean_mask_ops() {
template<>
void packetmath_boolean_mask_ops<bool, internal::packet_traits<bool>::type>() {}
template <typename Scalar, typename Packet>
void packetmath_minus_zero_add() {
const int PacketSize = internal::unpacket_traits<Packet>::size;
const int size = 2 * PacketSize;
EIGEN_ALIGN_MAX Scalar data1[size];
EIGEN_ALIGN_MAX Scalar data2[size];
EIGEN_ALIGN_MAX Scalar ref[size];
for (int i = 0; i < PacketSize; ++i) {
data1[i] = Scalar(-0.0);
data1[i + PacketSize] = Scalar(-0.0);
}
CHECK_CWISE2_IF(internal::packet_traits<Scalar>::HasAdd, REF_ADD, internal::padd);
}
template <typename Scalar, typename Packet>
void packetmath() {
typedef internal::packet_traits<Scalar> PacketTraits;
@ -454,6 +484,7 @@ void packetmath() {
packetmath_boolean_mask_ops<Scalar, Packet>();
packetmath_pcast_ops_runner<Scalar, Packet>::run();
packetmath_minus_zero_add<Scalar, Packet>();
}
template <typename Scalar, typename Packet>