mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-06 14:14:46 +08:00
Fix undefine BF16 union behavior in AVX512.
This commit is contained in:
parent
b92206676c
commit
3ec4f0b641
@ -32,6 +32,7 @@ typedef __m512 Packet16f;
|
||||
typedef __m512i Packet16i;
|
||||
typedef __m512d Packet8d;
|
||||
typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
|
||||
typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
|
||||
|
||||
template <>
|
||||
struct is_arithmetic<__m512> {
|
||||
@ -1620,13 +1621,6 @@ ptranspose(PacketBlock<Packet16h,4>& kernel) {
|
||||
kernel.packet[3] = pload<Packet16h>(out[3]);
|
||||
}
|
||||
|
||||
typedef union {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512BF16
|
||||
__m256bh bh;
|
||||
#endif
|
||||
Packet8i i; // __m256i;
|
||||
} Packet16bf;
|
||||
|
||||
template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
|
||||
|
||||
template <>
|
||||
@ -1673,42 +1667,36 @@ struct unpacket_traits<Packet16bf>
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
|
||||
Packet16bf r;
|
||||
r.i = _mm256_set1_epi16(from.value);
|
||||
return r;
|
||||
return _mm256_set1_epi16(from.value);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
|
||||
bfloat16 t;
|
||||
t.value = static_cast<unsigned short>(_mm256_extract_epi16(from.i, 0));
|
||||
t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0));
|
||||
return t;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
|
||||
Packet16bf r;
|
||||
r.i = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
|
||||
return r;
|
||||
return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
|
||||
Packet16bf r;
|
||||
r.i = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
|
||||
return r;
|
||||
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
|
||||
const Packet16bf& from) {
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i);
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
|
||||
const Packet16bf& from) {
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16bf
|
||||
@ -1722,8 +1710,7 @@ ploaddup<Packet16bf>(const bfloat16* from) {
|
||||
unsigned short f = from[5].value;
|
||||
unsigned short g = from[6].value;
|
||||
unsigned short h = from[7].value;
|
||||
r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
|
||||
return r;
|
||||
return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16bf
|
||||
@ -1733,12 +1720,11 @@ ploadquad(const bfloat16* from) {
|
||||
unsigned short b = from[1].value;
|
||||
unsigned short c = from[2].value;
|
||||
unsigned short d = from[3].value;
|
||||
r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
|
||||
return r;
|
||||
return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
|
||||
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
|
||||
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
|
||||
}
|
||||
|
||||
// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
|
||||
@ -1754,8 +1740,11 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||
#endif // EIGEN_VECTORIZE_AVX512DQ
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
||||
|
||||
#if defined(EIGEN_VECTORIZE_AVX512BF16)
|
||||
r.bh = _mm512_cvtneps_pbh(flush);
|
||||
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
|
||||
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
|
||||
// (C++ static_cast is not supported yet), do converion via intrinsic
|
||||
// and register path for performance.
|
||||
r = (__m256i)(_mm512_cvtneps_pbh(flush));
|
||||
#else
|
||||
__m512i t;
|
||||
__m512i input = _mm512_castps_si512(flush);
|
||||
@ -1775,7 +1764,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||
t = _mm512_mask_blend_epi32(mask, nan, t);
|
||||
|
||||
// output.value = static_cast<uint16_t>(input);
|
||||
r.i = _mm512_cvtepi32_epi16(t);
|
||||
r = _mm512_cvtepi32_epi16(t);
|
||||
#endif // EIGEN_VECTORIZE_AVX512BF16
|
||||
|
||||
return r;
|
||||
@ -1783,38 +1772,28 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
|
||||
Packet16bf r;
|
||||
r.i = ptrue<Packet8i>(a.i);
|
||||
return r;
|
||||
return ptrue<Packet8i>(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = por<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
return por<Packet8i>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = pxor<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
return pxor<Packet8i>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = pand<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
return pand<Packet8i>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = pandnot<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
return pandnot<Packet8i>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -1823,50 +1802,39 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
|
||||
const Packet16bf& b) {
|
||||
// Input mask is expected to be all 0/1, handle it with 8-bit
|
||||
// intrinsic for performance.
|
||||
Packet16bf r;
|
||||
r.i = _mm256_blendv_epi8(b.i, a.i, mask.i);
|
||||
return r;
|
||||
return _mm256_blendv_epi8(b, a, mask);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
|
||||
Packet16bf sign_mask;
|
||||
sign_mask.i = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
|
||||
sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
|
||||
Packet16bf result;
|
||||
result.i = _mm256_xor_si256(a.i, sign_mask.i);
|
||||
return result;
|
||||
return _mm256_xor_si256(a, sign_mask);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -1917,8 +1885,8 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
|
||||
Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
|
||||
Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
|
||||
Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
|
||||
Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
|
||||
return padd<Packet8bf>(lane0, lane1);
|
||||
}
|
||||
|
||||
@ -1949,22 +1917,19 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
|
||||
|
||||
Packet16bf res;
|
||||
// Swap hi and lo first because shuffle is in 128-bit lanes.
|
||||
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
|
||||
res = _mm256_permute2x128_si256(a, a, 1);
|
||||
// Shuffle 8-bit values in src within 2*128-bit lanes.
|
||||
res.i = _mm256_shuffle_epi8(res.i, m);
|
||||
return res;
|
||||
return _mm256_shuffle_epi8(res, m);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
|
||||
Index stride) {
|
||||
Packet16bf result;
|
||||
result.i = _mm256_set_epi16(
|
||||
return _mm256_set_epi16(
|
||||
from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
|
||||
from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
|
||||
from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
|
||||
from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -1992,22 +1957,22 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
|
||||
__m256i a = kernel.packet[0].i;
|
||||
__m256i b = kernel.packet[1].i;
|
||||
__m256i c = kernel.packet[2].i;
|
||||
__m256i d = kernel.packet[3].i;
|
||||
__m256i e = kernel.packet[4].i;
|
||||
__m256i f = kernel.packet[5].i;
|
||||
__m256i g = kernel.packet[6].i;
|
||||
__m256i h = kernel.packet[7].i;
|
||||
__m256i i = kernel.packet[8].i;
|
||||
__m256i j = kernel.packet[9].i;
|
||||
__m256i k = kernel.packet[10].i;
|
||||
__m256i l = kernel.packet[11].i;
|
||||
__m256i m = kernel.packet[12].i;
|
||||
__m256i n = kernel.packet[13].i;
|
||||
__m256i o = kernel.packet[14].i;
|
||||
__m256i p = kernel.packet[15].i;
|
||||
__m256i a = kernel.packet[0];
|
||||
__m256i b = kernel.packet[1];
|
||||
__m256i c = kernel.packet[2];
|
||||
__m256i d = kernel.packet[3];
|
||||
__m256i e = kernel.packet[4];
|
||||
__m256i f = kernel.packet[5];
|
||||
__m256i g = kernel.packet[6];
|
||||
__m256i h = kernel.packet[7];
|
||||
__m256i i = kernel.packet[8];
|
||||
__m256i j = kernel.packet[9];
|
||||
__m256i k = kernel.packet[10];
|
||||
__m256i l = kernel.packet[11];
|
||||
__m256i m = kernel.packet[12];
|
||||
__m256i n = kernel.packet[13];
|
||||
__m256i o = kernel.packet[14];
|
||||
__m256i p = kernel.packet[15];
|
||||
|
||||
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
|
||||
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
|
||||
@ -2063,29 +2028,29 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
|
||||
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
|
||||
|
||||
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
|
||||
kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
|
||||
kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
|
||||
kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
|
||||
kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
|
||||
kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
|
||||
kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
|
||||
kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
|
||||
kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
|
||||
kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
|
||||
kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
|
||||
kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
|
||||
kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
|
||||
kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
|
||||
kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
|
||||
kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
|
||||
kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
|
||||
kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
|
||||
kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
|
||||
kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
|
||||
kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
|
||||
kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
|
||||
kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
|
||||
kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
|
||||
kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
|
||||
kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
|
||||
kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
|
||||
kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
|
||||
kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
|
||||
kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
|
||||
kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
|
||||
kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
|
||||
kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
|
||||
__m256i a = kernel.packet[0].i;
|
||||
__m256i b = kernel.packet[1].i;
|
||||
__m256i c = kernel.packet[2].i;
|
||||
__m256i d = kernel.packet[3].i;
|
||||
__m256i a = kernel.packet[0];
|
||||
__m256i b = kernel.packet[1];
|
||||
__m256i c = kernel.packet[2];
|
||||
__m256i d = kernel.packet[3];
|
||||
|
||||
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
|
||||
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
|
||||
@ -2098,10 +2063,10 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
|
||||
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
|
||||
|
||||
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
|
||||
kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
|
||||
kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
|
||||
kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
|
||||
kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
|
||||
kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
|
||||
kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
|
||||
kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
|
||||
kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
Loading…
Reference in New Issue
Block a user