Support BFloat16 in Eigen

This commit is contained in:
Teng Lu 2020-06-20 19:16:24 +00:00 committed by Rasmus Munk Larsen
parent 6b9c92fe7e
commit 386d809bde
19 changed files with 1893 additions and 14 deletions

View File

@ -241,13 +241,22 @@ if(NOT MSVC)
option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF)
if(EIGEN_TEST_AVX512)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma -DEIGEN_ENABLE_AVX512")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma")
if (NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fabi-version=6")
endif()
message(STATUS "Enabling AVX512 in tests/examples")
endif()
option(EIGEN_TEST_AVX512DQ "Enable/Disable AVX512DQ in tests/examples" OFF)
if(EIGEN_TEST_AVX512DQ)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512dq")
if (NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fabi-version=6")
endif()
message(STATUS "Enabling AVX512DQ in tests/examples")
endif()
option(EIGEN_TEST_F16C "Enable/Disable F16C in tests/examples" OFF)
if(EIGEN_TEST_F16C)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")

View File

@ -51,6 +51,10 @@
#define EIGEN_HAS_GPU_FP16
#endif
#if defined(EIGEN_HAS_CUDA_BF16) || defined(EIGEN_HAS_HIP_BF16)
#define EIGEN_HAS_GPU_BF16
#endif
#if (defined _OPENMP) && (!defined EIGEN_DONT_PARALLELIZE)
#define EIGEN_HAS_OPENMP
#endif
@ -163,6 +167,7 @@ using std::ptrdiff_t;
#include "src/Core/arch/Default/ConjHelper.h"
// Generic half float support
#include "src/Core/arch/Default/Half.h"
#include "src/Core/arch/Default/BFloat16.h"
#include "src/Core/arch/Default/TypeCasting.h"
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"

View File

@ -29,6 +29,12 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \
const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X))
#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \
const Packet16bf p16bf_##NAME = pset1<Packet16bf>(X)
#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
// Natural logarithm
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
@ -128,6 +134,11 @@ plog<Packet16f>(const Packet16f& _x) {
p16f_nan),
p16f_minus_inf);
}
template <>
EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
}
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
@ -253,6 +264,10 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
template <>
EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@ -303,12 +318,18 @@ template <>
EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) {
return _mm512_sqrt_ps(x);
}
template <>
EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
return _mm512_sqrt_pd(x);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
}
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
@ -316,7 +337,6 @@ template <>
EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
return _mm512_rsqrt28_ps(x);
}
#elif EIGEN_FAST_MATH
template <>
@ -347,8 +367,7 @@ prsqrt<Packet16f>(const Packet16f& _x) {
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx);
}
}
#else
template <>
@ -356,9 +375,13 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
_EIGEN_DECLARE_CONST_Packet16f(one, 1.0f);
return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x));
}
#endif
template <>
EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
}
// prsqrt for double.
#if EIGEN_FAST_MATH
template <>
@ -412,10 +435,20 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
template<>
EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
template<>
EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
}
#endif
#endif
@ -427,18 +460,33 @@ psin<Packet16f>(const Packet16f& _x) {
return psin_float(_x);
}
template <>
EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
template <>
EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
}
} // end namespace internal
} // end namespace Eigen

View File

