mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-06 14:14:46 +08:00
Add support for Bfloat16 to use vector instructions on Altivec
architecture
This commit is contained in:
parent
46f8a18567
commit
704798d1df
@ -39,10 +39,10 @@ typedef __vector short int Packet8s;
|
|||||||
typedef __vector unsigned short int Packet8us;
|
typedef __vector unsigned short int Packet8us;
|
||||||
typedef __vector int8_t Packet16c;
|
typedef __vector int8_t Packet16c;
|
||||||
typedef __vector uint8_t Packet16uc;
|
typedef __vector uint8_t Packet16uc;
|
||||||
|
typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf;
|
||||||
|
|
||||||
// We don't want to write the same code all the time, but we need to reuse the constants
|
// We don't want to write the same code all the time, but we need to reuse the constants
|
||||||
// and it doesn't really work to declare them global, so we define macros instead
|
// and it doesn't really work to declare them global, so we define macros instead
|
||||||
|
|
||||||
#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \
|
#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \
|
||||||
Packet4f p4f_##NAME = {X, X, X, X}
|
Packet4f p4f_##NAME = {X, X, X, X}
|
||||||
|
|
||||||
@ -96,6 +96,7 @@ static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
|
|||||||
static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
|
static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
|
||||||
static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
||||||
static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
||||||
|
|
||||||
static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
|
static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
8, 9, 10, 11, 12, 13, 14, 15};
|
8, 9, 10, 11, 12, 13, 14, 15};
|
||||||
static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
|
static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
@ -108,6 +109,8 @@ static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 };
|
|||||||
static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
|
static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
|
||||||
static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 };
|
static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 };
|
||||||
static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 };
|
static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 };
|
||||||
|
static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 };
|
||||||
|
static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 };
|
||||||
|
|
||||||
static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
|
static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
|
||||||
|
|
||||||
@ -189,6 +192,48 @@ struct packet_traits<float> : default_packet_traits {
|
|||||||
HasBlend = 1
|
HasBlend = 1
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
template <>
|
||||||
|
struct packet_traits<bfloat16> : default_packet_traits {
|
||||||
|
typedef Packet8bf type;
|
||||||
|
typedef Packet8bf half;
|
||||||
|
enum {
|
||||||
|
Vectorizable = 1,
|
||||||
|
AlignedOnScalar = 1,
|
||||||
|
size = 8,
|
||||||
|
HasHalfPacket = 0,
|
||||||
|
|
||||||
|
HasAdd = 1,
|
||||||
|
HasSub = 1,
|
||||||
|
HasMul = 1,
|
||||||
|
HasDiv = 1,
|
||||||
|
HasMin = 1,
|
||||||
|
HasMax = 1,
|
||||||
|
HasAbs = 1,
|
||||||
|
HasSin = EIGEN_FAST_MATH,
|
||||||
|
HasCos = EIGEN_FAST_MATH,
|
||||||
|
HasLog = 1,
|
||||||
|
HasExp = 1,
|
||||||
|
#ifdef __VSX__
|
||||||
|
HasSqrt = 1,
|
||||||
|
#if !EIGEN_COMP_CLANG
|
||||||
|
HasRsqrt = 1,
|
||||||
|
#else
|
||||||
|
HasRsqrt = 0,
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
HasSqrt = 0,
|
||||||
|
HasRsqrt = 0,
|
||||||
|
HasTanh = EIGEN_FAST_MATH,
|
||||||
|
HasErf = EIGEN_FAST_MATH,
|
||||||
|
#endif
|
||||||
|
HasRound = 1,
|
||||||
|
HasFloor = 1,
|
||||||
|
HasCeil = 1,
|
||||||
|
HasNegate = 1,
|
||||||
|
HasBlend = 1
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct packet_traits<int> : default_packet_traits {
|
struct packet_traits<int> : default_packet_traits {
|
||||||
typedef Packet4i type;
|
typedef Packet4i type;
|
||||||
@ -319,6 +364,12 @@ template<> struct unpacket_traits<Packet16uc>
|
|||||||
enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<> struct unpacket_traits<Packet8bf>
|
||||||
|
{
|
||||||
|
typedef bfloat16 type;
|
||||||
|
typedef Packet8bf half;
|
||||||
|
enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
||||||
|
};
|
||||||
inline std::ostream & operator <<(std::ostream & s, const Packet16c & v)
|
inline std::ostream & operator <<(std::ostream & s, const Packet16c & v)
|
||||||
{
|
{
|
||||||
union {
|
union {
|
||||||
@ -421,6 +472,11 @@ template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const uint8_t* f
|
|||||||
return pload_common<Packet16uc>(from);
|
return pload_common<Packet16uc>(from);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from)
|
||||||
|
{
|
||||||
|
return pload_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){
|
EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){
|
||||||
// some versions of GCC throw "unused-but-set-parameter" (float *to).
|
// some versions of GCC throw "unused-but-set-parameter" (float *to).
|
||||||
@ -454,6 +510,11 @@ template<> EIGEN_STRONG_INLINE void pstore<unsigned short int>(unsigned short in
|
|||||||
pstore_common<Packet8us>(to, from);
|
pstore_common<Packet8us>(to, from);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from)
|
||||||
|
{
|
||||||
|
pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from)
|
template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from)
|
||||||
{
|
{
|
||||||
pstore_common<Packet16c>(to, from);
|
pstore_common<Packet16c>(to, from);
|
||||||
@ -513,6 +574,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int fro
|
|||||||
return reinterpret_cast<Packet4f>(pset1<Packet4i>(from));
|
return reinterpret_cast<Packet4f>(pset1<Packet4i>(from));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
|
||||||
|
return pset1_size8<Packet8us>(reinterpret_cast<const unsigned short int&>(from));
|
||||||
|
}
|
||||||
|
|
||||||
template<typename Packet> EIGEN_STRONG_INLINE void
|
template<typename Packet> EIGEN_STRONG_INLINE void
|
||||||
pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a,
|
pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a,
|
||||||
Packet& a0, Packet& a1, Packet& a2, Packet& a3)
|
Packet& a0, Packet& a1, Packet& a2, Packet& a3)
|
||||||
@ -700,6 +765,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const uint8_t& a)
|
|||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f> (const Packet4f& a, const Packet4f& b) { return a + b; }
|
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f> (const Packet4f& a, const Packet4f& b) { return a + b; }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i> (const Packet4i& a, const Packet4i& b) { return a + b; }
|
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i> (const Packet4i& a, const Packet4i& b) { return a + b; }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui> (const Packet4ui& a, const Packet4ui& b) { return a + b; }
|
||||||
template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s> (const Packet8s& a, const Packet8s& b) { return a + b; }
|
template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s> (const Packet8s& a, const Packet8s& b) { return a + b; }
|
||||||
template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us> (const Packet8us& a, const Packet8us& b) { return a + b; }
|
template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us> (const Packet8us& a, const Packet8us& b) { return a + b; }
|
||||||
template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c> (const Packet16c& a, const Packet16c& b) { return a + b; }
|
template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c> (const Packet16c& a, const Packet16c& b) { return a + b; }
|
||||||
@ -721,6 +787,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i> (const Packet4i& a,
|
|||||||
template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c> (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); }
|
template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c> (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); }
|
template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); }
|
||||||
|
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
|
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
|
||||||
{
|
{
|
||||||
#ifndef __VSX__ // VSX actually provides a div instruction
|
#ifndef __VSX__ // VSX actually provides a div instruction
|
||||||
@ -765,6 +832,7 @@ template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, con
|
|||||||
template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); }
|
||||||
|
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b)
|
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b)
|
||||||
{
|
{
|
||||||
#ifdef __VSX__
|
#ifdef __VSX__
|
||||||
@ -785,6 +853,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a,
|
|||||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
|
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
|
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
|
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) {
|
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) {
|
||||||
Packet4f c = reinterpret_cast<Packet4f>(vec_cmpge(a,b));
|
Packet4f c = reinterpret_cast<Packet4f>(vec_cmpge(a,b));
|
||||||
return vec_nor(c,c);
|
return vec_nor(c,c);
|
||||||
@ -793,12 +862,26 @@ template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4
|
|||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pand<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return pand<Packet8us>(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return por<Packet8us>(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return pxor<Packet8us>(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); }
|
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); }
|
template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); }
|
||||||
@ -806,6 +889,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, con
|
|||||||
template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
|
template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
|
||||||
return vec_sel(b, a, reinterpret_cast<Packet4ui>(mask));
|
return vec_sel(b, a, reinterpret_cast<Packet4ui>(mask));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) {
|
template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) {
|
||||||
Packet4f t = vec_add(reinterpret_cast<Packet4f>(vec_or(vec_and(reinterpret_cast<Packet4ui>(a), p4ui_SIGN), p4ui_PREV0DOT5)), a);
|
Packet4f t = vec_add(reinterpret_cast<Packet4f>(vec_or(vec_and(reinterpret_cast<Packet4ui>(a), p4ui_SIGN), p4ui_PREV0DOT5)), a);
|
||||||
Packet4f res;
|
Packet4f res;
|
||||||
@ -852,6 +936,10 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const unsigned short
|
|||||||
{
|
{
|
||||||
return ploadu_common<Packet8us>(from);
|
return ploadu_common<Packet8us>(from);
|
||||||
}
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from)
|
||||||
|
{
|
||||||
|
return ploadu_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
|
||||||
|
}
|
||||||
template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const int8_t* from)
|
template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const int8_t* from)
|
||||||
{
|
{
|
||||||
return ploadu_common<Packet16c>(from);
|
return ploadu_common<Packet16c>(from);
|
||||||
@ -909,6 +997,11 @@ template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const unsigned sho
|
|||||||
return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
|
return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from)
|
||||||
|
{
|
||||||
|
return ploadquad<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const int8_t* from)
|
template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const int8_t* from)
|
||||||
{
|
{
|
||||||
Packet16c p;
|
Packet16c p;
|
||||||
@ -962,6 +1055,10 @@ template<> EIGEN_STRONG_INLINE void pstoreu<unsigned short int>(unsigned short i
|
|||||||
{
|
{
|
||||||
pstoreu_common<Packet8us>(to, from);
|
pstoreu_common<Packet8us>(to, from);
|
||||||
}
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from)
|
||||||
|
{
|
||||||
|
pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
|
||||||
|
}
|
||||||
template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet16c& from)
|
template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet16c& from)
|
||||||
{
|
{
|
||||||
pstoreu_common<Packet16c>(to, from);
|
pstoreu_common<Packet16c>(to, from);
|
||||||
@ -1025,6 +1122,10 @@ template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
|
|||||||
{
|
{
|
||||||
return vec_perm(a, a, p16uc_REVERSE8);
|
return vec_perm(a, a, p16uc_REVERSE8);
|
||||||
}
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
return preverse<Packet8us>(a);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); }
|
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); }
|
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); }
|
||||||
@ -1032,6 +1133,10 @@ template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs
|
|||||||
template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
|
template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
|
||||||
template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); }
|
template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
|
template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
|
||||||
|
_EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF);
|
||||||
|
return pand<Packet8us>(p8us_abs_mask, a);
|
||||||
|
}
|
||||||
|
|
||||||
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a)
|
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a)
|
||||||
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
||||||
@ -1039,6 +1144,175 @@ template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a)
|
|||||||
{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
||||||
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a)
|
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a)
|
||||||
{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
||||||
|
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a)
|
||||||
|
{
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||||
|
Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask);
|
||||||
|
return reinterpret_cast<Packet4f>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a)
|
||||||
|
{
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||||
|
Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask);
|
||||||
|
return reinterpret_cast<Packet4f>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a)
|
||||||
|
{
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||||
|
return vec_sr(a, p4ui_mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a)
|
||||||
|
{
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||||
|
return vec_sl(a, p4ui_mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a)
|
||||||
|
{
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
|
||||||
|
return vec_sl(a, p8us_mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){
|
||||||
|
return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
|
||||||
|
return pand<Packet4f>(
|
||||||
|
reinterpret_cast<Packet4f>(bf.m_val),
|
||||||
|
reinterpret_cast<Packet4f>(p4ui_high_mask)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
|
||||||
|
Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
|
||||||
|
Packet4ui lsb = plogical_shift_right<16>(input);
|
||||||
|
lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE));
|
||||||
|
|
||||||
|
_EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu);
|
||||||
|
Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS);
|
||||||
|
input = padd<Packet4ui>(input, rounding_bias);
|
||||||
|
|
||||||
|
//Test NaN and Subnormal - Begin
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000);
|
||||||
|
Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f));
|
||||||
|
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF);
|
||||||
|
Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f));
|
||||||
|
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000);
|
||||||
|
Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp);
|
||||||
|
Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
|
||||||
|
|
||||||
|
Packet4bi is_mant_not_zero = vec_cmpne(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO));
|
||||||
|
Packet4ui nan_selector = pand<Packet4ui>(
|
||||||
|
reinterpret_cast<Packet4ui>(is_max_exp),
|
||||||
|
reinterpret_cast<Packet4ui>(is_mant_not_zero)
|
||||||
|
);
|
||||||
|
|
||||||
|
Packet4ui subnormal_selector = pand<Packet4ui>(
|
||||||
|
reinterpret_cast<Packet4ui>(is_zero_exp),
|
||||||
|
reinterpret_cast<Packet4ui>(is_mant_not_zero)
|
||||||
|
);
|
||||||
|
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
|
||||||
|
input = vec_sel(input, p4ui_nan, nan_selector);
|
||||||
|
input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
|
||||||
|
//Test NaN and Subnormal - End
|
||||||
|
|
||||||
|
input = plogical_shift_right<16>(input);
|
||||||
|
return reinterpret_cast<Packet8us>(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
|
||||||
|
Packet4f bf_odd, bf_even;
|
||||||
|
bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val);
|
||||||
|
bf_odd = plogical_shift_left<16>(bf_odd);
|
||||||
|
bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val);
|
||||||
|
return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
|
||||||
|
}
|
||||||
|
#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \
|
||||||
|
Packet4f a_even = Bf16ToF32Even(A);\
|
||||||
|
Packet4f a_odd = Bf16ToF32Odd(A);\
|
||||||
|
Packet4f op_even = OP(a_even);\
|
||||||
|
Packet4f op_odd = OP(a_odd);\
|
||||||
|
return F32ToBf16(op_even, op_odd);\
|
||||||
|
|
||||||
|
#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \
|
||||||
|
Packet4f a_even = Bf16ToF32Even(A);\
|
||||||
|
Packet4f a_odd = Bf16ToF32Odd(A);\
|
||||||
|
Packet4f b_even = Bf16ToF32Even(B);\
|
||||||
|
Packet4f b_odd = Bf16ToF32Odd(B);\
|
||||||
|
Packet4f op_even = OP(a_even, b_even);\
|
||||||
|
Packet4f op_odd = OP(a_odd, b_odd);\
|
||||||
|
return F32ToBf16(op_even, op_odd);\
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pmul<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pdiv<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
|
||||||
|
BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(psub<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf psqrt<Packet8bf> (const Packet8bf& a){
|
||||||
|
BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||||
|
Packet4f a_even = Bf16ToF32Even(a);
|
||||||
|
Packet4f a_odd = Bf16ToF32Odd(a);
|
||||||
|
Packet4f b_even = Bf16ToF32Even(b);
|
||||||
|
Packet4f b_odd = Bf16ToF32Odd(b);
|
||||||
|
Packet4f c_even = Bf16ToF32Even(c);
|
||||||
|
Packet4f c_odd = Bf16ToF32Odd(c);
|
||||||
|
Packet4f pmadd_even = pmadd<Packet4f>(a_even, b_even, c_even);
|
||||||
|
Packet4f pmadd_odd = pmadd<Packet4f>(a_odd, b_odd, c_odd);
|
||||||
|
return F32ToBf16(pmadd_even, pmadd_odd);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pmax<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq<Packet4f>, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) {
|
||||||
|
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst<Packet8us>(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from)
|
||||||
|
{
|
||||||
|
return ploaddup<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
|
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
|
||||||
return pfrexp_float(a,exponent);
|
return pfrexp_float(a,exponent);
|
||||||
@ -1070,6 +1344,13 @@ template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
|
|||||||
return pfirst(sum);
|
return pfirst(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
float redux_even = predux<Packet4f>(Bf16ToF32Even(a));
|
||||||
|
float redux_odd = predux<Packet4f>(Bf16ToF32Odd(a));
|
||||||
|
float f32_result = redux_even + redux_odd;
|
||||||
|
return bfloat16(f32_result);
|
||||||
|
}
|
||||||
template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a)
|
template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a)
|
||||||
{
|
{
|
||||||
union{
|
union{
|
||||||
@ -1166,6 +1447,15 @@ template<> EIGEN_STRONG_INLINE unsigned short int predux_mul<Packet8us>(const Pa
|
|||||||
return pfirst(octo);
|
return pfirst(octo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
float redux_even = predux_mul<Packet4f>(Bf16ToF32Even(a));
|
||||||
|
float redux_odd = predux_mul<Packet4f>(Bf16ToF32Odd(a));
|
||||||
|
float f32_result = redux_even * redux_odd;
|
||||||
|
return bfloat16(f32_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet16c>(const Packet16c& a)
|
template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet16c>(const Packet16c& a)
|
||||||
{
|
{
|
||||||
Packet16c pair, quad, octo, result;
|
Packet16c pair, quad, octo, result;
|
||||||
@ -1211,6 +1501,14 @@ template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a)
|
|||||||
return predux_min4<Packet4i>(a);
|
return predux_min4<Packet4i>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
float redux_even = predux_min<Packet4f>(Bf16ToF32Even(a));
|
||||||
|
float redux_odd = predux_min<Packet4f>(Bf16ToF32Odd(a));
|
||||||
|
float f32_result = (std::min)(redux_even, redux_odd);
|
||||||
|
return bfloat16(f32_result);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE short int predux_min<Packet8s>(const Packet8s& a)
|
template<> EIGEN_STRONG_INLINE short int predux_min<Packet8s>(const Packet8s& a)
|
||||||
{
|
{
|
||||||
Packet8s pair, quad, octo;
|
Packet8s pair, quad, octo;
|
||||||
@ -1283,6 +1581,14 @@ template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
|
|||||||
return predux_max4<Packet4i>(a);
|
return predux_max4<Packet4i>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
float redux_even = predux_max<Packet4f>(Bf16ToF32Even(a));
|
||||||
|
float redux_odd = predux_max<Packet4f>(Bf16ToF32Odd(a));
|
||||||
|
float f32_result = (std::max)(redux_even, redux_odd);
|
||||||
|
return bfloat16(f32_result);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE short int predux_max<Packet8s>(const Packet8s& a)
|
template<> EIGEN_STRONG_INLINE short int predux_max<Packet8s>(const Packet8s& a)
|
||||||
{
|
{
|
||||||
Packet8s pair, quad, octo;
|
Packet8s pair, quad, octo;
|
||||||
@ -1391,6 +1697,21 @@ ptranspose(PacketBlock<Packet8us,4>& kernel) {
|
|||||||
kernel.packet[3] = vec_mergel(t1, t3);
|
kernel.packet[3] = vec_mergel(t1, t3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC inline void
|
||||||
|
ptranspose(PacketBlock<Packet8bf,4>& kernel) {
|
||||||
|
Packet8us t0, t1, t2, t3;
|
||||||
|
|
||||||
|
t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val);
|
||||||
|
t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val);
|
||||||
|
t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val);
|
||||||
|
t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val);
|
||||||
|
kernel.packet[0] = vec_mergeh(t0, t2);
|
||||||
|
kernel.packet[1] = vec_mergel(t0, t2);
|
||||||
|
kernel.packet[2] = vec_mergeh(t1, t3);
|
||||||
|
kernel.packet[3] = vec_mergel(t1, t3);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC inline void
|
EIGEN_DEVICE_FUNC inline void
|
||||||
ptranspose(PacketBlock<Packet16c,4>& kernel) {
|
ptranspose(PacketBlock<Packet16c,4>& kernel) {
|
||||||
Packet16c t0, t1, t2, t3;
|
Packet16c t0, t1, t2, t3;
|
||||||
@ -1480,6 +1801,37 @@ ptranspose(PacketBlock<Packet8us,8>& kernel) {
|
|||||||
kernel.packet[7] = vec_mergel(sum[3], sum[7]);
|
kernel.packet[7] = vec_mergel(sum[3], sum[7]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC inline void
|
||||||
|
ptranspose(PacketBlock<Packet8bf,8>& kernel) {
|
||||||
|
Packet8bf v[8], sum[8];
|
||||||
|
|
||||||
|
v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val);
|
||||||
|
v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val);
|
||||||
|
v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val);
|
||||||
|
v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val);
|
||||||
|
v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val);
|
||||||
|
v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val);
|
||||||
|
v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val);
|
||||||
|
v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val);
|
||||||
|
sum[0] = vec_mergeh(v[0].m_val, v[4].m_val);
|
||||||
|
sum[1] = vec_mergel(v[0].m_val, v[4].m_val);
|
||||||
|
sum[2] = vec_mergeh(v[1].m_val, v[5].m_val);
|
||||||
|
sum[3] = vec_mergel(v[1].m_val, v[5].m_val);
|
||||||
|
sum[4] = vec_mergeh(v[2].m_val, v[6].m_val);
|
||||||
|
sum[5] = vec_mergel(v[2].m_val, v[6].m_val);
|
||||||
|
sum[6] = vec_mergeh(v[3].m_val, v[7].m_val);
|
||||||
|
sum[7] = vec_mergel(v[3].m_val, v[7].m_val);
|
||||||
|
|
||||||
|
kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val);
|
||||||
|
kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val);
|
||||||
|
kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val);
|
||||||
|
kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val);
|
||||||
|
kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val);
|
||||||
|
kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val);
|
||||||
|
kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val);
|
||||||
|
kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC inline void
|
EIGEN_DEVICE_FUNC inline void
|
||||||
ptranspose(PacketBlock<Packet16c,16>& kernel) {
|
ptranspose(PacketBlock<Packet16c,16>& kernel) {
|
||||||
Packet16c step1[16], step2[16], step3[16];
|
Packet16c step1[16], step2[16], step3[16];
|
||||||
@ -1656,6 +2008,10 @@ template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, con
|
|||||||
return vec_sel(elsePacket, thenPacket, mask);
|
return vec_sel(elsePacket, thenPacket, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) {
|
||||||
|
return pblend<Packet8us>(ifPacket, thenPacket, elsePacket);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) {
|
template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) {
|
||||||
Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
|
Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
|
||||||
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
|
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
|
||||||
@ -1694,15 +2050,78 @@ struct type_casting_traits<int, float> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_casting_traits<bfloat16, unsigned short int> {
|
||||||
|
enum {
|
||||||
|
VectorizedCast = 1,
|
||||||
|
SrcCoeffRatio = 1,
|
||||||
|
TgtCoeffRatio = 1
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_casting_traits<unsigned short int, bfloat16> {
|
||||||
|
enum {
|
||||||
|
VectorizedCast = 1,
|
||||||
|
SrcCoeffRatio = 1,
|
||||||
|
TgtCoeffRatio = 1
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
|
template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
|
||||||
return vec_cts(a,0);
|
return vec_cts(a,0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
|
||||||
|
return vec_ctu(a,0);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
|
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
|
||||||
return vec_ctf(a,0);
|
return vec_ctf(a,0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
|
||||||
|
return vec_ctf(a,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
|
||||||
|
Packet4f float_even = Bf16ToF32Even(a);
|
||||||
|
Packet4f float_odd = Bf16ToF32Odd(a);
|
||||||
|
Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
|
||||||
|
Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
|
||||||
|
Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
|
||||||
|
Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
|
||||||
|
|
||||||
|
//Check values that are bigger than USHRT_MAX (0xFFFF)
|
||||||
|
Packet4bi overflow_selector;
|
||||||
|
if(vec_any_gt(int_even, p4ui_low_mask)){
|
||||||
|
overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
|
||||||
|
low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||||
|
}
|
||||||
|
if(vec_any_gt(int_odd, p4ui_low_mask)){
|
||||||
|
overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
|
||||||
|
low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||||
|
}
|
||||||
|
|
||||||
|
low_odd = plogical_shift_left<16>(low_odd);
|
||||||
|
|
||||||
|
Packet4ui int_final = por<Packet4ui>(low_even, low_odd);
|
||||||
|
return reinterpret_cast<Packet8us>(int_final);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
|
||||||
|
//short -> int -> float -> bfloat16
|
||||||
|
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
|
||||||
|
Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
|
||||||
|
Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
|
||||||
|
Packet4ui int_odd = plogical_shift_right<16>(int_cast);
|
||||||
|
Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
|
||||||
|
Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
|
||||||
|
return F32ToBf16(float_even, float_odd);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
|
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
|
||||||
return reinterpret_cast<Packet4i>(a);
|
return reinterpret_cast<Packet4i>(a);
|
||||||
}
|
}
|
||||||
@ -2024,6 +2443,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, cons
|
|||||||
Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
|
Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
|
||||||
return vec_sel(elsePacket, thenPacket, mask);
|
return vec_sel(elsePacket, thenPacket, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // __VSX__
|
#endif // __VSX__
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
@ -247,6 +247,20 @@ void packetmath_boolean_mask_ops() {
|
|||||||
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
|
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
|
||||||
}
|
}
|
||||||
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
|
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
|
||||||
|
|
||||||
|
//Test (-0) == (0) for signed operations
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
data1[i] = Scalar(-0.0);
|
||||||
|
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
|
||||||
|
}
|
||||||
|
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
|
||||||
|
|
||||||
|
//Test NaN
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
data1[i] = std::numeric_limits<Scalar>::quiet_NaN();
|
||||||
|
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
|
||||||
|
}
|
||||||
|
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path
|
// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path
|
||||||
@ -254,6 +268,22 @@ void packetmath_boolean_mask_ops() {
|
|||||||
template<>
|
template<>
|
||||||
void packetmath_boolean_mask_ops<bool, internal::packet_traits<bool>::type>() {}
|
void packetmath_boolean_mask_ops<bool, internal::packet_traits<bool>::type>() {}
|
||||||
|
|
||||||
|
template <typename Scalar, typename Packet>
|
||||||
|
void packetmath_minus_zero_add() {
|
||||||
|
const int PacketSize = internal::unpacket_traits<Packet>::size;
|
||||||
|
const int size = 2 * PacketSize;
|
||||||
|
EIGEN_ALIGN_MAX Scalar data1[size];
|
||||||
|
EIGEN_ALIGN_MAX Scalar data2[size];
|
||||||
|
EIGEN_ALIGN_MAX Scalar ref[size];
|
||||||
|
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
data1[i] = Scalar(-0.0);
|
||||||
|
data1[i + PacketSize] = Scalar(-0.0);
|
||||||
|
}
|
||||||
|
CHECK_CWISE2_IF(internal::packet_traits<Scalar>::HasAdd, REF_ADD, internal::padd);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename Scalar, typename Packet>
|
template <typename Scalar, typename Packet>
|
||||||
void packetmath() {
|
void packetmath() {
|
||||||
typedef internal::packet_traits<Scalar> PacketTraits;
|
typedef internal::packet_traits<Scalar> PacketTraits;
|
||||||
@ -454,6 +484,7 @@ void packetmath() {
|
|||||||
|
|
||||||
packetmath_boolean_mask_ops<Scalar, Packet>();
|
packetmath_boolean_mask_ops<Scalar, Packet>();
|
||||||
packetmath_pcast_ops_runner<Scalar, Packet>::run();
|
packetmath_pcast_ops_runner<Scalar, Packet>::run();
|
||||||
|
packetmath_minus_zero_add<Scalar, Packet>();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar, typename Packet>
|
template <typename Scalar, typename Packet>
|
||||||
|
Loading…
Reference in New Issue
Block a user