BF16 for scalar_cmp_with_cast_op

This commit is contained in:
Sheng Yang 2020-07-01 18:33:42 +00:00 committed by Rasmus Munk Larsen
parent 8731452b97
commit 116c5235ac
5 changed files with 11 additions and 1 deletions

View File

@ -58,6 +58,9 @@ struct default_packet_traits
HasConj = 1,
HasSetLinear = 1,
HasBlend = 0,
// This flag is used to indicate whether packet comparison is supported.
// pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
HasCmp = 0,
HasDiv = 0,
HasSqrt = 0,

View File

@ -63,6 +63,7 @@ template<> struct packet_traits<float> : default_packet_traits
size = 8,
HasHalfPacket = 1,
HasCmp = 1,
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@ -93,6 +94,7 @@ template<> struct packet_traits<double> : default_packet_traits
size=4,
HasHalfPacket = 1,
HasCmp = 1,
HasDiv = 1,
HasExp = 1,
HasSqrt = 1,

View File

@ -103,6 +103,7 @@ template<> struct packet_traits<float> : default_packet_traits
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
HasDiv = 1
};
};
@ -119,6 +120,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
HasDiv = 1
};
};
@ -1656,6 +1658,7 @@ struct packet_traits<bfloat16> : default_packet_traits {
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
#endif
HasCmp = 1,
HasDiv = 1
};
};

View File

@ -94,6 +94,7 @@ struct packet_traits<float> : default_packet_traits {
size = 4,
HasHalfPacket = 0,
HasCmp = 1,
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@ -128,6 +129,7 @@ struct packet_traits<double> : default_packet_traits {
size=2,
HasHalfPacket = 0,
HasCmp = 1,
HasDiv = 1,
HasExp = 1,
HasSqrt = 1,

View File

@ -260,7 +260,7 @@ template<typename LhsScalar, typename RhsScalar, ComparisonName cmp>
struct functor_traits<scalar_cmp_with_cast_op<LhsScalar,RhsScalar, cmp> > {
enum {
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && internal::is_same<LhsScalar, float>::value
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasCmp
};
};