@ -362,6 +362,25 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
}
#endif
// Helper function for bit packing snippet of low precision comparison.
// It packs the flags from 32x16 to 16x16.
EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
// Split data into small pieces and handle with AVX instructions
// to guarantee internal order of vector.
// Operation:
// dst[15:0] := Saturate16(rf[31:0])
// dst[31:16] := Saturate16(rf[63:32])
// ...
// dst[255:240] := Saturate16(rf[255:224])
__m256i lo = _mm256_castps_si256(extract256<0>(rf));
__m256i hi = _mm256_castps_si256(extract256<1>(rf));
__m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
_mm256_extractf128_si256(lo, 1));
__m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
_mm256_extractf128_si256(hi, 1));
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
@ -1342,15 +1361,7 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = pcmp_eq(af, bf);
// Pack the 32-bit flags into 16-bits flags.
__m256i lo = _mm256_castps_si256(extract256<0>(rf));
__m256i hi = _mm256_castps_si256(extract256<1>(rf));
__m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
_mm256_extractf128_si256(lo, 1));
__m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
_mm256_extractf128_si256(hi, 1));
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
return Pack32To16(pcmp_eq(af, bf));
}
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
@ -1607,6 +1618,493 @@ ptranspose(PacketBlock<Packet16h,4>& kernel) {
kernel.packet[3] = pload<Packet16h>(out[3]);
}
typedef union {
#ifdef EIGEN_VECTORIZE_AVX512BF16
__m256bh bh;
#endif
Packet8i i; // __m256i;
} Packet16bf;
template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
// There is no half-size packet for current Packet16bf.
// TODO: support as SSE/AVX path.
typedef Packet16bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
HasHalfPacket = 0,
HasBlend = 0,
HasInsert = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasNdtri = 1,
HasBessel = 1,
#endif
HasExp = 1,
HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
#endif
HasDiv = 1
};
};
template <>
struct unpacket_traits<Packet16bf>
{
typedef bfloat16 type;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
typedef Packet16bf half;
};
template <>
EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
Packet16bf r;
r.i = _mm256_set1_epi16(from.value);
return r;
}
template <>
EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
bfloat16 t;
t.value = static_cast<unsigned short>(_mm256_extract_epi16(from.i, 0));
return t;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
Packet16bf r;
r.i = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
Packet16bf r;
r.i = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
return r;
}
template <>
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
const Packet16bf& from) {
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
const Packet16bf& from) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i);
}
template<> EIGEN_STRONG_INLINE Packet16bf
ploaddup<Packet16bf>(const bfloat16* from) {
Packet16bf r;
unsigned short a = from[0].value;
unsigned short b = from[1].value;
unsigned short c = from[2].value;
unsigned short d = from[3].value;
unsigned short e = from[4].value;
unsigned short f = from[5].value;
unsigned short g = from[6].value;
unsigned short h = from[7].value;
r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
return r;
}
template<> EIGEN_STRONG_INLINE Packet16bf
ploadquad(const bfloat16* from) {
Packet16bf r;
unsigned short a = from[0].value;
unsigned short b = from[1].value;
unsigned short c = from[2].value;
unsigned short d = from[3].value;
r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
return r;
}
EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
}
// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
// Flush input denormals value to zero with hardware capability.
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
__m512 flush = _mm512_and_ps(a, a);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
#if defined(EIGEN_VECTORIZE_AVX512BF16)
r.bh = _mm512_cvtneps_pbh(flush);
#else
__m512i t;
__m512i input = _mm512_castps_si512(flush);
__m512i nan = _mm512_set1_epi32(0x7fc0);
// uint32_t lsb = (input >> 16) & 1;
t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
// uint32_t rounding_bias = 0x7fff + lsb;
t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
// input += rounding_bias;
t = _mm512_add_epi32(t, input);
// input = input >> 16;
t = _mm512_srli_epi32(t, 16);
// Check NaN before converting back to bf16
__mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q);
t = _mm512_mask_blend_epi32(mask, nan, t);
// output.value = static_cast<uint16_t>(input);
r.i = _mm512_cvtepi32_epi16(t);
#endif
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
Packet16bf r;
r.i = ptrue<Packet8i>(a.i);
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
Packet16bf r;
r.i = por<Packet8i>(a.i, b.i);
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
Packet16bf r;
r.i = pxor<Packet8i>(a.i, b.i);
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
Packet16bf r;
r.i = pand<Packet8i>(a.i, b.i);
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
const Packet16bf& b) {
Packet16bf r;
r.i = pandnot<Packet8i>(a.i, b.i);
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
const Packet16bf& a,
const Packet16bf& b) {
// Input mask is expected to be all 0/1, handle it with 8-bit
// intrinsic for performance.
Packet16bf r;
r.i = _mm256_blendv_epi8(b.i, a.i, mask.i);
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
const Packet16bf& b) {
Packet16bf result;
result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
return result;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
const Packet16bf& b) {
Packet16bf result;
result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
return result;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
const Packet16bf& b) {
Packet16bf result;
result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
return result;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
const Packet16bf& b) {
Packet16bf result;
result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
return result;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
Packet16bf sign_mask;
sign_mask.i = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
Packet16bf result;
result.i = _mm256_xor_si256(a.i, sign_mask.i);
return result;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a,
const Packet16bf& b) {
return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a,
const Packet16bf& b) {
return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a,
const Packet16bf& b) {
return F32ToBf16(pmul<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a,
const Packet16bf& b) {
return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a,
const Packet16bf& b) {
return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
const Packet16bf& b) {
return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
__m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,
14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
Packet16bf res;
// Swap hi and lo first because shuffle is in 128-bit lanes.
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
res.i = _mm256_shuffle_epi8(a.i, m);
return res;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
Index stride) {
Packet16bf result;
result.i = _mm256_set_epi16(
from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
return result;
}
template <>
EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
const Packet16bf& from,
Index stride) {
EIGEN_ALIGN64 bfloat16 aux[16];
pstore(aux, from);
to[stride*0].value = aux[0].value;
to[stride*1].value = aux[1].value;
to[stride*2].value = aux[2].value;
to[stride*3].value = aux[3].value;
to[stride*4].value = aux[4].value;
to[stride*5].value = aux[5].value;
to[stride*6].value = aux[6].value;
to[stride*7].value = aux[7].value;
to[stride*8].value = aux[8].value;
to[stride*9].value = aux[9].value;
to[stride*10].value = aux[10].value;
to[stride*11].value = aux[11].value;
to[stride*12].value = aux[12].value;
to[stride*13].value = aux[13].value;
to[stride*14].value = aux[14].value;
to[stride*15].value = aux[15].value;
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
__m256i a = kernel.packet[0].i;
__m256i b = kernel.packet[1].i;
__m256i c = kernel.packet[2].i;
__m256i d = kernel.packet[3].i;
__m256i e = kernel.packet[4].i;
__m256i f = kernel.packet[5].i;
__m256i g = kernel.packet[6].i;
__m256i h = kernel.packet[7].i;
__m256i i = kernel.packet[8].i;
__m256i j = kernel.packet[9].i;
__m256i k = kernel.packet[10].i;
__m256i l = kernel.packet[11].i;
__m256i m = kernel.packet[12].i;
__m256i n = kernel.packet[13].i;
__m256i o = kernel.packet[14].i;
__m256i p = kernel.packet[15].i;
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
__m256i ef_07 = _mm256_unpacklo_epi16(e, f);
__m256i gh_07 = _mm256_unpacklo_epi16(g, h);
__m256i ij_07 = _mm256_unpacklo_epi16(i, j);
__m256i kl_07 = _mm256_unpacklo_epi16(k, l);
__m256i mn_07 = _mm256_unpacklo_epi16(m, n);
__m256i op_07 = _mm256_unpacklo_epi16(o, p);
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
__m256i ef_8f = _mm256_unpackhi_epi16(e, f);
__m256i gh_8f = _mm256_unpackhi_epi16(g, h);
__m256i ij_8f = _mm256_unpackhi_epi16(i, j);
__m256i kl_8f = _mm256_unpackhi_epi16(k, l);
__m256i mn_8f = _mm256_unpackhi_epi16(m, n);
__m256i op_8f = _mm256_unpackhi_epi16(o, p);
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
__m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
__m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
__m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
__m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
__m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
__m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
__m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
__m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
__m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
__m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
__m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
__m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
__m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
__m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
__m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
__m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
__m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
__m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
__m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
__m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
__m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
__m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
__m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
__m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
__m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
__m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
__m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
0x20);
kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
0x20);
kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
0x20);
kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
0x20);
kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
0x20);
kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
0x20);
kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
0x20);
kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
0x20);
kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
0x20);
kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
0x20);
kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
0x20);
kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
0x20);
kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
0x20);
kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
0x20);
kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
0x20);
kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
0x20);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
__m256i a = kernel.packet[0].i;
__m256i b = kernel.packet[1].i;
__m256i c = kernel.packet[2].i;
__m256i d = kernel.packet[3].i;
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
}
} // end namespace internal

View File

@ -40,6 +40,32 @@ template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packe
return float2half(a);
}
template <>
struct type_casting_traits<bfloat16, float> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
return Bf16ToF32(a);
}
template <>
struct type_casting_traits<float, bfloat16> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
return F32ToBf16(a);
}
} // end namespace internal
} // end namespace Eigen

