Various fixes for packet ops.

1. Fix buggy pcmp_eq and unit test for half types.
2. Add unit test for pselect and add specializations for SSE 4.1, AVX512, and half types.
3. Get rid of FIXME: Implement faster pnegate for half by XOR'ing with a sign bit mask.
This commit is contained in:
Rasmus Munk Larsen 2019-06-20 11:47:49 -07:00
parent e0be7f30e1
commit 988f24b730
4 changed files with 89 additions and 13 deletions

View File

@ -252,6 +252,24 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b,
}
#endif
template <>
EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask,
const Packet16f& a,
const Packet16f& b) {
__mmask16 mask16 = _mm512_cmp_epi32_mask(
_mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
return _mm512_mask_blend_ps(mask16, a, b);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask,
const Packet8d& a,
const Packet8d& b) {
__mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask),
_mm512_setzero_epi32(), _MM_CMPINT_EQ);
return _mm512_mask_blend_pd(mask8, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a,
const Packet16f& b) {

View File

@ -176,6 +176,15 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset<half2>(const Eigen:
#endif
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect<half2>(const half2& mask,
const half2& a,
const half2& b) {
half result_low = __low2half(mask) == half(0) ? __low2half(b) : __low2half(a);
half result_high = __high2half(mask) == half(0) ? __high2half(b) : __high2half(a);
return __halves2half2(result_low, result_high);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq<half2>(const half2& a,
const half2& b) {
@ -726,18 +735,29 @@ template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet
Packet16h r; r.x = pandnot(Packet8i(a.x),Packet8i(b.x)); return r;
}
template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
Packet16h r; r.x = _mm256_blendv_epi8(b.x, a.x, mask.x); return r;
}
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = pcmp_eq(af, bf);
return float2half(rf);
// Pack the 32-bit flags into 16-bits flags.
__m256i lo = _mm256_castps_si256(extract256<0>(rf));
__m256i hi = _mm256_castps_si256(extract256<1>(rf));
__m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
_mm256_extractf128_si256(lo, 1));
__m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
_mm256_extractf128_si256(hi, 1));
Packet16h result; result.x = _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
return result;
}
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
// FIXME we could do that with bit manipulation
Packet16f af = half2float(a);
Packet16f rf = pnegate(af);
return float2half(rf);
Packet16h sign_mask; sign_mask.x = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
Packet16h result; result.x = _mm256_xor_si256(a.x, sign_mask.x);
return result;
}
template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
@ -1182,20 +1202,26 @@ template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h
Packet8h r; r.x = _mm_andnot_si128(b.x,a.x); return r;
}
template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
Packet8h r; r.x = _mm_blendv_epi8(b.x, a.x, mask.x); return r;
}
template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
Packet8f af = half2float(a);
Packet8f bf = half2float(b);
Packet8f rf = pcmp_eq(af, bf);
return float2half(rf);
// Pack the 32-bit flags into 16-bits flags.
Packet8h result; result.x = _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
_mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
return result;
}
template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) {
// FIXME we could do that with bit manipulation
Packet8f af = half2float(a);
Packet8f rf = pnegate(af);
return float2half(rf);
Packet8h sign_mask; sign_mask.x = _mm_set1_epi16(static_cast<unsigned short>(0x8000));
Packet8h result; result.x = _mm_xor_si128(a.x, sign_mask.x);
return result;
}
template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {

View File

@ -273,6 +273,12 @@ template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f&
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); }
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1
template<> EIGEN_DEVICE_FUNC inline Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) { return _mm_blendv_ps(b,a,mask); }
template<> EIGEN_DEVICE_FUNC inline Packet2d pselect(const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return _mm_blendv_pd(b,a,mask); }
#endif
template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) {
#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
// There appears to be a bug in GCC, by which the optimizer may

View File

@ -166,6 +166,16 @@ struct packet_helper<false,Packet>
VERIFY(areApprox(ref, data2, PacketSize) && #POP); \
}
#define CHECK_CWISE3_IF(COND, REFOP, POP) if (COND) { \
packet_helper<COND, Packet> h; \
for (int i = 0; i < PacketSize; ++i) \
ref[i] = \
REFOP(data1[i], data1[i + PacketSize], data1[i + 2 * PacketSize]); \
h.store(data2, POP(h.load(data1), h.load(data1 + PacketSize), \
h.load(data1 + 2 * PacketSize))); \
VERIFY(areApprox(ref, data2, PacketSize) && #POP); \
}
#define REF_ADD(a,b) ((a)+(b))
#define REF_SUB(a,b) ((a)-(b))
#define REF_MUL(a,b) ((a)*(b))
@ -447,19 +457,35 @@ template<typename Scalar,typename Packet> void packetmath()
data1[i] = internal::random<Scalar>();
unsigned char v = internal::random<bool>() ? 0xff : 0;
char* bytes = (char*)(data1+PacketSize+i);
for(int k=0; k<int(sizeof(Scalar)); ++k)
for(int k=0; k<int(sizeof(Scalar)); ++k) {
bytes[k] = v;
}
}
CHECK_CWISE2_IF(true, internal::por, internal::por);
CHECK_CWISE2_IF(true, internal::pxor, internal::pxor);
CHECK_CWISE2_IF(true, internal::pand, internal::pand);
CHECK_CWISE2_IF(true, internal::pandnot, internal::pandnot);
}
{
for (int i = 0; i < PacketSize; ++i) {
// "if" mask
unsigned char v = internal::random<bool>() ? 0xff : 0;
char* bytes = (char*)(data1+i);
for(int k=0; k<int(sizeof(Scalar)); ++k) {
bytes[k] = v;
}
// "then" packet
data1[i+PacketSize] = internal::random<Scalar>();
// "else" packet
data1[i+2*PacketSize] = internal::random<Scalar>();
}
CHECK_CWISE3_IF(true, internal::pselect, internal::pselect);
}
{
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<Scalar>();
data2[i] = (i % 2) ? data1[i] : Scalar(0);
data1[i] = Scalar(i);
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
}
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
}