diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 2b6693eed..76f3366d7 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -32,6 +32,7 @@ typedef __m512 Packet16f; typedef __m512i Packet16i; typedef __m512d Packet8d; typedef eigen_packet_wrapper<__m256i, 1> Packet16h; +typedef eigen_packet_wrapper<__m256i, 2> Packet16bf; template <> struct is_arithmetic<__m512> { @@ -1620,13 +1621,6 @@ ptranspose(PacketBlock& kernel) { kernel.packet[3] = pload(out[3]); } -typedef union { -#ifdef EIGEN_VECTORIZE_AVX512BF16 - __m256bh bh; -#endif - Packet8i i; // __m256i; -} Packet16bf; - template <> struct is_arithmetic { enum { value = true }; }; template <> @@ -1673,42 +1667,36 @@ struct unpacket_traits template <> EIGEN_STRONG_INLINE Packet16bf pset1(const bfloat16& from) { - Packet16bf r; - r.i = _mm256_set1_epi16(from.value); - return r; + return _mm256_set1_epi16(from.value); } template <> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet16bf& from) { bfloat16 t; - t.value = static_cast(_mm256_extract_epi16(from.i, 0)); + t.value = static_cast(_mm256_extract_epi16(from, 0)); return t; } template <> EIGEN_STRONG_INLINE Packet16bf pload(const bfloat16* from) { - Packet16bf r; - r.i = _mm256_load_si256(reinterpret_cast(from)); - return r; + return _mm256_load_si256(reinterpret_cast(from)); } template <> EIGEN_STRONG_INLINE Packet16bf ploadu(const bfloat16* from) { - Packet16bf r; - r.i = _mm256_loadu_si256(reinterpret_cast(from)); - return r; + return _mm256_loadu_si256(reinterpret_cast(from)); } template <> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet16bf& from) { - _mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i); + _mm256_store_si256(reinterpret_cast<__m256i*>(to), from); } template <> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet16bf& from) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } template<> EIGEN_STRONG_INLINE Packet16bf @@ -1722,8 +1710,7 @@ ploaddup(const bfloat16* from) { unsigned short f = from[5].value; unsigned short g = from[6].value; unsigned short h = from[7].value; - r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); - return r; + return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); } template<> EIGEN_STRONG_INLINE Packet16bf @@ -1733,12 +1720,11 @@ ploadquad(const bfloat16* from) { unsigned short b = from[1].value; unsigned short c = from[2].value; unsigned short d = from[3].value; - r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); - return r; + return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); } EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { - return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16)); + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); } // Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. @@ -1754,8 +1740,11 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { #endif // EIGEN_VECTORIZE_AVX512DQ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); -#if defined(EIGEN_VECTORIZE_AVX512BF16) - r.bh = _mm512_cvtneps_pbh(flush); +#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1) + // Since GCC 10.1 supports avx512bf16 and C style explicit cast + // (C++ static_cast is not supported yet), do converion via intrinsic + // and register path for performance. + r = (__m256i)(_mm512_cvtneps_pbh(flush)); #else __m512i t; __m512i input = _mm512_castps_si512(flush); @@ -1775,7 +1764,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { t = _mm512_mask_blend_epi32(mask, nan, t); // output.value = static_cast(input); - r.i = _mm512_cvtepi32_epi16(t); + r = _mm512_cvtepi32_epi16(t); #endif // EIGEN_VECTORIZE_AVX512BF16 return r; @@ -1783,38 +1772,28 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { template <> EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) { - Packet16bf r; - r.i = ptrue(a.i); - return r; + return ptrue(a); } template <> EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) { - Packet16bf r; - r.i = por(a.i, b.i); - return r; + return por(a, b); } template <> EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) { - Packet16bf r; - r.i = pxor(a.i, b.i); - return r; + return pxor(a, b); } template <> EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) { - Packet16bf r; - r.i = pand(a.i, b.i); - return r; + return pand(a, b); } template <> EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, const Packet16bf& b) { - Packet16bf r; - r.i = pandnot(a.i, b.i); - return r; + return pandnot(a, b); } template <> @@ -1823,50 +1802,39 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, const Packet16bf& b) { // Input mask is expected to be all 0/1, handle it with 8-bit // intrinsic for performance. - Packet16bf r; - r.i = _mm256_blendv_epi8(b.i, a.i, mask.i); - return r; + return _mm256_blendv_epi8(b, a, mask); } template <> EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) { - Packet16bf result; - result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); - return result; + return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a, const Packet16bf& b) { - Packet16bf result; - result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); - return result; + return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a, const Packet16bf& b) { - Packet16bf result; - result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); - return result; + return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, const Packet16bf& b) { - Packet16bf result; - result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); - return result; + return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { Packet16bf sign_mask; - sign_mask.i = _mm256_set1_epi16(static_cast(0x8000)); + sign_mask = _mm256_set1_epi16(static_cast(0x8000)); Packet16bf result; - result.i = _mm256_xor_si256(a.i, sign_mask.i); - return result; + return _mm256_xor_si256(a, sign_mask); } template <> @@ -1917,8 +1885,8 @@ EIGEN_STRONG_INLINE Packet16bf pmax(const Packet16bf& a, template <> EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4(const Packet16bf& a) { - Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0); - Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1); + Packet8bf lane0 = _mm256_extractf128_si256(a, 0); + Packet8bf lane1 = _mm256_extractf128_si256(a, 1); return padd(lane0, lane1); } @@ -1949,22 +1917,19 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { Packet16bf res; // Swap hi and lo first because shuffle is in 128-bit lanes. - res.i = _mm256_permute2x128_si256(a.i, a.i, 1); + res = _mm256_permute2x128_si256(a, a, 1); // Shuffle 8-bit values in src within 2*128-bit lanes. - res.i = _mm256_shuffle_epi8(res.i, m); - return res; + return _mm256_shuffle_epi8(res, m); } template <> EIGEN_STRONG_INLINE Packet16bf pgather(const bfloat16* from, Index stride) { - Packet16bf result; - result.i = _mm256_set_epi16( + return _mm256_set_epi16( from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value, from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value, from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); - return result; } template <> @@ -1992,22 +1957,22 @@ EIGEN_STRONG_INLINE void pscatter(bfloat16* to, } EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - __m256i a = kernel.packet[0].i; - __m256i b = kernel.packet[1].i; - __m256i c = kernel.packet[2].i; - __m256i d = kernel.packet[3].i; - __m256i e = kernel.packet[4].i; - __m256i f = kernel.packet[5].i; - __m256i g = kernel.packet[6].i; - __m256i h = kernel.packet[7].i; - __m256i i = kernel.packet[8].i; - __m256i j = kernel.packet[9].i; - __m256i k = kernel.packet[10].i; - __m256i l = kernel.packet[11].i; - __m256i m = kernel.packet[12].i; - __m256i n = kernel.packet[13].i; - __m256i o = kernel.packet[14].i; - __m256i p = kernel.packet[15].i; + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + __m256i e = kernel.packet[4]; + __m256i f = kernel.packet[5]; + __m256i g = kernel.packet[6]; + __m256i h = kernel.packet[7]; + __m256i i = kernel.packet[8]; + __m256i j = kernel.packet[9]; + __m256i k = kernel.packet[10]; + __m256i l = kernel.packet[11]; + __m256i m = kernel.packet[12]; + __m256i n = kernel.packet[13]; + __m256i o = kernel.packet[14]; + __m256i p = kernel.packet[15]; __m256i ab_07 = _mm256_unpacklo_epi16(a, b); __m256i cd_07 = _mm256_unpacklo_epi16(c, d); @@ -2063,29 +2028,29 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); // NOTE: no unpacklo/hi instr in this case, so using permute instr. - kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); - kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); - kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); - kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); - kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); - kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); - kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); - kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); - kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); - kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); - kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); - kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); - kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); - kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); - kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); - kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); + kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); } EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - __m256i a = kernel.packet[0].i; - __m256i b = kernel.packet[1].i; - __m256i c = kernel.packet[2].i; - __m256i d = kernel.packet[3].i; + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; __m256i ab_07 = _mm256_unpacklo_epi16(a, b); __m256i cd_07 = _mm256_unpacklo_epi16(c, d); @@ -2098,10 +2063,10 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); // NOTE: no unpacklo/hi instr in this case, so using permute instr. - kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20); - kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20); - kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31); - kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31); + kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20); + kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20); + kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31); + kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31); } } // end namespace internal