View File

@ -0,0 +1,703 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EIGEN_BFLOAT16_H
#define EIGEN_BFLOAT16_H
#if __cplusplus > 199711L
#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
#else
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
#endif
namespace Eigen {
struct bfloat16;
namespace bfloat16_impl {
// Make our own __bfloat16_raw definition.
struct __bfloat16_raw {
EIGEN_DEVICE_FUNC __bfloat16_raw() : value(0) {}
explicit EIGEN_DEVICE_FUNC __bfloat16_raw(unsigned short raw) : value(raw) {}
unsigned short value;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
struct bfloat16_base : public __bfloat16_raw {
EIGEN_DEVICE_FUNC bfloat16_base() {}
EIGEN_DEVICE_FUNC bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
};
} // namespace bfloat16_impl
// Class definition.
struct bfloat16 : public bfloat16_impl::bfloat16_base {
typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
EIGEN_DEVICE_FUNC bfloat16() {}
EIGEN_DEVICE_FUNC bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
explicit EIGEN_DEVICE_FUNC bfloat16(bool b)
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
template<class T>
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val))) {}
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {}
// Following the convention of numpy, converting between complex and
// float will lead to loss of imag value.
// Single precision complex.
typedef std::complex<float> complex64;
// Double precision complex.
typedef std::complex<double> complex128;
explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {}
explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val)
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
// +0.0 and -0.0 become false, everything else becomes true.
return (value & 0x7fff) != 0;
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const {
return static_cast<signed char>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const {
return static_cast<unsigned char>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const {
return static_cast<short>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const {
return static_cast<unsigned short>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const {
return static_cast<int>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const {
return static_cast<unsigned int>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const {
return static_cast<long>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const {
return static_cast<unsigned long>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const {
return static_cast<long long>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
return static_cast<unsigned long long>(bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
return bfloat16_impl::bfloat16_to_float(*this);
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const {
return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const {
return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0));
}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const {
return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this));
}
};
} // end namespace Eigen
namespace std {
template<>
struct numeric_limits<Eigen::bfloat16> {
static const bool is_specialized = true;
static const bool is_signed = true;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
static const float_denorm_style has_denorm = numeric_limits<float>::has_denorm;
static const bool has_denorm_loss = numeric_limits<float>::has_denorm_loss;
static const std::float_round_style round_style = numeric_limits<float>::round_style;
static const bool is_iec559 = false;
static const bool is_bounded = true;
static const bool is_modulo = false;
static const int digits = 8;
static const int digits10 = 2;
static const int max_digits10 = 4;
static const int radix = 2;
static const int min_exponent = numeric_limits<float>::min_exponent;
static const int min_exponent10 = numeric_limits<float>::min_exponent10;
static const int max_exponent = numeric_limits<float>::max_exponent;
static const int max_exponent10 = numeric_limits<float>::max_exponent10;
static const bool traps = numeric_limits<float>::traps;
static const bool tinyness_before = numeric_limits<float>::tinyness_before;
static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
};
// If std::numeric_limits<T> is specialized, should also specialize
// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
// std::numeric_limits<const volatile T>
// https://stackoverflow.com/a/16519653/
template<>
struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
template<>
struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
template<>
struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
} // end namespace std
namespace Eigen {
namespace bfloat16_impl {
// We need to distinguish clang as the CUDA compiler from clang as the host compiler,
// invoked by NVCC (e.g. on MacOS). The former needs to see both host and device implementation
// of the functions, while the latter can only deal with one of them.
#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
// We need to provide emulated *host-side* BF16 operators for clang.
#pragma push_macro("EIGEN_DEVICE_FUNC")
#undef EIGEN_DEVICE_FUNC
#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
#define EIGEN_DEVICE_FUNC __host__
#else // both host and device need emulated ops.
#define EIGEN_DEVICE_FUNC __host__ __device__
#endif
#endif
// Definitions for CPUs, mostly working through conversion
// to/from fp32.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) + float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
return bfloat16(float(a) + static_cast<float>(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
return bfloat16(static_cast<float>(a) + float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) * float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) - float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
return bfloat16(float(a) / float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
bfloat16 result;
result.value = a.value ^ 0x8000;
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) + float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) * float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) - float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
a = bfloat16(float(a) / float(b));
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
a += bfloat16(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
a -= bfloat16(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
bfloat16 original_value = a;
++a;
return original_value;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
bfloat16 original_value = a;
--a;
return original_value;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
return numext::equal_strict(float(a),float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
return numext::not_equal_strict(float(a), float(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
return float(a) < float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
return float(a) <= float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
return float(a) > float(b);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
return float(a) >= float(b);
}
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
#endif
#endif // Emulate support for bfloat16 floats
// Division by an index. Do it in full float precision to avoid accuracy
// issues in converting the denominator to bfloat16.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
__bfloat16_raw output;
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
output.value = 0x7FC0;
return output;
} else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
// Flush denormal to +/- 0.
output.value = std::signbit(v) ? 0x8000 : 0;
return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
output.value = p[0];
#else
output.value = p[1];
#endif
return output;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value) {
__bfloat16_raw h;
h.value = value;
return h;
}
union float32_bits {
unsigned int u;
float f;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) {
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
// Nothing to do here
#else
unsigned int input;
float32_bits f;
f.f = ff;
input = f.u;
__bfloat16_raw output;
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
// this makes sure after truncation we don't end up with an inf.
//
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
output.value = 0x7fc0;
} else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
// Flush denormal to +/- 0.0
output.value = std::signbit(ff) ? 0x8000 : 0;
} else {
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here
// is how it works:
//
// Definitions:
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
// with the following tags:
//
// Sign | Exp (8 bits) | Frac (23 bits)
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
//
// S: Sign bit.
// E: Exponent bits.
// F: First 6 bits of fraction.
// L: Least significant bit of resulting bfloat16 if we truncate away the
// rest of the float32. This is also the 7th bit of fraction
// R: Rounding bit, 8th bit of fraction.
// T: Sticky bits, rest of fraction, 15 bits.
//
// To round half to nearest even, there are 3 cases where we want to round
// down (simply truncate the result of the bits away, which consists of
// rounding bit and sticky bits) and two cases where we want to round up
// (truncate then add one to the result).
//
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
// 1s) as the rounding bias, adds the rounding bias to the input, then
// truncates the last 16 bits away.
//
// To understand how it works, we can analyze this algorithm case by case:
//
// 1. L = 0, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input may create any carry, depending on
// whether there is any value set to 1 in T bits.
// - R may be set to 1 if there is a carry.
// - L remains 0.
// - Note that this case also handles Inf and -Inf, where all fraction
// bits, including L, R and Ts are all 0. The output remains Inf after
// this algorithm.
//
// 2. L = 1, R = 0:
// Expect: round down, this is less than half value.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits but
// adds 1 to rounding bit.
// - L remains 1.
//
// 3. L = 0, R = 1, all of T are 0:
// Expect: round down, this is exactly at half, the result is already
// even (L=0).
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input sets all sticky bits to 1, but
// doesn't create a carry.
// - R remains 1.
// - L remains 0.
//
// 4. L = 1, R = 1:
// Expect: round up, this is exactly at half, the result needs to be
// round to the next even number.
//
// Algorithm:
// - Rounding bias: 0x7fff + 1 = 0x8000
// - Adding rounding bias to input doesn't change sticky bits, but
// creates a carry from rounding bit.
// - The carry sets L to 0, creates another carry bit and propagate
// forward to F bits.
// - If all the F bits are 1, a carry then propagates to the exponent
// bits, which then creates the minimum value with the next exponent
// value. Note that we won't have the case where exponents are all 1,
// since that's either a NaN (handled in the other if condition) or inf
// (handled in case 1).
//
// 5. L = 0, R = 1, any of T is 1:
// Expect: round up, this is greater than half.
//
// Algorithm:
// - Rounding bias: 0x7fff + 0 = 0x7fff
// - Adding rounding bias to input creates a carry from sticky bits,
// sets rounding bit to 0, then create another carry.
// - The second carry sets L to 1.
//
// Examples:
//
// Exact half value that is already even:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
//
// This falls into case 3. We truncate the rest of 16 bits and no
// carry is created into F and L:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
// Exact half value, round to next even number:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
//
// This falls into case 4. We create a carry from R and T,
// which then propagates into L and F:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
//
//
// Max denormal value round to min normal value:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Output:
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
//
// Max normal value round to Inf:
// Input:
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
//
// This falls into case 4. We create a carry from R and T,
// propagate into L and F, which then propagates into exponent
// bits:
//
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
//
//
// Least significant bit of resulting bfloat.
unsigned int lsb = (input >> 16) & 1;
unsigned int rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value = static_cast<unsigned short>(input >> 16);
}
return output;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
float result = 0;
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = h.value;
#else
q[1] = h.value;
#endif
return result;
}
// --- standard functions ---
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
return std::isinf EIGEN_NOT_A_MACRO(float(a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
return std::isnan EIGEN_NOT_A_MACRO(float(a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
bfloat16 result;
result.value = a.value & 0x7FFF;
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
return bfloat16(::expf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
return bfloat16(numext::expm1(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
return bfloat16(::logf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
return bfloat16(numext::log1p(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
return bfloat16(::log10f(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
return bfloat16(::sqrtf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
return bfloat16(::powf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
return bfloat16(::sinf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
return bfloat16(::cosf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
return bfloat16(::tanf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
return bfloat16(::asinf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
return bfloat16(::acosf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
return bfloat16(::atanf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
return bfloat16(::sinhf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
return bfloat16(::coshf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
return bfloat16(::tanhf(float(a)));
}
#if EIGEN_HAS_CXX11_MATH
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
return bfloat16(::asinh(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
return bfloat16(::acosh(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
return bfloat16(::atanh(float(a)));
}
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
return bfloat16(::floorf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
return bfloat16(::ceilf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
return bfloat16(::fmodf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f2 < f1 ? b : a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f1 < f2 ? b : a;
}
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
os << static_cast<float>(v);
return os;
}
#endif
} // end namespace bfloat16_impl
namespace internal {
template<>
struct random_default_impl<bfloat16, false, false>
{
static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
{
return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
}
static inline bfloat16 run()
{
return run(bfloat16(-1.f), bfloat16(1.f));
}
};
template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
} // end namespace internal
template<> struct NumTraits<Eigen::bfloat16>
: GenericNumTraits<Eigen::bfloat16>
{
enum {
IsSigned = true,
IsInteger = false,
IsComplex = false,
RequireInitialization = false
};
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
}
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() { return Eigen::bfloat16(5e-2f); }
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
}
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
}
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
}
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
}
};
} // end namespace Eigen
namespace std {
#if __cplusplus > 199711L
template <>
struct hash<Eigen::bfloat16> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
return hash<float>()(static_cast<float>(a));
}
};
#endif
} // end namespace std
namespace Eigen {
namespace numext {
template<>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool (isnan)(const Eigen::bfloat16& h) {
return (bfloat16_impl::isnan)(h);
}
template<>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool (isinf)(const Eigen::bfloat16& h) {
return (bfloat16_impl::isinf)(h);
}
template<>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool (isfinite)(const Eigen::bfloat16& h) {
return (bfloat16_impl::isfinite)(h);
}
} // namespace Eigen
} // namespace numext
#endif // EIGEN_BFLOAT16_H

View File

@ -71,6 +71,49 @@ template<>
struct functor_traits<scalar_cast_op<Eigen::half, float> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<float, Eigen::bfloat16> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
return Eigen::bfloat16(a);
}
};
template<>
struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<int, Eigen::bfloat16> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
return Eigen::bfloat16(static_cast<float>(a));
}
};
template<>
struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<Eigen::bfloat16, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
return static_cast<float>(a);
}
};
template<>
struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
}
}

View File

@ -288,6 +288,9 @@
#ifdef __AVX512ER__
#define EIGEN_VECTORIZE_AVX512ER
#endif
#ifdef __AVX512BF16__
#define EIGEN_VECTORIZE_AVX512BF16
#endif
#endif
#endif

View File

@ -286,6 +286,7 @@ ei_add_test(ctorleak)
ei_add_test(mpl2only)
ei_add_test(inplace_decomposition)
ei_add_test(half_float)
ei_add_test(bfloat16_float)
ei_add_test(array_of_string)
ei_add_test(num_dimensions)
ei_add_test(stl_iterators)

399
test/bfloat16_float.cpp Normal file
View File

@ -0,0 +1,399 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include <sstream>
#include <memory>
#include <math.h>
#include "main.h"
#include <Eigen/src/Core/arch/Default/BFloat16.h>
// Make sure it's possible to forward declare Eigen::bfloat16
namespace Eigen {
struct bfloat16;
}
using Eigen::bfloat16;
float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
uint32_t low_mantissa) {
float dest;
uint32_t src = (sign << 31) + (exponent << 23) + (high_mantissa << 16) + low_mantissa;
memcpy(static_cast<void*>(&dest),
static_cast<const void*>(&src), sizeof(dest));
return dest;
}
void test_truncate(float input, float expected_truncation, float expected_rounding){
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne(input);
if ((numext::isnan)(input)){
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
return;
}
VERIFY_IS_EQUAL(expected_truncation, static_cast<float>(truncated));
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
}
void test_conversion()
{
using Eigen::bfloat16_impl::__bfloat16_raw;
// Conversion from float.
VERIFY_IS_EQUAL(bfloat16(1.0f).value, 0x3f80);
VERIFY_IS_EQUAL(bfloat16(0.5f).value, 0x3f00);
VERIFY_IS_EQUAL(bfloat16(0.33333f).value, 0x3eab);
VERIFY_IS_EQUAL(bfloat16(3.38e38f).value, 0x7f7e);
VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity.
// Verify round-to-nearest-even behavior.
float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00)));
float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01)));
float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02)));
VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00);
VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02);
// Conversion from int.
VERIFY_IS_EQUAL(bfloat16(-1).value, 0xbf80);
VERIFY_IS_EQUAL(bfloat16(0).value, 0x0000);
VERIFY_IS_EQUAL(bfloat16(1).value, 0x3f80);
VERIFY_IS_EQUAL(bfloat16(2).value, 0x4000);
VERIFY_IS_EQUAL(bfloat16(3).value, 0x4040);
VERIFY_IS_EQUAL(bfloat16(12).value, 0x4140);
// Conversion from bool.
VERIFY_IS_EQUAL(bfloat16(false).value, 0x0000);
VERIFY_IS_EQUAL(bfloat16(true).value, 0x3f80);
// Conversion to float.
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f);
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f);
// Zero representations
VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f));
VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f));
VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(-0.0f));
VERIFY_IS_EQUAL(bfloat16(0.0f).value, 0x0000);
VERIFY_IS_EQUAL(bfloat16(-0.0f).value, 0x8000);
// Flush denormals to zero
for (float denorm = -std::numeric_limits<float>::denorm_min();
denorm < std::numeric_limits<float>::denorm_min();
denorm = nextafterf(denorm, 1.0f)) {
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
if (std::signbit(denorm)) {
VERIFY_IS_EQUAL(bf_trunc.value, 0x8000);
} else {
VERIFY_IS_EQUAL(bf_trunc.value, 0x0000);
}
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm);
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
if (std::signbit(denorm)) {
VERIFY_IS_EQUAL(bf_round.value, 0x8000);
} else {
VERIFY_IS_EQUAL(bf_round.value, 0x0000);
}
}
// Default is zero
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
// Representable floats round trip via bfloat16
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-std::numeric_limits<float>::infinity())), -std::numeric_limits<float>::infinity());
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(std::numeric_limits<float>::infinity())), std::numeric_limits<float>::infinity());
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-1.0f)), -1.0f);
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.5f)), -0.5f);
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.0f)), -0.0f);
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(1.0f)), 1.0f);
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.5f)), 0.5f);
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.0f)), 0.0f);
// Truncate test
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0xf5c3),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x49, 0x0000));
test_truncate(
BinaryToFloat(1, 0x80, 0x48, 0xf5c3),
BinaryToFloat(1, 0x80, 0x48, 0x0000),
BinaryToFloat(1, 0x80, 0x49, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x8000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0xff, 0x00, 0x0001),
BinaryToFloat(0, 0xff, 0x40, 0x0000),
BinaryToFloat(0, 0xff, 0x40, 0x0000));
test_truncate(
BinaryToFloat(0, 0xff, 0x7f, 0xffff),
BinaryToFloat(0, 0xff, 0x40, 0x0000),
BinaryToFloat(0, 0xff, 0x40, 0x0000));
test_truncate(
BinaryToFloat(1, 0x80, 0x48, 0xc000),
BinaryToFloat(1, 0x80, 0x48, 0x0000),
BinaryToFloat(1, 0x80, 0x49, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x4000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x8000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0x00, 0x48, 0x8000),
BinaryToFloat(0, 0x00, 0x00, 0x0000),
BinaryToFloat(0, 0x00, 0x00, 0x0000));
test_truncate(
BinaryToFloat(0, 0x00, 0x7f, 0xc000),
BinaryToFloat(0, 0x00, 0x00, 0x0000),
BinaryToFloat(0, 0x00, 0x00, 0x0000));
// Conversion
Array<float,1,100> a;
for (int i = 0; i < 100; i++) a(i) = i + 1.25;
Array<bfloat16,1,100> b = a.cast<bfloat16>();
Array<float,1,100> c = b.cast<float>();
for (int i = 0; i < 100; ++i) {
VERIFY_LE(numext::abs(c(i) - a(i)), a(i) / 128);
}
// Epsilon
VERIFY_LE(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() + bfloat16(1.0f)));
VERIFY_IS_EQUAL(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() / bfloat16(2.0f) + bfloat16(1.0f)));
// Negate
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(3.0f)), -3.0f);
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4.5f)), 4.5f);
#if !EIGEN_COMP_MSVC
// Visual Studio errors out on divisions by 0
VERIFY((numext::isnan)(static_cast<float>(bfloat16(0.0 / 0.0))));
VERIFY((numext::isinf)(static_cast<float>(bfloat16(1.0 / 0.0))));
VERIFY((numext::isinf)(static_cast<float>(bfloat16(-1.0 / 0.0))));
// Visual Studio errors out on divisions by 0
VERIFY((numext::isnan)(bfloat16(0.0 / 0.0)));
VERIFY((numext::isinf)(bfloat16(1.0 / 0.0)));
VERIFY((numext::isinf)(bfloat16(-1.0 / 0.0)));
#endif
// NaNs and infinities.
VERIFY(!(numext::isinf)(static_cast<float>(bfloat16(3.38e38f)))); // Largest finite number.
VERIFY(!(numext::isnan)(static_cast<float>(bfloat16(0.0f))));
VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0xff80)))));
VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0xffc0)))));
VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0x7f80)))));
VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0x7fc0)))));
// Exactly same checks as above, just directly on the bfloat16 representation.
VERIFY(!(numext::isinf)(bfloat16(__bfloat16_raw(0x7bff))));
VERIFY(!(numext::isnan)(bfloat16(__bfloat16_raw(0x0000))));
VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0xff80))));
VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0xffc0))));
VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0x7f80))));
VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0x7fc0))));
}
void test_numtraits()
{
std::cout << "epsilon = " << NumTraits<bfloat16>::epsilon() << " (0x" << std::hex << NumTraits<bfloat16>::epsilon().value << ")" << std::endl;
std::cout << "highest = " << NumTraits<bfloat16>::highest() << " (0x" << std::hex << NumTraits<bfloat16>::highest().value << ")" << std::endl;
std::cout << "lowest = " << NumTraits<bfloat16>::lowest() << " (0x" << std::hex << NumTraits<bfloat16>::lowest().value << ")" << std::endl;
std::cout << "min = " << (std::numeric_limits<bfloat16>::min)() << " (0x" << std::hex << (std::numeric_limits<bfloat16>::min)().value << ")" << std::endl;
std::cout << "denorm min = " << (std::numeric_limits<bfloat16>::denorm_min)() << " (0x" << std::hex << (std::numeric_limits<bfloat16>::denorm_min)().value << ")" << std::endl;
std::cout << "infinity = " << NumTraits<bfloat16>::infinity() << " (0x" << std::hex << NumTraits<bfloat16>::infinity().value << ")" << std::endl;
std::cout << "quiet nan = " << NumTraits<bfloat16>::quiet_NaN() << " (0x" << std::hex << NumTraits<bfloat16>::quiet_NaN().value << ")" << std::endl;
std::cout << "signaling nan = " << std::numeric_limits<bfloat16>::signaling_NaN() << " (0x" << std::hex << std::numeric_limits<bfloat16>::signaling_NaN().value << ")" << std::endl;
VERIFY(NumTraits<bfloat16>::IsSigned);
VERIFY_IS_EQUAL( std::numeric_limits<bfloat16>::infinity().value, bfloat16(std::numeric_limits<float>::infinity()).value );
VERIFY_IS_EQUAL( std::numeric_limits<bfloat16>::quiet_NaN().value, bfloat16(std::numeric_limits<float>::quiet_NaN()).value );
VERIFY( (std::numeric_limits<bfloat16>::min)() > bfloat16(0.f) );
VERIFY( (std::numeric_limits<bfloat16>::denorm_min)() > bfloat16(0.f) );
VERIFY_IS_EQUAL( (std::numeric_limits<bfloat16>::denorm_min)()/bfloat16(2), bfloat16(0.f) );
}
void test_arithmetic()
{
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(2)), 4);
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(-2)), 0);
VERIFY_IS_APPROX(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f);
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2.0f) * bfloat16(-5.5f)), -11.0f);
VERIFY_IS_APPROX(static_cast<float>(bfloat16(1.0f) / bfloat16(3.0f)), 0.3339f);
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(4096.0f)), -4096.0f);
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4096.0f)), 4096.0f);
}
void test_comparison()
{
VERIFY(bfloat16(1.0f) > bfloat16(0.5f));
VERIFY(bfloat16(0.5f) < bfloat16(1.0f));
VERIFY(!(bfloat16(1.0f) < bfloat16(0.5f)));
VERIFY(!(bfloat16(0.5f) > bfloat16(1.0f)));
VERIFY(!(bfloat16(4.0f) > bfloat16(4.0f)));
VERIFY(!(bfloat16(4.0f) < bfloat16(4.0f)));
VERIFY(!(bfloat16(0.0f) < bfloat16(-0.0f)));
VERIFY(!(bfloat16(-0.0f) < bfloat16(0.0f)));
VERIFY(!(bfloat16(0.0f) > bfloat16(-0.0f)));
VERIFY(!(bfloat16(-0.0f) > bfloat16(0.0f)));
VERIFY(bfloat16(0.2f) > bfloat16(-1.0f));
VERIFY(bfloat16(-1.0f) < bfloat16(0.2f));
VERIFY(bfloat16(-16.0f) < bfloat16(-15.0f));
VERIFY(bfloat16(1.0f) == bfloat16(1.0f));
VERIFY(bfloat16(1.0f) != bfloat16(2.0f));
// Comparisons with NaNs and infinities.
#if !EIGEN_COMP_MSVC
// Visual Studio errors out on divisions by 0
VERIFY(!(bfloat16(0.0 / 0.0) == bfloat16(0.0 / 0.0)));
VERIFY(bfloat16(0.0 / 0.0) != bfloat16(0.0 / 0.0));
VERIFY(!(bfloat16(1.0) == bfloat16(0.0 / 0.0)));
VERIFY(!(bfloat16(1.0) < bfloat16(0.0 / 0.0)));
VERIFY(!(bfloat16(1.0) > bfloat16(0.0 / 0.0)));
VERIFY(bfloat16(1.0) != bfloat16(0.0 / 0.0));
VERIFY(bfloat16(1.0) < bfloat16(1.0 / 0.0));
VERIFY(bfloat16(1.0) > bfloat16(-1.0 / 0.0));
#endif
}
void test_basic_functions()
{
VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(3.5f))), 3.5f);
VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(3.5f))), 3.5f);
VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(-3.5f))), 3.5f);
VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(-3.5f))), 3.5f);
VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(3.5f))), 3.0f);
VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(3.5f))), 3.0f);
VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(-3.5f))), -4.0f);
VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(-3.5f))), -4.0f);
VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(3.5f))), 4.0f);
VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(3.5f))), 4.0f);
VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(-3.5f))), -3.0f);
VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(-3.5f))), -3.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(0.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(0.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(4.0f))), 2.0f);
VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(4.0f))), 2.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f);
VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f);
VERIFY_IS_EQUAL(static_cast<float>(numext::exp(bfloat16(0.0f))), 1.0f);
VERIFY_IS_EQUAL(static_cast<float>(exp(bfloat16(0.0f))), 1.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI));
VERIFY_IS_APPROX(static_cast<float>(exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI));
VERIFY_IS_EQUAL(static_cast<float>(numext::expm1(bfloat16(0.0f))), 0.0f);
VERIFY_IS_EQUAL(static_cast<float>(expm1(bfloat16(0.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::expm1(bfloat16(2.0f))), 6.375f);
VERIFY_IS_APPROX(static_cast<float>(expm1(bfloat16(2.0f))), 6.375f);
VERIFY_IS_EQUAL(static_cast<float>(numext::log(bfloat16(1.0f))), 0.0f);
VERIFY_IS_EQUAL(static_cast<float>(log(bfloat16(1.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::log(bfloat16(10.0f))), 2.296875f);
VERIFY_IS_APPROX(static_cast<float>(log(bfloat16(10.0f))), 2.296875f);
VERIFY_IS_EQUAL(static_cast<float>(numext::log1p(bfloat16(0.0f))), 0.0f);
VERIFY_IS_EQUAL(static_cast<float>(log1p(bfloat16(0.0f))), 0.0f);
VERIFY_IS_APPROX(static_cast<float>(numext::log1p(bfloat16(10.0f))), 2.390625f);
VERIFY_IS_APPROX(static_cast<float>(log1p(bfloat16(10.0f))), 2.390625f);
}
void test_trigonometric_functions()
{
VERIFY_IS_APPROX(numext::cos(bfloat16(0.0f)), bfloat16(cosf(0.0f)));
VERIFY_IS_APPROX(cos(bfloat16(0.0f)), bfloat16(cosf(0.0f)));
VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI)), bfloat16(cosf(EIGEN_PI)));
// VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI/2)), bfloat16(cosf(EIGEN_PI/2)));
// VERIFY_IS_APPROX(numext::cos(bfloat16(3*EIGEN_PI/2)), bfloat16(cosf(3*EIGEN_PI/2)));
VERIFY_IS_APPROX(numext::cos(bfloat16(3.5f)), bfloat16(cosf(3.5f)));
VERIFY_IS_APPROX(numext::sin(bfloat16(0.0f)), bfloat16(sinf(0.0f)));
VERIFY_IS_APPROX(sin(bfloat16(0.0f)), bfloat16(sinf(0.0f)));
// VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI)), bfloat16(sinf(EIGEN_PI)));
VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI/2)), bfloat16(sinf(EIGEN_PI/2)));
VERIFY_IS_APPROX(numext::sin(bfloat16(3*EIGEN_PI/2)), bfloat16(sinf(3*EIGEN_PI/2)));
VERIFY_IS_APPROX(numext::sin(bfloat16(3.5f)), bfloat16(sinf(3.5f)));
VERIFY_IS_APPROX(numext::tan(bfloat16(0.0f)), bfloat16(tanf(0.0f)));
VERIFY_IS_APPROX(tan(bfloat16(0.0f)), bfloat16(tanf(0.0f)));
// VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI)), bfloat16(tanf(EIGEN_PI)));
// VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI/2)), bfloat16(tanf(EIGEN_PI/2)));
// VERIFY_IS_APPROX(numext::tan(bfloat16(3*EIGEN_PI/2)), bfloat16(tanf(3*EIGEN_PI/2)));
VERIFY_IS_APPROX(numext::tan(bfloat16(3.5f)), bfloat16(tanf(3.5f)));
}
void test_array()
{
typedef Array<bfloat16,1,Dynamic> ArrayXh;
Index size = internal::random<Index>(1,10);
Index i = internal::random<Index>(0,size-1);
ArrayXh a1 = ArrayXh::Random(size), a2 = ArrayXh::Random(size);
VERIFY_IS_APPROX( a1+a1, bfloat16(2)*a1 );
VERIFY( (a1.abs() >= bfloat16(0)).all() );
VERIFY_IS_APPROX( (a1*a1).sqrt(), a1.abs() );
VERIFY( ((a1.min)(a2) <= (a1.max)(a2)).all() );
a1(i) = bfloat16(-10.);
VERIFY_IS_EQUAL( a1.minCoeff(), bfloat16(-10.) );
a1(i) = bfloat16(10.);
VERIFY_IS_EQUAL( a1.maxCoeff(), bfloat16(10.) );
std::stringstream ss;
ss << a1;
}
void test_product()
{
typedef Matrix<bfloat16,Dynamic,Dynamic> MatrixXh;
Index rows = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
Index cols = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
Index depth = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
MatrixXh Ah = MatrixXh::Random(rows,depth);
MatrixXh Bh = MatrixXh::Random(depth,cols);
MatrixXh Ch = MatrixXh::Random(rows,cols);
MatrixXf Af = Ah.cast<float>();
MatrixXf Bf = Bh.cast<float>();
MatrixXf Cf = Ch.cast<float>();
VERIFY_IS_APPROX(Ch.noalias()+=Ah*Bh, (Cf.noalias()+=Af*Bf).cast<bfloat16>());
}
EIGEN_DECLARE_TEST(bfloat16_float)
{
CALL_SUBTEST(test_numtraits());
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST(test_conversion());
CALL_SUBTEST(test_arithmetic());
CALL_SUBTEST(test_comparison());
CALL_SUBTEST(test_basic_functions());
CALL_SUBTEST(test_trigonometric_functions());
CALL_SUBTEST(test_array());
CALL_SUBTEST(test_product());
}
}

