mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 19:40:45 +08:00
AVX path for BF16
This commit is contained in:
parent
4ab32e2de2
commit
56b3e3f3f8
@ -239,6 +239,12 @@ if(NOT MSVC)
|
|||||||
message(STATUS "Enabling FMA in tests/examples")
|
message(STATUS "Enabling FMA in tests/examples")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
option(EIGEN_TEST_AVX2 "Enable/Disable AVX2 in tests/examples" OFF)
|
||||||
|
if(EIGEN_TEST_AVX2)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma")
|
||||||
|
message(STATUS "Enabling AVX2 in tests/examples")
|
||||||
|
endif()
|
||||||
|
|
||||||
option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF)
|
option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF)
|
||||||
if(EIGEN_TEST_AVX512)
|
if(EIGEN_TEST_AVX512)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma")
|
||||||
|
@ -58,15 +58,15 @@ pexp<Packet8f>(const Packet8f& _x) {
|
|||||||
// Hyperbolic Tangent function.
|
// Hyperbolic Tangent function.
|
||||||
template <>
|
template <>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
|
||||||
ptanh<Packet8f>(const Packet8f& x) {
|
ptanh<Packet8f>(const Packet8f& _x) {
|
||||||
return internal::generic_fast_tanh_float(x);
|
return internal::generic_fast_tanh_float(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exponential function for doubles.
|
// Exponential function for doubles.
|
||||||
template <>
|
template <>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
|
||||||
pexp<Packet4d>(const Packet4d& x) {
|
pexp<Packet4d>(const Packet4d& _x) {
|
||||||
return pexp_double(x);
|
return pexp_double(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functions for sqrt.
|
// Functions for sqrt.
|
||||||
@ -96,13 +96,13 @@ psqrt<Packet8f>(const Packet8f& _x) {
|
|||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||||
Packet8f psqrt<Packet8f>(const Packet8f& x) {
|
Packet8f psqrt<Packet8f>(const Packet8f& _x) {
|
||||||
return _mm256_sqrt_ps(x);
|
return _mm256_sqrt_ps(_x);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||||
Packet4d psqrt<Packet4d>(const Packet4d& x) {
|
Packet4d psqrt<Packet4d>(const Packet4d& _x) {
|
||||||
return _mm256_sqrt_pd(x);
|
return _mm256_sqrt_pd(_x);
|
||||||
}
|
}
|
||||||
#if EIGEN_FAST_MATH
|
#if EIGEN_FAST_MATH
|
||||||
|
|
||||||
@ -140,18 +140,27 @@ Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||||
Packet8f prsqrt<Packet8f>(const Packet8f& x) {
|
Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
|
||||||
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
|
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
|
||||||
return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
|
return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||||
Packet4d prsqrt<Packet4d>(const Packet4d& x) {
|
Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
|
||||||
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
|
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
|
||||||
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
|
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
|
||||||
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
@ -32,11 +32,13 @@ typedef __m256 Packet8f;
|
|||||||
typedef __m256i Packet8i;
|
typedef __m256i Packet8i;
|
||||||
typedef __m256d Packet4d;
|
typedef __m256d Packet4d;
|
||||||
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
|
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
|
||||||
|
typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
|
||||||
|
|
||||||
template<> struct is_arithmetic<__m256> { enum { value = true }; };
|
template<> struct is_arithmetic<__m256> { enum { value = true }; };
|
||||||
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
|
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
|
||||||
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
|
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
|
||||||
template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
|
template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
|
||||||
|
template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
|
||||||
|
|
||||||
#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \
|
#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \
|
||||||
const Packet8f p8f_##NAME = pset1<Packet8f>(X)
|
const Packet8f p8f_##NAME = pset1<Packet8f>(X)
|
||||||
@ -134,6 +136,40 @@ struct packet_traits<Eigen::half> : default_packet_traits {
|
|||||||
HasBlend = 0
|
HasBlend = 0
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct packet_traits<bfloat16> : default_packet_traits {
|
||||||
|
typedef Packet8bf type;
|
||||||
|
// There is no half-size packet for current Packet8bf.
|
||||||
|
// TODO: support as SSE path.
|
||||||
|
typedef Packet8bf half;
|
||||||
|
enum {
|
||||||
|
Vectorizable = 1,
|
||||||
|
AlignedOnScalar = 1,
|
||||||
|
size = 8,
|
||||||
|
HasHalfPacket = 0,
|
||||||
|
|
||||||
|
HasCmp = 1,
|
||||||
|
HasDiv = 1,
|
||||||
|
HasSin = EIGEN_FAST_MATH,
|
||||||
|
HasCos = EIGEN_FAST_MATH,
|
||||||
|
HasLog = 1,
|
||||||
|
HasLog1p = 1,
|
||||||
|
HasExpm1 = 1,
|
||||||
|
HasExp = 1,
|
||||||
|
HasNdtri = 1,
|
||||||
|
HasBessel = 1,
|
||||||
|
HasSqrt = 1,
|
||||||
|
HasRsqrt = 1,
|
||||||
|
HasTanh = EIGEN_FAST_MATH,
|
||||||
|
HasErf = EIGEN_FAST_MATH,
|
||||||
|
HasBlend = 0,
|
||||||
|
HasRound = 1,
|
||||||
|
HasFloor = 1,
|
||||||
|
HasCeil = 1,
|
||||||
|
HasRint = 1
|
||||||
|
};
|
||||||
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<> struct scalar_div_cost<float,true> { enum { value = 14 }; };
|
template<> struct scalar_div_cost<float,true> { enum { value = 14 }; };
|
||||||
@ -165,6 +201,14 @@ template<> struct unpacket_traits<Packet4d> {
|
|||||||
enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
||||||
};
|
};
|
||||||
template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; };
|
template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; };
|
||||||
|
template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; };
|
||||||
|
|
||||||
|
// Helper function for bit packing snippet of low precision comparison.
|
||||||
|
// It packs the flags from 16x16 to 8x16.
|
||||||
|
EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) {
|
||||||
|
return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
|
||||||
|
_mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
|
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
|
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
|
||||||
@ -1032,6 +1076,307 @@ ptranspose(PacketBlock<Packet8h,4>& kernel) {
|
|||||||
kernel.packet[3] = pload<Packet8h>(out[3]);
|
kernel.packet[3] = pload<Packet8h>(out[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
|
||||||
|
#ifdef EIGEN_VECTORIZE_AVX2
|
||||||
|
__m256i extend = _mm256_cvtepu16_epi32(a);
|
||||||
|
return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16));
|
||||||
|
#else
|
||||||
|
__m128i lo = _mm_cvtepu16_epi32(a);
|
||||||
|
__m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8));
|
||||||
|
__m128i lo_shift = _mm_slli_epi32(lo, 16);
|
||||||
|
__m128i hi_shift = _mm_slli_epi32(hi, 16);
|
||||||
|
return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
||||||
|
Packet8bf r;
|
||||||
|
|
||||||
|
// Flush input denormals value to zero with hardware capability.
|
||||||
|
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||||
|
__m256 flush = _mm256_and_ps(a, a);
|
||||||
|
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
||||||
|
|
||||||
|
__m256i input = _mm256_castps_si256(flush);
|
||||||
|
|
||||||
|
#ifdef EIGEN_VECTORIZE_AVX2
|
||||||
|
// uint32_t lsb = (input >> 16);
|
||||||
|
__m256i t = _mm256_srli_epi32(input, 16);
|
||||||
|
// uint32_t lsb = lsb & 1;
|
||||||
|
t = _mm256_and_si256(t, _mm256_set1_epi32(1));
|
||||||
|
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||||
|
t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff));
|
||||||
|
// input += rounding_bias;
|
||||||
|
t = _mm256_add_epi32(t, input);
|
||||||
|
// input = input >> 16;
|
||||||
|
t = _mm256_srli_epi32(t, 16);
|
||||||
|
// Check NaN before converting back to bf16
|
||||||
|
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
|
||||||
|
__m256i nan = _mm256_set1_epi32(0x7fc0);
|
||||||
|
t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
|
||||||
|
// output.value = static_cast<uint16_t>(input);
|
||||||
|
return _mm_packus_epi32(_mm256_extractf128_si256(t, 0),
|
||||||
|
_mm256_extractf128_si256(t, 1));
|
||||||
|
#else
|
||||||
|
// uint32_t lsb = (input >> 16);
|
||||||
|
__m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16);
|
||||||
|
__m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16);
|
||||||
|
// uint32_t lsb = lsb & 1;
|
||||||
|
lo = _mm_and_si128(lo, _mm_set1_epi32(1));
|
||||||
|
hi = _mm_and_si128(hi, _mm_set1_epi32(1));
|
||||||
|
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||||
|
lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff));
|
||||||
|
hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff));
|
||||||
|
// input += rounding_bias;
|
||||||
|
lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0));
|
||||||
|
hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1));
|
||||||
|
// input = input >> 16;
|
||||||
|
lo = _mm_srli_epi32(lo, 16);
|
||||||
|
hi = _mm_srli_epi32(hi, 16);
|
||||||
|
// Check NaN before converting back to bf16
|
||||||
|
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
|
||||||
|
__m128i nan = _mm_set1_epi32(0x7fc0);
|
||||||
|
lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
|
||||||
|
hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
|
||||||
|
// output.value = static_cast<uint16_t>(input);
|
||||||
|
return _mm_packus_epi32(lo, hi);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
|
||||||
|
return _mm_set1_epi16(from.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) {
|
||||||
|
return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<unsigned short>(_mm_extract_epi16(from, 0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) {
|
||||||
|
return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) {
|
||||||
|
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) {
|
||||||
|
_mm_store_si128(reinterpret_cast<__m128i*>(to), from);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) {
|
||||||
|
_mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf
|
||||||
|
ploaddup<Packet8bf>(const bfloat16* from) {
|
||||||
|
unsigned short a = from[0].value;
|
||||||
|
unsigned short b = from[1].value;
|
||||||
|
unsigned short c = from[2].value;
|
||||||
|
unsigned short d = from[3].value;
|
||||||
|
return _mm_set_epi16(d, d, c, c, b, b, a, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf
|
||||||
|
ploadquad<Packet8bf>(const bfloat16* from) {
|
||||||
|
unsigned short a = from[0].value;
|
||||||
|
unsigned short b = from[1].value;
|
||||||
|
return _mm_set_epi16(b, b, b, b, a, a, a, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
|
||||||
|
return _mm_cmpeq_epi32(a, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
|
||||||
|
return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a,
|
||||||
|
const Packet8bf& b) {
|
||||||
|
return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a,
|
||||||
|
const Packet8bf& b) {
|
||||||
|
return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return _mm_or_si128(a,b);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return _mm_xor_si128(a,b);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return _mm_and_si128(a,b);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return _mm_andnot_si128(b,a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return _mm_blendv_epi8(b, a, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
return F32ToBf16(pround<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return F32ToBf16(print<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) {
|
||||||
|
return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; }
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) {
|
||||||
|
Packet8bf sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000));
|
||||||
|
return _mm_xor_si128(a, sign_mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
|
||||||
|
{
|
||||||
|
return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
|
||||||
|
{
|
||||||
|
EIGEN_ALIGN32 bfloat16 aux[8];
|
||||||
|
pstore(aux, from);
|
||||||
|
to[stride*0] = aux[0];
|
||||||
|
to[stride*1] = aux[1];
|
||||||
|
to[stride*2] = aux[2];
|
||||||
|
to[stride*3] = aux[3];
|
||||||
|
to[stride*4] = aux[4];
|
||||||
|
to[stride*5] = aux[5];
|
||||||
|
to[stride*6] = aux[6];
|
||||||
|
to[stride*7] = aux[7];
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) {
|
||||||
|
return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
|
||||||
|
{
|
||||||
|
__m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
|
||||||
|
return _mm_shuffle_epi8(a,m);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void
|
||||||
|
ptranspose(PacketBlock<Packet8bf,8>& kernel) {
|
||||||
|
__m128i a = kernel.packet[0];
|
||||||
|
__m128i b = kernel.packet[1];
|
||||||
|
__m128i c = kernel.packet[2];
|
||||||
|
__m128i d = kernel.packet[3];
|
||||||
|
__m128i e = kernel.packet[4];
|
||||||
|
__m128i f = kernel.packet[5];
|
||||||
|
__m128i g = kernel.packet[6];
|
||||||
|
__m128i h = kernel.packet[7];
|
||||||
|
|
||||||
|
__m128i a03b03 = _mm_unpacklo_epi16(a, b);
|
||||||
|
__m128i c03d03 = _mm_unpacklo_epi16(c, d);
|
||||||
|
__m128i e03f03 = _mm_unpacklo_epi16(e, f);
|
||||||
|
__m128i g03h03 = _mm_unpacklo_epi16(g, h);
|
||||||
|
__m128i a47b47 = _mm_unpackhi_epi16(a, b);
|
||||||
|
__m128i c47d47 = _mm_unpackhi_epi16(c, d);
|
||||||
|
__m128i e47f47 = _mm_unpackhi_epi16(e, f);
|
||||||
|
__m128i g47h47 = _mm_unpackhi_epi16(g, h);
|
||||||
|
|
||||||
|
__m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
|
||||||
|
__m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
|
||||||
|
__m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
|
||||||
|
__m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
|
||||||
|
__m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
|
||||||
|
__m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
|
||||||
|
__m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
|
||||||
|
__m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
|
||||||
|
|
||||||
|
kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
|
||||||
|
kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
|
||||||
|
kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
|
||||||
|
kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
|
||||||
|
kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
|
||||||
|
kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
|
||||||
|
kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
|
||||||
|
kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void
|
||||||
|
ptranspose(PacketBlock<Packet8bf,4>& kernel) {
|
||||||
|
__m128i a = kernel.packet[0];
|
||||||
|
__m128i b = kernel.packet[1];
|
||||||
|
__m128i c = kernel.packet[2];
|
||||||
|
__m128i d = kernel.packet[3];
|
||||||
|
|
||||||
|
__m128i ab_03 = _mm_unpacklo_epi16(a, b);
|
||||||
|
__m128i cd_03 = _mm_unpacklo_epi16(c, d);
|
||||||
|
__m128i ab_47 = _mm_unpackhi_epi16(a, b);
|
||||||
|
__m128i cd_47 = _mm_unpackhi_epi16(c, d);
|
||||||
|
|
||||||
|
kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03);
|
||||||
|
kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03);
|
||||||
|
kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47);
|
||||||
|
kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -76,12 +76,38 @@ struct type_casting_traits<float, Eigen::half> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_casting_traits<bfloat16, float> {
|
||||||
|
enum {
|
||||||
|
VectorizedCast = 1,
|
||||||
|
SrcCoeffRatio = 1,
|
||||||
|
TgtCoeffRatio = 1
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
|
||||||
|
return Bf16ToF32(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_casting_traits<float, bfloat16> {
|
||||||
|
enum {
|
||||||
|
VectorizedCast = 1,
|
||||||
|
SrcCoeffRatio = 1,
|
||||||
|
TgtCoeffRatio = 1
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
#endif // EIGEN_VECTORIZE_AVX512
|
#endif // EIGEN_VECTORIZE_AVX512
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
|
template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
|
||||||
return float2half(a);
|
return float2half(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
|
||||||
|
return F32ToBf16(a);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -135,10 +135,7 @@ plog<Packet16f>(const Packet16f& _x) {
|
|||||||
p16f_minus_inf);
|
p16f_minus_inf);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
|
||||||
EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
|
|
||||||
return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Exponential function. Works by writing "x = m*log(2) + r" where
|
// Exponential function. Works by writing "x = m*log(2) + r" where
|
||||||
@ -264,10 +261,7 @@ pexp<Packet8d>(const Packet8d& _x) {
|
|||||||
return pmax(pmul(x, e), _x);
|
return pmax(pmul(x, e), _x);
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
template <>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
|
||||||
EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
|
|
||||||
return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Functions for sqrt.
|
// Functions for sqrt.
|
||||||
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
|
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
|
||||||
@ -325,10 +319,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
|
||||||
EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
|
|
||||||
return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// prsqrt for float.
|
// prsqrt for float.
|
||||||
#if defined(EIGEN_VECTORIZE_AVX512ER)
|
#if defined(EIGEN_VECTORIZE_AVX512ER)
|
||||||
@ -377,10 +368,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
|
||||||
EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
|
|
||||||
return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// prsqrt for double.
|
// prsqrt for double.
|
||||||
#if EIGEN_FAST_MATH
|
#if EIGEN_FAST_MATH
|
||||||
@ -435,20 +423,14 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
|
|||||||
return generic_plog1p(_x);
|
return generic_plog1p(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
|
||||||
EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
|
|
||||||
return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||||
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
|
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
|
||||||
return generic_expm1(_x);
|
return generic_expm1(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
|
||||||
EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
|
|
||||||
return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@ -460,32 +442,21 @@ psin<Packet16f>(const Packet16f& _x) {
|
|||||||
return psin_float(_x);
|
return psin_float(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
|
|
||||||
return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
|
||||||
pcos<Packet16f>(const Packet16f& _x) {
|
pcos<Packet16f>(const Packet16f& _x) {
|
||||||
return pcos_float(_x);
|
return pcos_float(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
|
|
||||||
return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
|
||||||
ptanh<Packet16f>(const Packet16f& _x) {
|
ptanh<Packet16f>(const Packet16f& _x) {
|
||||||
return internal::generic_fast_tanh_float(_x);
|
return internal::generic_fast_tanh_float(_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
|
||||||
EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
|
||||||
return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
@ -1633,13 +1633,13 @@ template <>
|
|||||||
struct packet_traits<bfloat16> : default_packet_traits {
|
struct packet_traits<bfloat16> : default_packet_traits {
|
||||||
typedef Packet16bf type;
|
typedef Packet16bf type;
|
||||||
// There is no half-size packet for current Packet16bf.
|
// There is no half-size packet for current Packet16bf.
|
||||||
// TODO: support as SSE/AVX path.
|
// TODO: support as SSE path.
|
||||||
typedef Packet16bf half;
|
typedef Packet8bf half;
|
||||||
enum {
|
enum {
|
||||||
Vectorizable = 1,
|
Vectorizable = 1,
|
||||||
AlignedOnScalar = 1,
|
AlignedOnScalar = 1,
|
||||||
size = 16,
|
size = 16,
|
||||||
HasHalfPacket = 0,
|
HasHalfPacket = 1,
|
||||||
HasBlend = 0,
|
HasBlend = 0,
|
||||||
HasInsert = 1,
|
HasInsert = 1,
|
||||||
HasSin = EIGEN_FAST_MATH,
|
HasSin = EIGEN_FAST_MATH,
|
||||||
@ -1668,7 +1668,7 @@ struct unpacket_traits<Packet16bf>
|
|||||||
{
|
{
|
||||||
typedef bfloat16 type;
|
typedef bfloat16 type;
|
||||||
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
||||||
typedef Packet16bf half;
|
typedef Packet8bf half;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
|
|||||||
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
|
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
|
// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
|
||||||
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||||
Packet16bf r;
|
Packet16bf r;
|
||||||
|
|
||||||
// Flush input denormals value to zero with hardware capability.
|
// Flush input denormals value to zero with hardware capability.
|
||||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||||
|
#if defined(EIGEN_VECTORIZE_AVX512DQ)
|
||||||
__m512 flush = _mm512_and_ps(a, a);
|
__m512 flush = _mm512_and_ps(a, a);
|
||||||
|
#else
|
||||||
|
__m512 flush = _mm512_max_ps(a, a);
|
||||||
|
#endif // EIGEN_VECTORIZE_AVX512DQ
|
||||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
||||||
|
|
||||||
#if defined(EIGEN_VECTORIZE_AVX512BF16)
|
#if defined(EIGEN_VECTORIZE_AVX512BF16)
|
||||||
@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
|||||||
|
|
||||||
// output.value = static_cast<uint16_t>(input);
|
// output.value = static_cast<uint16_t>(input);
|
||||||
r.i = _mm512_cvtepi32_epi16(t);
|
r.i = _mm512_cvtepi32_epi16(t);
|
||||||
#endif
|
#endif // EIGEN_VECTORIZE_AVX512BF16
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
@ -1911,6 +1915,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
|
|||||||
return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
|
||||||
|
Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
|
||||||
|
Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
|
||||||
|
return padd<Packet8bf>(lane0, lane1);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
|
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
|
||||||
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
|
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
|
||||||
@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
|
|||||||
// Swap hi and lo first because shuffle is in 128-bit lanes.
|
// Swap hi and lo first because shuffle is in 128-bit lanes.
|
||||||
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
|
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
|
||||||
// Shuffle 8-bit values in src within 2*128-bit lanes.
|
// Shuffle 8-bit values in src within 2*128-bit lanes.
|
||||||
res.i = _mm256_shuffle_epi8(a.i, m);
|
res.i = _mm256_shuffle_epi8(res.i, m);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
|
|||||||
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
|
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
|
||||||
|
|
||||||
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
|
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
|
||||||
kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
|
kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
|
||||||
0x20);
|
kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
|
||||||
kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
|
kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
|
||||||
0x20);
|
kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
|
||||||
kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
|
kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
|
||||||
0x20);
|
kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
|
||||||
kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
|
kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
|
||||||
0x20);
|
kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
|
||||||
kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
|
kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
|
||||||
0x20);
|
kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
|
||||||
kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
|
kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
|
||||||
0x20);
|
kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
|
||||||
kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
|
kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
|
||||||
0x20);
|
kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
|
||||||
kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
|
kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
|
||||||
0x20);
|
kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
|
||||||
kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
|
|
||||||
0x20);
|
|
||||||
kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
|
|
||||||
0x20);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
|
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
|
||||||
|
@ -23,6 +23,13 @@ limitations under the License.
|
|||||||
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
|
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
|
||||||
|
template <> \
|
||||||
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
|
||||||
|
PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
|
||||||
|
return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
struct bfloat16;
|
struct bfloat16;
|
||||||
|
@ -310,6 +310,12 @@ macro(ei_testing_print_summary)
|
|||||||
message(STATUS "AVX: Using architecture defaults")
|
message(STATUS "AVX: Using architecture defaults")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(EIGEN_TEST_AVX2)
|
||||||
|
message(STATUS "AVX2: ON")
|
||||||
|
else()
|
||||||
|
message(STATUS "AVX2: Using architecture defaults")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(EIGEN_TEST_FMA)
|
if(EIGEN_TEST_FMA)
|
||||||
message(STATUS "FMA: ON")
|
message(STATUS "FMA: ON")
|
||||||
else()
|
else()
|
||||||
@ -322,6 +328,12 @@ macro(ei_testing_print_summary)
|
|||||||
message(STATUS "AVX512: Using architecture defaults")
|
message(STATUS "AVX512: Using architecture defaults")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(EIGEN_TEST_AVX512DQ)
|
||||||
|
message(STATUS "AVX512DQ: ON")
|
||||||
|
else()
|
||||||
|
message(STATUS "AVX512DQ: Using architecture defaults")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(EIGEN_TEST_ALTIVEC)
|
if(EIGEN_TEST_ALTIVEC)
|
||||||
message(STATUS "Altivec: ON")
|
message(STATUS "Altivec: ON")
|
||||||
else()
|
else()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user