View File

@ -435,6 +435,7 @@ EIGEN_TEST_SCALAR_TEST_OVERLOAD(unsigned long long)
EIGEN_TEST_SCALAR_TEST_OVERLOAD(float)
EIGEN_TEST_SCALAR_TEST_OVERLOAD(double)
EIGEN_TEST_SCALAR_TEST_OVERLOAD(half)
EIGEN_TEST_SCALAR_TEST_OVERLOAD(bfloat16)
#undef EIGEN_TEST_SCALAR_TEST_OVERLOAD

View File

@ -45,6 +45,7 @@ EIGEN_DECLARE_TEST(numext) {
CALL_SUBTEST( check_abs<long>() );
CALL_SUBTEST( check_abs<unsigned long>() );
CALL_SUBTEST( check_abs<half>() );
CALL_SUBTEST( check_abs<bfloat16>() );
CALL_SUBTEST( check_abs<float>() );
CALL_SUBTEST( check_abs<double>() );
CALL_SUBTEST( check_abs<long double>() );

View File

@ -836,6 +836,7 @@ EIGEN_DECLARE_TEST(packetmath)
#ifdef EIGEN_PACKET_MATH_SSE_H
CALL_SUBTEST_14(( packetmath<bool,internal::packet_traits<bool>::type>() ));
#endif
CALL_SUBTEST_15(( packetmath<bfloat16,internal::packet_traits<bfloat16>::type>() ));
g_first_pass = false;
}
}

View File

@ -50,6 +50,7 @@ T apply_bit_op(Bits a, Bits b, Func f) {
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,float) \
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,double) \
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,half) \
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,bfloat16) \
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,std::complex<float>) \
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,std::complex<double>)

View File

@ -101,6 +101,17 @@ Eigen::half RandomToTypeUniform<Eigen::half>(uint64_t* state, uint64_t stream) {
return result - Eigen::half(1.0f);
}
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Eigen::bfloat16 RandomToTypeUniform<Eigen::bfloat16>(uint64_t* state, uint64_t stream) {
Eigen::bfloat16 result;
// Generate 7 random bits for the mantissa
unsigned rnd = PCG_XSH_RS_generator(state, stream);
result.value = static_cast<uint16_t>(rnd & 0x7fu);
// Set the exponent
result.value |= (static_cast<uint16_t>(127) << 7);
// Return the final result
return result - Eigen::bfloat16(1.0f);
}
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float RandomToTypeUniform<float>(uint64_t* state, uint64_t stream) {

View File

@ -62,6 +62,7 @@ namespace Eigen {
#include "src/SpecialFunctions/BesselFunctionsImpl.h"
#include "src/SpecialFunctions/BesselFunctionsPacketMath.h"
#include "src/SpecialFunctions/BesselFunctionsBFloat16.h"
#include "src/SpecialFunctions/BesselFunctionsHalf.h"
#include "src/SpecialFunctions/BesselFunctionsFunctors.h"
#include "src/SpecialFunctions/BesselFunctionsArrayAPI.h"
@ -70,6 +71,7 @@ namespace Eigen {
#include "src/SpecialFunctions/HipVectorCompatibility.h"
#endif
#include "src/SpecialFunctions/SpecialFunctionsPacketMath.h"
#include "src/SpecialFunctions/SpecialFunctionsBFloat16.h"
#include "src/SpecialFunctions/SpecialFunctionsHalf.h"
#include "src/SpecialFunctions/SpecialFunctionsFunctors.h"
#include "src/SpecialFunctions/SpecialFunctionsArrayAPI.h"

View File

@ -0,0 +1,68 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_BESSELFUNCTIONS_BFLOAT16_H
#define EIGEN_BESSELFUNCTIONS_BFLOAT16_H
namespace Eigen {
namespace numext {
#if EIGEN_HAS_C99_MATH
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_i0(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0e(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_i0e(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_i1(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1e(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_i1e(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j0(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_j0(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j1(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_j1(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y0(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_y0(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y1(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_y1(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_k0(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0e(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_k0e(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_k1(static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1e(const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::bessel_k1e(static_cast<float>(x)));
}
#endif
} // end namespace numext
} // end namespace Eigen
#endif // EIGEN_BESSELFUNCTIONS_BFLOAT16_H

View File

@ -0,0 +1,58 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_SPECIALFUNCTIONS_BFLOAT16_H
#define EIGEN_SPECIALFUNCTIONS_BFLOAT16_H
namespace Eigen {
namespace numext {
#if EIGEN_HAS_C99_MATH
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 lgamma(const Eigen::bfloat16& a) {
return Eigen::bfloat16(Eigen::numext::lgamma(static_cast<float>(a)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 digamma(const Eigen::bfloat16& a) {
return Eigen::bfloat16(Eigen::numext::digamma(static_cast<float>(a)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 zeta(const Eigen::bfloat16& x, const Eigen::bfloat16& q) {
return Eigen::bfloat16(Eigen::numext::zeta(static_cast<float>(x), static_cast<float>(q)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 polygamma(const Eigen::bfloat16& n, const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::polygamma(static_cast<float>(n), static_cast<float>(x)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erf(const Eigen::bfloat16& a) {
return Eigen::bfloat16(Eigen::numext::erf(static_cast<float>(a)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erfc(const Eigen::bfloat16& a) {
return Eigen::bfloat16(Eigen::numext::erfc(static_cast<float>(a)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 ndtri(const Eigen::bfloat16& a) {
return Eigen::bfloat16(Eigen::numext::ndtri(static_cast<float>(a)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma(const Eigen::bfloat16& a, const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::igamma(static_cast<float>(a), static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma_der_a(const Eigen::bfloat16& a, const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::igamma_der_a(static_cast<float>(a), static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 gamma_sample_der_alpha(const Eigen::bfloat16& alpha, const Eigen::bfloat16& sample) {
return Eigen::bfloat16(Eigen::numext::gamma_sample_der_alpha(static_cast<float>(alpha), static_cast<float>(sample)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igammac(const Eigen::bfloat16& a, const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 betainc(const Eigen::bfloat16& a, const Eigen::bfloat16& b, const Eigen::bfloat16& x) {
return Eigen::bfloat16(Eigen::numext::betainc(static_cast<float>(a), static_cast<float>(b), static_cast<float>(x)));
}
#endif
} // end namespace numext
} // end namespace Eigen
#endif // EIGEN_SPECIALFUNCTIONS_BFLOAT16_H

View File

@ -511,6 +511,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_reduction) {
CALL_SUBTEST(( test_simple_reductions<float,ColMajor>() ));
CALL_SUBTEST(( test_simple_reductions<float,RowMajor>() ));
CALL_SUBTEST(( test_simple_reductions<Eigen::half,ColMajor>() ));
CALL_SUBTEST(( test_simple_reductions<Eigen::bfloat16,ColMajor>() ));
CALL_SUBTEST(test_reductions_in_expr<ColMajor>());
CALL_SUBTEST(test_reductions_in_expr<RowMajor>());
CALL_SUBTEST(test_full_reductions<ColMajor>());