mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-24 14:45:14 +08:00
Support BFloat16 in Eigen
This commit is contained in:
parent
6b9c92fe7e
commit
386d809bde
@ -241,13 +241,22 @@ if(NOT MSVC)
|
||||
|
||||
option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF)
|
||||
if(EIGEN_TEST_AVX512)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma -DEIGEN_ENABLE_AVX512")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma")
|
||||
if (NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fabi-version=6")
|
||||
endif()
|
||||
message(STATUS "Enabling AVX512 in tests/examples")
|
||||
endif()
|
||||
|
||||
option(EIGEN_TEST_AVX512DQ "Enable/Disable AVX512DQ in tests/examples" OFF)
|
||||
if(EIGEN_TEST_AVX512DQ)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512dq")
|
||||
if (NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fabi-version=6")
|
||||
endif()
|
||||
message(STATUS "Enabling AVX512DQ in tests/examples")
|
||||
endif()
|
||||
|
||||
option(EIGEN_TEST_F16C "Enable/Disable F16C in tests/examples" OFF)
|
||||
if(EIGEN_TEST_F16C)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
|
||||
|
@ -51,6 +51,10 @@
|
||||
#define EIGEN_HAS_GPU_FP16
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_HAS_CUDA_BF16) || defined(EIGEN_HAS_HIP_BF16)
|
||||
#define EIGEN_HAS_GPU_BF16
|
||||
#endif
|
||||
|
||||
#if (defined _OPENMP) && (!defined EIGEN_DONT_PARALLELIZE)
|
||||
#define EIGEN_HAS_OPENMP
|
||||
#endif
|
||||
@ -163,6 +167,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/Default/ConjHelper.h"
|
||||
// Generic half float support
|
||||
#include "src/Core/arch/Default/Half.h"
|
||||
#include "src/Core/arch/Default/BFloat16.h"
|
||||
#include "src/Core/arch/Default/TypeCasting.h"
|
||||
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
||||
|
||||
|
@ -29,6 +29,12 @@ namespace internal {
|
||||
#define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \
|
||||
const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X))
|
||||
|
||||
#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \
|
||||
const Packet16bf p16bf_##NAME = pset1<Packet16bf>(X)
|
||||
|
||||
#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
|
||||
const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
|
||||
|
||||
// Natural logarithm
|
||||
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
|
||||
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
|
||||
@ -128,6 +134,11 @@ plog<Packet16f>(const Packet16f& _x) {
|
||||
p16f_nan),
|
||||
p16f_minus_inf);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Exponential function. Works by writing "x = m*log(2) + r" where
|
||||
@ -253,6 +264,10 @@ pexp<Packet8d>(const Packet8d& _x) {
|
||||
return pmax(pmul(x, e), _x);
|
||||
}*/
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
|
||||
// Functions for sqrt.
|
||||
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
|
||||
@ -303,12 +318,18 @@ template <>
|
||||
EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) {
|
||||
return _mm512_sqrt_ps(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
|
||||
return _mm512_sqrt_pd(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
|
||||
return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
|
||||
}
|
||||
|
||||
// prsqrt for float.
|
||||
#if defined(EIGEN_VECTORIZE_AVX512ER)
|
||||
|
||||
@ -316,7 +337,6 @@ template <>
|
||||
EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
|
||||
return _mm512_rsqrt28_ps(x);
|
||||
}
|
||||
|
||||
#elif EIGEN_FAST_MATH
|
||||
|
||||
template <>
|
||||
@ -347,8 +367,7 @@ prsqrt<Packet16f>(const Packet16f& _x) {
|
||||
// For other arguments, choose the output of the intrinsic. This will
|
||||
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
|
||||
return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx);
|
||||
}
|
||||
|
||||
}
|
||||
#else
|
||||
|
||||
template <>
|
||||
@ -356,9 +375,13 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
|
||||
_EIGEN_DECLARE_CONST_Packet16f(one, 1.0f);
|
||||
return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
|
||||
return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
|
||||
}
|
||||
|
||||
// prsqrt for double.
|
||||
#if EIGEN_FAST_MATH
|
||||
template <>
|
||||
@ -412,10 +435,20 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
|
||||
return generic_plog1p(_x);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
|
||||
return generic_expm1(_x);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@ -427,18 +460,33 @@ psin<Packet16f>(const Packet16f& _x) {
|
||||
return psin_float(_x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
|
||||
pcos<Packet16f>(const Packet16f& _x) {
|
||||
return pcos_float(_x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
|
||||
ptanh<Packet16f>(const Packet16f& _x) {
|
||||
return internal::generic_fast_tanh_float(_x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
|
||||
return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -362,6 +362,25 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// Helper function for bit packing snippet of low precision comparison.
|
||||
// It packs the flags from 32x16 to 16x16.
|
||||
EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
|
||||
// Split data into small pieces and handle with AVX instructions
|
||||
// to guarantee internal order of vector.
|
||||
// Operation:
|
||||
// dst[15:0] := Saturate16(rf[31:0])
|
||||
// dst[31:16] := Saturate16(rf[63:32])
|
||||
// ...
|
||||
// dst[255:240] := Saturate16(rf[255:224])
|
||||
__m256i lo = _mm256_castps_si256(extract256<0>(rf));
|
||||
__m256i hi = _mm256_castps_si256(extract256<1>(rf));
|
||||
__m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
|
||||
_mm256_extractf128_si256(lo, 1));
|
||||
__m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
|
||||
_mm256_extractf128_si256(hi, 1));
|
||||
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
|
||||
@ -1342,15 +1361,7 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa
|
||||
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
|
||||
Packet16f af = half2float(a);
|
||||
Packet16f bf = half2float(b);
|
||||
Packet16f rf = pcmp_eq(af, bf);
|
||||
// Pack the 32-bit flags into 16-bits flags.
|
||||
__m256i lo = _mm256_castps_si256(extract256<0>(rf));
|
||||
__m256i hi = _mm256_castps_si256(extract256<1>(rf));
|
||||
__m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
|
||||
_mm256_extractf128_si256(lo, 1));
|
||||
__m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
|
||||
_mm256_extractf128_si256(hi, 1));
|
||||
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
|
||||
return Pack32To16(pcmp_eq(af, bf));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
|
||||
@ -1607,6 +1618,493 @@ ptranspose(PacketBlock<Packet16h,4>& kernel) {
|
||||
kernel.packet[3] = pload<Packet16h>(out[3]);
|
||||
}
|
||||
|
||||
typedef union {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512BF16
|
||||
__m256bh bh;
|
||||
#endif
|
||||
Packet8i i; // __m256i;
|
||||
} Packet16bf;
|
||||
|
||||
template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
|
||||
|
||||
template <>
|
||||
struct packet_traits<bfloat16> : default_packet_traits {
|
||||
typedef Packet16bf type;
|
||||
// There is no half-size packet for current Packet16bf.
|
||||
// TODO: support as SSE/AVX path.
|
||||
typedef Packet16bf half;
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
size = 16,
|
||||
HasHalfPacket = 0,
|
||||
HasBlend = 0,
|
||||
HasInsert = 1,
|
||||
HasSin = EIGEN_FAST_MATH,
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
|
||||
#ifdef EIGEN_VECTORIZE_AVX512DQ
|
||||
HasLog = 1,
|
||||
HasLog1p = 1,
|
||||
HasExpm1 = 1,
|
||||
HasNdtri = 1,
|
||||
HasBessel = 1,
|
||||
#endif
|
||||
HasExp = 1,
|
||||
HasSqrt = EIGEN_FAST_MATH,
|
||||
HasRsqrt = EIGEN_FAST_MATH,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
#endif
|
||||
HasDiv = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet16bf>
|
||||
{
|
||||
typedef bfloat16 type;
|
||||
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
||||
typedef Packet16bf half;
|
||||
};
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
|
||||
Packet16bf r;
|
||||
r.i = _mm256_set1_epi16(from.value);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
|
||||
bfloat16 t;
|
||||
t.value = static_cast<unsigned short>(_mm256_extract_epi16(from.i, 0));
|
||||
return t;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
|
||||
Packet16bf r;
|
||||
r.i = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
|
||||
Packet16bf r;
|
||||
r.i = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
|
||||
const Packet16bf& from) {
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
|
||||
const Packet16bf& from) {
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16bf
|
||||
ploaddup<Packet16bf>(const bfloat16* from) {
|
||||
Packet16bf r;
|
||||
unsigned short a = from[0].value;
|
||||
unsigned short b = from[1].value;
|
||||
unsigned short c = from[2].value;
|
||||
unsigned short d = from[3].value;
|
||||
unsigned short e = from[4].value;
|
||||
unsigned short f = from[5].value;
|
||||
unsigned short g = from[6].value;
|
||||
unsigned short h = from[7].value;
|
||||
r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
|
||||
return r;
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16bf
|
||||
ploadquad(const bfloat16* from) {
|
||||
Packet16bf r;
|
||||
unsigned short a = from[0].value;
|
||||
unsigned short b = from[1].value;
|
||||
unsigned short c = from[2].value;
|
||||
unsigned short d = from[3].value;
|
||||
r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
|
||||
return r;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
|
||||
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
|
||||
}
|
||||
|
||||
// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
|
||||
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||
Packet16bf r;
|
||||
|
||||
// Flush input denormals value to zero with hardware capability.
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||
__m512 flush = _mm512_and_ps(a, a);
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
||||
|
||||
#if defined(EIGEN_VECTORIZE_AVX512BF16)
|
||||
r.bh = _mm512_cvtneps_pbh(flush);
|
||||
#else
|
||||
__m512i t;
|
||||
__m512i input = _mm512_castps_si512(flush);
|
||||
__m512i nan = _mm512_set1_epi32(0x7fc0);
|
||||
|
||||
// uint32_t lsb = (input >> 16) & 1;
|
||||
t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
|
||||
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||
t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
|
||||
// input += rounding_bias;
|
||||
t = _mm512_add_epi32(t, input);
|
||||
// input = input >> 16;
|
||||
t = _mm512_srli_epi32(t, 16);
|
||||
|
||||
// Check NaN before converting back to bf16
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q);
|
||||
t = _mm512_mask_blend_epi32(mask, nan, t);
|
||||
|
||||
// output.value = static_cast<uint16_t>(input);
|
||||
r.i = _mm512_cvtepi32_epi16(t);
|
||||
#endif
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
|
||||
Packet16bf r;
|
||||
r.i = ptrue<Packet8i>(a.i);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = por<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = pxor<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = pand<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf r;
|
||||
r.i = pandnot<Packet8i>(a.i, b.i);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
|
||||
const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
// Input mask is expected to be all 0/1, handle it with 8-bit
|
||||
// intrinsic for performance.
|
||||
Packet16bf r;
|
||||
r.i = _mm256_blendv_epi8(b.i, a.i, mask.i);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
Packet16bf result;
|
||||
result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
|
||||
Packet16bf sign_mask;
|
||||
sign_mask.i = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
|
||||
Packet16bf result;
|
||||
result.i = _mm256_xor_si256(a.i, sign_mask.i);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
|
||||
return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
return F32ToBf16(pmul<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
|
||||
const Packet16bf& b) {
|
||||
return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
|
||||
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) {
|
||||
return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) {
|
||||
return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) {
|
||||
return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
|
||||
__m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,
|
||||
14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
|
||||
|
||||
Packet16bf res;
|
||||
// Swap hi and lo first because shuffle is in 128-bit lanes.
|
||||
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
|
||||
// Shuffle 8-bit values in src within 2*128-bit lanes.
|
||||
res.i = _mm256_shuffle_epi8(a.i, m);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
|
||||
Index stride) {
|
||||
Packet16bf result;
|
||||
result.i = _mm256_set_epi16(
|
||||
from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
|
||||
from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
|
||||
from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
|
||||
from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
|
||||
const Packet16bf& from,
|
||||
Index stride) {
|
||||
EIGEN_ALIGN64 bfloat16 aux[16];
|
||||
pstore(aux, from);
|
||||
to[stride*0].value = aux[0].value;
|
||||
to[stride*1].value = aux[1].value;
|
||||
to[stride*2].value = aux[2].value;
|
||||
to[stride*3].value = aux[3].value;
|
||||
to[stride*4].value = aux[4].value;
|
||||
to[stride*5].value = aux[5].value;
|
||||
to[stride*6].value = aux[6].value;
|
||||
to[stride*7].value = aux[7].value;
|
||||
to[stride*8].value = aux[8].value;
|
||||
to[stride*9].value = aux[9].value;
|
||||
to[stride*10].value = aux[10].value;
|
||||
to[stride*11].value = aux[11].value;
|
||||
to[stride*12].value = aux[12].value;
|
||||
to[stride*13].value = aux[13].value;
|
||||
to[stride*14].value = aux[14].value;
|
||||
to[stride*15].value = aux[15].value;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
|
||||
__m256i a = kernel.packet[0].i;
|
||||
__m256i b = kernel.packet[1].i;
|
||||
__m256i c = kernel.packet[2].i;
|
||||
__m256i d = kernel.packet[3].i;
|
||||
__m256i e = kernel.packet[4].i;
|
||||
__m256i f = kernel.packet[5].i;
|
||||
__m256i g = kernel.packet[6].i;
|
||||
__m256i h = kernel.packet[7].i;
|
||||
__m256i i = kernel.packet[8].i;
|
||||
__m256i j = kernel.packet[9].i;
|
||||
__m256i k = kernel.packet[10].i;
|
||||
__m256i l = kernel.packet[11].i;
|
||||
__m256i m = kernel.packet[12].i;
|
||||
__m256i n = kernel.packet[13].i;
|
||||
__m256i o = kernel.packet[14].i;
|
||||
__m256i p = kernel.packet[15].i;
|
||||
|
||||
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
|
||||
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
|
||||
__m256i ef_07 = _mm256_unpacklo_epi16(e, f);
|
||||
__m256i gh_07 = _mm256_unpacklo_epi16(g, h);
|
||||
__m256i ij_07 = _mm256_unpacklo_epi16(i, j);
|
||||
__m256i kl_07 = _mm256_unpacklo_epi16(k, l);
|
||||
__m256i mn_07 = _mm256_unpacklo_epi16(m, n);
|
||||
__m256i op_07 = _mm256_unpacklo_epi16(o, p);
|
||||
|
||||
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
|
||||
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
|
||||
__m256i ef_8f = _mm256_unpackhi_epi16(e, f);
|
||||
__m256i gh_8f = _mm256_unpackhi_epi16(g, h);
|
||||
__m256i ij_8f = _mm256_unpackhi_epi16(i, j);
|
||||
__m256i kl_8f = _mm256_unpackhi_epi16(k, l);
|
||||
__m256i mn_8f = _mm256_unpackhi_epi16(m, n);
|
||||
__m256i op_8f = _mm256_unpackhi_epi16(o, p);
|
||||
|
||||
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
|
||||
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
|
||||
__m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
|
||||
__m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
|
||||
__m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
|
||||
__m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
|
||||
__m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
|
||||
__m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
|
||||
|
||||
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
|
||||
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
|
||||
__m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
|
||||
__m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
|
||||
__m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
|
||||
__m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
|
||||
__m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
|
||||
__m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
|
||||
|
||||
__m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
|
||||
__m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
|
||||
__m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
|
||||
__m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
|
||||
__m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
|
||||
__m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
|
||||
__m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
|
||||
__m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
|
||||
__m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
|
||||
__m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
|
||||
__m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
|
||||
__m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
|
||||
__m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
|
||||
__m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
|
||||
__m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
|
||||
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
|
||||
|
||||
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
|
||||
kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
|
||||
0x20);
|
||||
kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
|
||||
0x20);
|
||||
kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
|
||||
0x20);
|
||||
kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
|
||||
0x20);
|
||||
kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
|
||||
0x20);
|
||||
kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
|
||||
0x20);
|
||||
kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
|
||||
0x20);
|
||||
kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
|
||||
0x20);
|
||||
kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
|
||||
0x20);
|
||||
kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
|
||||
0x20);
|
||||
kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
|
||||
0x20);
|
||||
kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
|
||||
0x20);
|
||||
kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
|
||||
0x20);
|
||||
kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
|
||||
0x20);
|
||||
kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
|
||||
0x20);
|
||||
kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
|
||||
0x20);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
|
||||
__m256i a = kernel.packet[0].i;
|
||||
__m256i b = kernel.packet[1].i;
|
||||
__m256i c = kernel.packet[2].i;
|
||||
__m256i d = kernel.packet[3].i;
|
||||
|
||||
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
|
||||
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
|
||||
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
|
||||
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
|
||||
|
||||
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
|
||||
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
|
||||
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
|
||||
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
|
||||
|
||||
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
|
||||
kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
|
||||
kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
|
||||
kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
|
||||
kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
@ -40,6 +40,32 @@ template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packe
|
||||
return float2half(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<bfloat16, float> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
|
||||
return Bf16ToF32(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<float, bfloat16> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
|
||||
return F32ToBf16(a);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
703
Eigen/src/Core/arch/Default/BFloat16.h
Normal file
703
Eigen/src/Core/arch/Default/BFloat16.h
Normal file
@ -0,0 +1,703 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
|
||||
#ifndef EIGEN_BFLOAT16_H
|
||||
#define EIGEN_BFLOAT16_H
|
||||
|
||||
#if __cplusplus > 199711L
|
||||
#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
|
||||
#else
|
||||
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
struct bfloat16;
|
||||
|
||||
namespace bfloat16_impl {
|
||||
|
||||
// Make our own __bfloat16_raw definition.
|
||||
struct __bfloat16_raw {
|
||||
EIGEN_DEVICE_FUNC __bfloat16_raw() : value(0) {}
|
||||
explicit EIGEN_DEVICE_FUNC __bfloat16_raw(unsigned short raw) : value(raw) {}
|
||||
unsigned short value;
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
|
||||
|
||||
struct bfloat16_base : public __bfloat16_raw {
|
||||
EIGEN_DEVICE_FUNC bfloat16_base() {}
|
||||
EIGEN_DEVICE_FUNC bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
|
||||
};
|
||||
|
||||
} // namespace bfloat16_impl
|
||||
|
||||
// Class definition.
|
||||
struct bfloat16 : public bfloat16_impl::bfloat16_base {
|
||||
|
||||
typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
|
||||
|
||||
EIGEN_DEVICE_FUNC bfloat16() {}
|
||||
|
||||
EIGEN_DEVICE_FUNC bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
|
||||
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(bool b)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
|
||||
template<class T>
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val))) {}
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {}
|
||||
// Following the convention of numpy, converting between complex and
|
||||
// float will lead to loss of imag value.
|
||||
// Single precision complex.
|
||||
typedef std::complex<float> complex64;
|
||||
// Double precision complex.
|
||||
typedef std::complex<double> complex128;
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {}
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
||||
// +0.0 and -0.0 become false, everything else becomes true.
|
||||
return (value & 0x7fff) != 0;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const {
|
||||
return static_cast<signed char>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const {
|
||||
return static_cast<unsigned char>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const {
|
||||
return static_cast<short>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const {
|
||||
return static_cast<unsigned short>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const {
|
||||
return static_cast<int>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const {
|
||||
return static_cast<unsigned int>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const {
|
||||
return static_cast<long>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const {
|
||||
return static_cast<unsigned long>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const {
|
||||
return static_cast<long long>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
|
||||
return static_cast<unsigned long long>(bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
|
||||
return bfloat16_impl::bfloat16_to_float(*this);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
|
||||
return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const {
|
||||
return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const {
|
||||
return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const {
|
||||
return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace std {
|
||||
template<>
|
||||
struct numeric_limits<Eigen::bfloat16> {
|
||||
static const bool is_specialized = true;
|
||||
static const bool is_signed = true;
|
||||
static const bool is_integer = false;
|
||||
static const bool is_exact = false;
|
||||
static const bool has_infinity = true;
|
||||
static const bool has_quiet_NaN = true;
|
||||
static const bool has_signaling_NaN = true;
|
||||
static const float_denorm_style has_denorm = numeric_limits<float>::has_denorm;
|
||||
static const bool has_denorm_loss = numeric_limits<float>::has_denorm_loss;
|
||||
static const std::float_round_style round_style = numeric_limits<float>::round_style;
|
||||
static const bool is_iec559 = false;
|
||||
static const bool is_bounded = true;
|
||||
static const bool is_modulo = false;
|
||||
static const int digits = 8;
|
||||
static const int digits10 = 2;
|
||||
static const int max_digits10 = 4;
|
||||
static const int radix = 2;
|
||||
static const int min_exponent = numeric_limits<float>::min_exponent;
|
||||
static const int min_exponent10 = numeric_limits<float>::min_exponent10;
|
||||
static const int max_exponent = numeric_limits<float>::max_exponent;
|
||||
static const int max_exponent10 = numeric_limits<float>::max_exponent10;
|
||||
static const bool traps = numeric_limits<float>::traps;
|
||||
static const bool tinyness_before = numeric_limits<float>::tinyness_before;
|
||||
|
||||
static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
|
||||
static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
|
||||
static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
|
||||
static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
|
||||
static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
|
||||
static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
|
||||
static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
|
||||
static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
|
||||
static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
|
||||
};
|
||||
|
||||
// If std::numeric_limits<T> is specialized, should also specialize
|
||||
// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
|
||||
// std::numeric_limits<const volatile T>
|
||||
// https://stackoverflow.com/a/16519653/
|
||||
template<>
|
||||
struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
|
||||
template<>
|
||||
struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
|
||||
template<>
|
||||
struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
|
||||
} // end namespace std
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace bfloat16_impl {
|
||||
|
||||
// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
|
||||
// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
|
||||
// of the functions, while the latter can only deal with one of them.
|
||||
#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
|
||||
|
||||
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
|
||||
// We need to provide emulated *host-side* BF16 operators for clang.
|
||||
#pragma push_macro("EIGEN_DEVICE_FUNC")
|
||||
#undef EIGEN_DEVICE_FUNC
|
||||
#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
|
||||
#define EIGEN_DEVICE_FUNC __host__
|
||||
#else // both host and device need emulated ops.
|
||||
#define EIGEN_DEVICE_FUNC __host__ __device__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Definitions for CPUs, mostly working through conversion
|
||||
// to/from fp32.
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(float(a) + float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
|
||||
return bfloat16(float(a) + static_cast<float>(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
|
||||
return bfloat16(static_cast<float>(a) + float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(float(a) * float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(float(a) - float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(float(a) / float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
|
||||
bfloat16 result;
|
||||
result.value = a.value ^ 0x8000;
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
|
||||
a = bfloat16(float(a) + float(b));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
|
||||
a = bfloat16(float(a) * float(b));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
|
||||
a = bfloat16(float(a) - float(b));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
|
||||
a = bfloat16(float(a) / float(b));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
|
||||
a += bfloat16(1);
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
|
||||
a -= bfloat16(1);
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
|
||||
bfloat16 original_value = a;
|
||||
++a;
|
||||
return original_value;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
|
||||
bfloat16 original_value = a;
|
||||
--a;
|
||||
return original_value;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
|
||||
return numext::equal_strict(float(a),float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
|
||||
return numext::not_equal_strict(float(a), float(b));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
|
||||
return float(a) < float(b);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
|
||||
return float(a) <= float(b);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
|
||||
return float(a) > float(b);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
|
||||
return float(a) >= float(b);
|
||||
}
|
||||
|
||||
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
|
||||
#pragma pop_macro("EIGEN_DEVICE_FUNC")
|
||||
#endif
|
||||
#endif // Emulate support for bfloat16 floats
|
||||
|
||||
// Division by an index. Do it in full float precision to avoid accuracy
|
||||
// issues in converting the denominator to bfloat16.
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
|
||||
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
|
||||
__bfloat16_raw output;
|
||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
|
||||
output.value = 0x7FC0;
|
||||
return output;
|
||||
} else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
|
||||
// Flush denormal to +/- 0.
|
||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
||||
return output;
|
||||
}
|
||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
output.value = p[0];
|
||||
#else
|
||||
output.value = p[1];
|
||||
#endif
|
||||
return output;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value) {
|
||||
__bfloat16_raw h;
|
||||
h.value = value;
|
||||
return h;
|
||||
}
|
||||
|
||||
union float32_bits {
|
||||
unsigned int u;
|
||||
float f;
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) {
|
||||
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
|
||||
// Nothing to do here
|
||||
#else
|
||||
unsigned int input;
|
||||
float32_bits f;
|
||||
f.f = ff;
|
||||
input = f.u;
|
||||
__bfloat16_raw output;
|
||||
|
||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
|
||||
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
|
||||
// this makes sure after truncation we don't end up with an inf.
|
||||
//
|
||||
// qNaN magic: All exponent bits set + most significant bit of fraction
|
||||
// set.
|
||||
output.value = 0x7fc0;
|
||||
} else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
|
||||
// Flush denormal to +/- 0.0
|
||||
output.value = std::signbit(ff) ? 0x8000 : 0;
|
||||
} else {
|
||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||
// reduces expected error when we convert a large number of floats. Here
|
||||
// is how it works:
|
||||
//
|
||||
// Definitions:
|
||||
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
||||
// with the following tags:
|
||||
//
|
||||
// Sign | Exp (8 bits) | Frac (23 bits)
|
||||
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
||||
//
|
||||
// S: Sign bit.
|
||||
// E: Exponent bits.
|
||||
// F: First 6 bits of fraction.
|
||||
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
||||
// rest of the float32. This is also the 7th bit of fraction
|
||||
// R: Rounding bit, 8th bit of fraction.
|
||||
// T: Sticky bits, rest of fraction, 15 bits.
|
||||
//
|
||||
// To round half to nearest even, there are 3 cases where we want to round
|
||||
// down (simply truncate the result of the bits away, which consists of
|
||||
// rounding bit and sticky bits) and two cases where we want to round up
|
||||
// (truncate then add one to the result).
|
||||
//
|
||||
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
||||
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
||||
// truncates the last 16 bits away.
|
||||
//
|
||||
// To understand how it works, we can analyze this algorithm case by case:
|
||||
//
|
||||
// 1. L = 0, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input may create any carry, depending on
|
||||
// whether there is any value set to 1 in T bits.
|
||||
// - R may be set to 1 if there is a carry.
|
||||
// - L remains 0.
|
||||
// - Note that this case also handles Inf and -Inf, where all fraction
|
||||
// bits, including L, R and Ts are all 0. The output remains Inf after
|
||||
// this algorithm.
|
||||
//
|
||||
// 2. L = 1, R = 0:
|
||||
// Expect: round down, this is less than half value.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits but
|
||||
// adds 1 to rounding bit.
|
||||
// - L remains 1.
|
||||
//
|
||||
// 3. L = 0, R = 1, all of T are 0:
|
||||
// Expect: round down, this is exactly at half, the result is already
|
||||
// even (L=0).
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input sets all sticky bits to 1, but
|
||||
// doesn't create a carry.
|
||||
// - R remains 1.
|
||||
// - L remains 0.
|
||||
//
|
||||
// 4. L = 1, R = 1:
|
||||
// Expect: round up, this is exactly at half, the result needs to be
|
||||
// round to the next even number.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 1 = 0x8000
|
||||
// - Adding rounding bias to input doesn't change sticky bits, but
|
||||
// creates a carry from rounding bit.
|
||||
// - The carry sets L to 0, creates another carry bit and propagate
|
||||
// forward to F bits.
|
||||
// - If all the F bits are 1, a carry then propagates to the exponent
|
||||
// bits, which then creates the minimum value with the next exponent
|
||||
// value. Note that we won't have the case where exponents are all 1,
|
||||
// since that's either a NaN (handled in the other if condition) or inf
|
||||
// (handled in case 1).
|
||||
//
|
||||
// 5. L = 0, R = 1, any of T is 1:
|
||||
// Expect: round up, this is greater than half.
|
||||
//
|
||||
// Algorithm:
|
||||
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
||||
// - Adding rounding bias to input creates a carry from sticky bits,
|
||||
// sets rounding bit to 0, then create another carry.
|
||||
// - The second carry sets L to 1.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// Exact half value that is already even:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
||||
//
|
||||
// This falls into case 3. We truncate the rest of 16 bits and no
|
||||
// carry is created into F and L:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
// Exact half value, round to next even number:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// which then propagates into L and F:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
||||
//
|
||||
//
|
||||
// Max denormal value round to min normal value:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Output:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
||||
//
|
||||
// Max normal value round to Inf:
|
||||
// Input:
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
||||
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
||||
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
||||
//
|
||||
// This falls into case 4. We create a carry from R and T,
|
||||
// propagate into L and F, which then propagates into exponent
|
||||
// bits:
|
||||
//
|
||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||
// S E E E E E E E E F F F F F F L
|
||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
||||
//
|
||||
//
|
||||
// Least significant bit of resulting bfloat.
|
||||
unsigned int lsb = (input >> 16) & 1;
|
||||
unsigned int rounding_bias = 0x7fff + lsb;
|
||||
input += rounding_bias;
|
||||
output.value = static_cast<unsigned short>(input >> 16);
|
||||
}
|
||||
return output;
|
||||
#endif
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
|
||||
float result = 0;
|
||||
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
q[0] = h.value;
|
||||
#else
|
||||
q[1] = h.value;
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
// --- standard functions ---
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
|
||||
return std::isinf EIGEN_NOT_A_MACRO(float(a));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
|
||||
return std::isnan EIGEN_NOT_A_MACRO(float(a));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
|
||||
return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
|
||||
bfloat16 result;
|
||||
result.value = a.value & 0x7FFF;
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
|
||||
return bfloat16(::expf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
|
||||
return bfloat16(numext::expm1(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
|
||||
return bfloat16(::logf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
|
||||
return bfloat16(numext::log1p(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
|
||||
return bfloat16(::log10f(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
|
||||
return bfloat16(::sqrtf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(::powf(float(a), float(b)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
|
||||
return bfloat16(::sinf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
|
||||
return bfloat16(::cosf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
|
||||
return bfloat16(::tanf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
|
||||
return bfloat16(::asinf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
|
||||
return bfloat16(::acosf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
|
||||
return bfloat16(::atanf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
|
||||
return bfloat16(::sinhf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
|
||||
return bfloat16(::coshf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
|
||||
return bfloat16(::tanhf(float(a)));
|
||||
}
|
||||
#if EIGEN_HAS_CXX11_MATH
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
|
||||
return bfloat16(::asinh(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
|
||||
return bfloat16(::acosh(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
|
||||
return bfloat16(::atanh(float(a)));
|
||||
}
|
||||
#endif
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
|
||||
return bfloat16(::floorf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
|
||||
return bfloat16(::ceilf(float(a)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(::fmodf(float(a), float(b)));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
|
||||
const float f1 = static_cast<float>(a);
|
||||
const float f2 = static_cast<float>(b);
|
||||
return f2 < f1 ? b : a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
|
||||
const float f1 = static_cast<float>(a);
|
||||
const float f2 = static_cast<float>(b);
|
||||
return f1 < f2 ? b : a;
|
||||
}
|
||||
|
||||
#ifndef EIGEN_NO_IO
|
||||
EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
|
||||
os << static_cast<float>(v);
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace bfloat16_impl
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<>
|
||||
struct random_default_impl<bfloat16, false, false>
|
||||
{
|
||||
static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
|
||||
{
|
||||
return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
|
||||
}
|
||||
static inline bfloat16 run()
|
||||
{
|
||||
return run(bfloat16(-1.f), bfloat16(1.f));
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
template<> struct NumTraits<Eigen::bfloat16>
|
||||
: GenericNumTraits<Eigen::bfloat16>
|
||||
{
|
||||
enum {
|
||||
IsSigned = true,
|
||||
IsInteger = false,
|
||||
IsComplex = false,
|
||||
RequireInitialization = false
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() { return Eigen::bfloat16(5e-2f); }
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
|
||||
return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace std {
|
||||
|
||||
#if __cplusplus > 199711L
|
||||
template <>
|
||||
struct hash<Eigen::bfloat16> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
|
||||
return hash<float>()(static_cast<float>(a));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // end namespace std
|
||||
|
||||
|
||||
namespace Eigen {
|
||||
namespace numext {
|
||||
|
||||
template<>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
bool (isnan)(const Eigen::bfloat16& h) {
|
||||
return (bfloat16_impl::isnan)(h);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
bool (isinf)(const Eigen::bfloat16& h) {
|
||||
return (bfloat16_impl::isinf)(h);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
bool (isfinite)(const Eigen::bfloat16& h) {
|
||||
return (bfloat16_impl::isfinite)(h);
|
||||
}
|
||||
|
||||
} // namespace Eigen
|
||||
} // namespace numext
|
||||
|
||||
#endif // EIGEN_BFLOAT16_H
|
@ -71,6 +71,49 @@ template<>
|
||||
struct functor_traits<scalar_cast_op<Eigen::half, float> >
|
||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||
|
||||
|
||||
template<>
|
||||
struct scalar_cast_op<float, Eigen::bfloat16> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||
typedef Eigen::bfloat16 result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
|
||||
return Eigen::bfloat16(a);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
|
||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||
|
||||
|
||||
template<>
|
||||
struct scalar_cast_op<int, Eigen::bfloat16> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||
typedef Eigen::bfloat16 result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
|
||||
return Eigen::bfloat16(static_cast<float>(a));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
|
||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||
|
||||
|
||||
template<>
|
||||
struct scalar_cast_op<Eigen::bfloat16, float> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||
typedef float result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
|
||||
return static_cast<float>(a);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
|
||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -288,6 +288,9 @@
|
||||
#ifdef __AVX512ER__
|
||||
#define EIGEN_VECTORIZE_AVX512ER
|
||||
#endif
|
||||
#ifdef __AVX512BF16__
|
||||
#define EIGEN_VECTORIZE_AVX512BF16
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
@ -286,6 +286,7 @@ ei_add_test(ctorleak)
|
||||
ei_add_test(mpl2only)
|
||||
ei_add_test(inplace_decomposition)
|
||||
ei_add_test(half_float)
|
||||
ei_add_test(bfloat16_float)
|
||||
ei_add_test(array_of_string)
|
||||
ei_add_test(num_dimensions)
|
||||
ei_add_test(stl_iterators)
|
||||
|
399
test/bfloat16_float.cpp
Normal file
399
test/bfloat16_float.cpp
Normal file
@ -0,0 +1,399 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <math.h>
|
||||
|
||||
#include "main.h"
|
||||
|
||||
#include <Eigen/src/Core/arch/Default/BFloat16.h>
|
||||
|
||||
// Make sure it's possible to forward declare Eigen::bfloat16
|
||||
namespace Eigen {
|
||||
struct bfloat16;
|
||||
}
|
||||
|
||||
using Eigen::bfloat16;
|
||||
|
||||
float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
||||
uint32_t low_mantissa) {
|
||||
float dest;
|
||||
uint32_t src = (sign << 31) + (exponent << 23) + (high_mantissa << 16) + low_mantissa;
|
||||
memcpy(static_cast<void*>(&dest),
|
||||
static_cast<const void*>(&src), sizeof(dest));
|
||||
return dest;
|
||||
}
|
||||
|
||||
void test_truncate(float input, float expected_truncation, float expected_rounding){
|
||||
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
|
||||
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne(input);
|
||||
if ((numext::isnan)(input)){
|
||||
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
|
||||
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
|
||||
return;
|
||||
}
|
||||
VERIFY_IS_EQUAL(expected_truncation, static_cast<float>(truncated));
|
||||
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
|
||||
}
|
||||
|
||||
void test_conversion()
|
||||
{
|
||||
using Eigen::bfloat16_impl::__bfloat16_raw;
|
||||
|
||||
// Conversion from float.
|
||||
VERIFY_IS_EQUAL(bfloat16(1.0f).value, 0x3f80);
|
||||
VERIFY_IS_EQUAL(bfloat16(0.5f).value, 0x3f00);
|
||||
VERIFY_IS_EQUAL(bfloat16(0.33333f).value, 0x3eab);
|
||||
VERIFY_IS_EQUAL(bfloat16(3.38e38f).value, 0x7f7e);
|
||||
VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity.
|
||||
|
||||
// Verify round-to-nearest-even behavior.
|
||||
float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00)));
|
||||
float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01)));
|
||||
float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02)));
|
||||
VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00);
|
||||
VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02);
|
||||
|
||||
// Conversion from int.
|
||||
VERIFY_IS_EQUAL(bfloat16(-1).value, 0xbf80);
|
||||
VERIFY_IS_EQUAL(bfloat16(0).value, 0x0000);
|
||||
VERIFY_IS_EQUAL(bfloat16(1).value, 0x3f80);
|
||||
VERIFY_IS_EQUAL(bfloat16(2).value, 0x4000);
|
||||
VERIFY_IS_EQUAL(bfloat16(3).value, 0x4040);
|
||||
VERIFY_IS_EQUAL(bfloat16(12).value, 0x4140);
|
||||
|
||||
// Conversion from bool.
|
||||
VERIFY_IS_EQUAL(bfloat16(false).value, 0x0000);
|
||||
VERIFY_IS_EQUAL(bfloat16(true).value, 0x3f80);
|
||||
|
||||
// Conversion to float.
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f);
|
||||
|
||||
// Zero representations
|
||||
VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f));
|
||||
VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f));
|
||||
VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(-0.0f));
|
||||
VERIFY_IS_EQUAL(bfloat16(0.0f).value, 0x0000);
|
||||
VERIFY_IS_EQUAL(bfloat16(-0.0f).value, 0x8000);
|
||||
|
||||
// Flush denormals to zero
|
||||
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
||||
denorm < std::numeric_limits<float>::denorm_min();
|
||||
denorm = nextafterf(denorm, 1.0f)) {
|
||||
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
|
||||
if (std::signbit(denorm)) {
|
||||
VERIFY_IS_EQUAL(bf_trunc.value, 0x8000);
|
||||
} else {
|
||||
VERIFY_IS_EQUAL(bf_trunc.value, 0x0000);
|
||||
}
|
||||
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
|
||||
if (std::signbit(denorm)) {
|
||||
VERIFY_IS_EQUAL(bf_round.value, 0x8000);
|
||||
} else {
|
||||
VERIFY_IS_EQUAL(bf_round.value, 0x0000);
|
||||
}
|
||||
}
|
||||
|
||||
// Default is zero
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
|
||||
|
||||
// Representable floats round trip via bfloat16
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-std::numeric_limits<float>::infinity())), -std::numeric_limits<float>::infinity());
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(std::numeric_limits<float>::infinity())), std::numeric_limits<float>::infinity());
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-1.0f)), -1.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.5f)), -0.5f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.0f)), -0.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(1.0f)), 1.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.5f)), 0.5f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.0f)), 0.0f);
|
||||
|
||||
// Truncate test
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0xf5c3),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x49, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(1, 0x80, 0x48, 0xf5c3),
|
||||
BinaryToFloat(1, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(1, 0x80, 0x49, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x8000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0xff, 0x00, 0x0001),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0xff, 0x7f, 0xffff),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(1, 0x80, 0x48, 0xc000),
|
||||
BinaryToFloat(1, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(1, 0x80, 0x49, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x4000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x8000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x00, 0x48, 0x8000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x00, 0x7f, 0xc000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000));
|
||||
|
||||
// Conversion
|
||||
Array<float,1,100> a;
|
||||
for (int i = 0; i < 100; i++) a(i) = i + 1.25;
|
||||
Array<bfloat16,1,100> b = a.cast<bfloat16>();
|
||||
Array<float,1,100> c = b.cast<float>();
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
VERIFY_LE(numext::abs(c(i) - a(i)), a(i) / 128);
|
||||
}
|
||||
|
||||
// Epsilon
|
||||
VERIFY_LE(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() + bfloat16(1.0f)));
|
||||
VERIFY_IS_EQUAL(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() / bfloat16(2.0f) + bfloat16(1.0f)));
|
||||
|
||||
// Negate
|
||||
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(3.0f)), -3.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4.5f)), 4.5f);
|
||||
|
||||
|
||||
#if !EIGEN_COMP_MSVC
|
||||
// Visual Studio errors out on divisions by 0
|
||||
VERIFY((numext::isnan)(static_cast<float>(bfloat16(0.0 / 0.0))));
|
||||
VERIFY((numext::isinf)(static_cast<float>(bfloat16(1.0 / 0.0))));
|
||||
VERIFY((numext::isinf)(static_cast<float>(bfloat16(-1.0 / 0.0))));
|
||||
|
||||
// Visual Studio errors out on divisions by 0
|
||||
VERIFY((numext::isnan)(bfloat16(0.0 / 0.0)));
|
||||
VERIFY((numext::isinf)(bfloat16(1.0 / 0.0)));
|
||||
VERIFY((numext::isinf)(bfloat16(-1.0 / 0.0)));
|
||||
#endif
|
||||
|
||||
// NaNs and infinities.
|
||||
VERIFY(!(numext::isinf)(static_cast<float>(bfloat16(3.38e38f)))); // Largest finite number.
|
||||
VERIFY(!(numext::isnan)(static_cast<float>(bfloat16(0.0f))));
|
||||
VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0xff80)))));
|
||||
VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0xffc0)))));
|
||||
VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0x7f80)))));
|
||||
VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0x7fc0)))));
|
||||
|
||||
// Exactly same checks as above, just directly on the bfloat16 representation.
|
||||
VERIFY(!(numext::isinf)(bfloat16(__bfloat16_raw(0x7bff))));
|
||||
VERIFY(!(numext::isnan)(bfloat16(__bfloat16_raw(0x0000))));
|
||||
VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0xff80))));
|
||||
VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0xffc0))));
|
||||
VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0x7f80))));
|
||||
VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0x7fc0))));
|
||||
}
|
||||
|
||||
void test_numtraits()
|
||||
{
|
||||
std::cout << "epsilon = " << NumTraits<bfloat16>::epsilon() << " (0x" << std::hex << NumTraits<bfloat16>::epsilon().value << ")" << std::endl;
|
||||
std::cout << "highest = " << NumTraits<bfloat16>::highest() << " (0x" << std::hex << NumTraits<bfloat16>::highest().value << ")" << std::endl;
|
||||
std::cout << "lowest = " << NumTraits<bfloat16>::lowest() << " (0x" << std::hex << NumTraits<bfloat16>::lowest().value << ")" << std::endl;
|
||||
std::cout << "min = " << (std::numeric_limits<bfloat16>::min)() << " (0x" << std::hex << (std::numeric_limits<bfloat16>::min)().value << ")" << std::endl;
|
||||
std::cout << "denorm min = " << (std::numeric_limits<bfloat16>::denorm_min)() << " (0x" << std::hex << (std::numeric_limits<bfloat16>::denorm_min)().value << ")" << std::endl;
|
||||
std::cout << "infinity = " << NumTraits<bfloat16>::infinity() << " (0x" << std::hex << NumTraits<bfloat16>::infinity().value << ")" << std::endl;
|
||||
std::cout << "quiet nan = " << NumTraits<bfloat16>::quiet_NaN() << " (0x" << std::hex << NumTraits<bfloat16>::quiet_NaN().value << ")" << std::endl;
|
||||
std::cout << "signaling nan = " << std::numeric_limits<bfloat16>::signaling_NaN() << " (0x" << std::hex << std::numeric_limits<bfloat16>::signaling_NaN().value << ")" << std::endl;
|
||||
|
||||
VERIFY(NumTraits<bfloat16>::IsSigned);
|
||||
|
||||
VERIFY_IS_EQUAL( std::numeric_limits<bfloat16>::infinity().value, bfloat16(std::numeric_limits<float>::infinity()).value );
|
||||
VERIFY_IS_EQUAL( std::numeric_limits<bfloat16>::quiet_NaN().value, bfloat16(std::numeric_limits<float>::quiet_NaN()).value );
|
||||
VERIFY( (std::numeric_limits<bfloat16>::min)() > bfloat16(0.f) );
|
||||
VERIFY( (std::numeric_limits<bfloat16>::denorm_min)() > bfloat16(0.f) );
|
||||
VERIFY_IS_EQUAL( (std::numeric_limits<bfloat16>::denorm_min)()/bfloat16(2), bfloat16(0.f) );
|
||||
}
|
||||
|
||||
void test_arithmetic()
|
||||
{
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(2)), 4);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(-2)), 0);
|
||||
VERIFY_IS_APPROX(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2.0f) * bfloat16(-5.5f)), -11.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(bfloat16(1.0f) / bfloat16(3.0f)), 0.3339f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(4096.0f)), -4096.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4096.0f)), 4096.0f);
|
||||
}
|
||||
|
||||
void test_comparison()
|
||||
{
|
||||
VERIFY(bfloat16(1.0f) > bfloat16(0.5f));
|
||||
VERIFY(bfloat16(0.5f) < bfloat16(1.0f));
|
||||
VERIFY(!(bfloat16(1.0f) < bfloat16(0.5f)));
|
||||
VERIFY(!(bfloat16(0.5f) > bfloat16(1.0f)));
|
||||
|
||||
VERIFY(!(bfloat16(4.0f) > bfloat16(4.0f)));
|
||||
VERIFY(!(bfloat16(4.0f) < bfloat16(4.0f)));
|
||||
|
||||
VERIFY(!(bfloat16(0.0f) < bfloat16(-0.0f)));
|
||||
VERIFY(!(bfloat16(-0.0f) < bfloat16(0.0f)));
|
||||
VERIFY(!(bfloat16(0.0f) > bfloat16(-0.0f)));
|
||||
VERIFY(!(bfloat16(-0.0f) > bfloat16(0.0f)));
|
||||
|
||||
VERIFY(bfloat16(0.2f) > bfloat16(-1.0f));
|
||||
VERIFY(bfloat16(-1.0f) < bfloat16(0.2f));
|
||||
VERIFY(bfloat16(-16.0f) < bfloat16(-15.0f));
|
||||
|
||||
VERIFY(bfloat16(1.0f) == bfloat16(1.0f));
|
||||
VERIFY(bfloat16(1.0f) != bfloat16(2.0f));
|
||||
|
||||
// Comparisons with NaNs and infinities.
|
||||
#if !EIGEN_COMP_MSVC
|
||||
// Visual Studio errors out on divisions by 0
|
||||
VERIFY(!(bfloat16(0.0 / 0.0) == bfloat16(0.0 / 0.0)));
|
||||
VERIFY(bfloat16(0.0 / 0.0) != bfloat16(0.0 / 0.0));
|
||||
|
||||
VERIFY(!(bfloat16(1.0) == bfloat16(0.0 / 0.0)));
|
||||
VERIFY(!(bfloat16(1.0) < bfloat16(0.0 / 0.0)));
|
||||
VERIFY(!(bfloat16(1.0) > bfloat16(0.0 / 0.0)));
|
||||
VERIFY(bfloat16(1.0) != bfloat16(0.0 / 0.0));
|
||||
|
||||
VERIFY(bfloat16(1.0) < bfloat16(1.0 / 0.0));
|
||||
VERIFY(bfloat16(1.0) > bfloat16(-1.0 / 0.0));
|
||||
#endif
|
||||
}
|
||||
|
||||
void test_basic_functions()
|
||||
{
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(3.5f))), 3.5f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(3.5f))), 3.5f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(-3.5f))), 3.5f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(-3.5f))), 3.5f);
|
||||
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(3.5f))), 3.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(3.5f))), 3.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(-3.5f))), -4.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(-3.5f))), -4.0f);
|
||||
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(3.5f))), 4.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(3.5f))), 4.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(-3.5f))), -3.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(-3.5f))), -3.0f);
|
||||
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(0.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(0.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(4.0f))), 2.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(4.0f))), 2.0f);
|
||||
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f);
|
||||
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::exp(bfloat16(0.0f))), 1.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(exp(bfloat16(0.0f))), 1.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI));
|
||||
VERIFY_IS_APPROX(static_cast<float>(exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI));
|
||||
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::expm1(bfloat16(0.0f))), 0.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(expm1(bfloat16(0.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::expm1(bfloat16(2.0f))), 6.375f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(expm1(bfloat16(2.0f))), 6.375f);
|
||||
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::log(bfloat16(1.0f))), 0.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(log(bfloat16(1.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::log(bfloat16(10.0f))), 2.296875f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(log(bfloat16(10.0f))), 2.296875f);
|
||||
|
||||
VERIFY_IS_EQUAL(static_cast<float>(numext::log1p(bfloat16(0.0f))), 0.0f);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(log1p(bfloat16(0.0f))), 0.0f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(numext::log1p(bfloat16(10.0f))), 2.390625f);
|
||||
VERIFY_IS_APPROX(static_cast<float>(log1p(bfloat16(10.0f))), 2.390625f);
|
||||
}
|
||||
|
||||
void test_trigonometric_functions()
|
||||
{
|
||||
VERIFY_IS_APPROX(numext::cos(bfloat16(0.0f)), bfloat16(cosf(0.0f)));
|
||||
VERIFY_IS_APPROX(cos(bfloat16(0.0f)), bfloat16(cosf(0.0f)));
|
||||
VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI)), bfloat16(cosf(EIGEN_PI)));
|
||||
// VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI/2)), bfloat16(cosf(EIGEN_PI/2)));
|
||||
// VERIFY_IS_APPROX(numext::cos(bfloat16(3*EIGEN_PI/2)), bfloat16(cosf(3*EIGEN_PI/2)));
|
||||
VERIFY_IS_APPROX(numext::cos(bfloat16(3.5f)), bfloat16(cosf(3.5f)));
|
||||
|
||||
VERIFY_IS_APPROX(numext::sin(bfloat16(0.0f)), bfloat16(sinf(0.0f)));
|
||||
VERIFY_IS_APPROX(sin(bfloat16(0.0f)), bfloat16(sinf(0.0f)));
|
||||
// VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI)), bfloat16(sinf(EIGEN_PI)));
|
||||
VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI/2)), bfloat16(sinf(EIGEN_PI/2)));
|
||||
VERIFY_IS_APPROX(numext::sin(bfloat16(3*EIGEN_PI/2)), bfloat16(sinf(3*EIGEN_PI/2)));
|
||||
VERIFY_IS_APPROX(numext::sin(bfloat16(3.5f)), bfloat16(sinf(3.5f)));
|
||||
|
||||
VERIFY_IS_APPROX(numext::tan(bfloat16(0.0f)), bfloat16(tanf(0.0f)));
|
||||
VERIFY_IS_APPROX(tan(bfloat16(0.0f)), bfloat16(tanf(0.0f)));
|
||||
// VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI)), bfloat16(tanf(EIGEN_PI)));
|
||||
// VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI/2)), bfloat16(tanf(EIGEN_PI/2)));
|
||||
// VERIFY_IS_APPROX(numext::tan(bfloat16(3*EIGEN_PI/2)), bfloat16(tanf(3*EIGEN_PI/2)));
|
||||
VERIFY_IS_APPROX(numext::tan(bfloat16(3.5f)), bfloat16(tanf(3.5f)));
|
||||
}
|
||||
|
||||
void test_array()
|
||||
{
|
||||
typedef Array<bfloat16,1,Dynamic> ArrayXh;
|
||||
Index size = internal::random<Index>(1,10);
|
||||
Index i = internal::random<Index>(0,size-1);
|
||||
ArrayXh a1 = ArrayXh::Random(size), a2 = ArrayXh::Random(size);
|
||||
VERIFY_IS_APPROX( a1+a1, bfloat16(2)*a1 );
|
||||
VERIFY( (a1.abs() >= bfloat16(0)).all() );
|
||||
VERIFY_IS_APPROX( (a1*a1).sqrt(), a1.abs() );
|
||||
|
||||
VERIFY( ((a1.min)(a2) <= (a1.max)(a2)).all() );
|
||||
a1(i) = bfloat16(-10.);
|
||||
VERIFY_IS_EQUAL( a1.minCoeff(), bfloat16(-10.) );
|
||||
a1(i) = bfloat16(10.);
|
||||
VERIFY_IS_EQUAL( a1.maxCoeff(), bfloat16(10.) );
|
||||
|
||||
std::stringstream ss;
|
||||
ss << a1;
|
||||
}
|
||||
|
||||
void test_product()
|
||||
{
|
||||
typedef Matrix<bfloat16,Dynamic,Dynamic> MatrixXh;
|
||||
Index rows = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
|
||||
Index cols = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
|
||||
Index depth = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
|
||||
MatrixXh Ah = MatrixXh::Random(rows,depth);
|
||||
MatrixXh Bh = MatrixXh::Random(depth,cols);
|
||||
MatrixXh Ch = MatrixXh::Random(rows,cols);
|
||||
MatrixXf Af = Ah.cast<float>();
|
||||
MatrixXf Bf = Bh.cast<float>();
|
||||
MatrixXf Cf = Ch.cast<float>();
|
||||
VERIFY_IS_APPROX(Ch.noalias()+=Ah*Bh, (Cf.noalias()+=Af*Bf).cast<bfloat16>());
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(bfloat16_float)
|
||||
{
|
||||
CALL_SUBTEST(test_numtraits());
|
||||
for(int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST(test_conversion());
|
||||
CALL_SUBTEST(test_arithmetic());
|
||||
CALL_SUBTEST(test_comparison());
|
||||
CALL_SUBTEST(test_basic_functions());
|
||||
CALL_SUBTEST(test_trigonometric_functions());
|
||||
CALL_SUBTEST(test_array());
|
||||
CALL_SUBTEST(test_product());
|
||||
}
|
||||
}
|
@ -435,6 +435,7 @@ EIGEN_TEST_SCALAR_TEST_OVERLOAD(unsigned long long)
|
||||
EIGEN_TEST_SCALAR_TEST_OVERLOAD(float)
|
||||
EIGEN_TEST_SCALAR_TEST_OVERLOAD(double)
|
||||
EIGEN_TEST_SCALAR_TEST_OVERLOAD(half)
|
||||
EIGEN_TEST_SCALAR_TEST_OVERLOAD(bfloat16)
|
||||
|
||||
#undef EIGEN_TEST_SCALAR_TEST_OVERLOAD
|
||||
|
||||
|
@ -45,6 +45,7 @@ EIGEN_DECLARE_TEST(numext) {
|
||||
CALL_SUBTEST( check_abs<long>() );
|
||||
CALL_SUBTEST( check_abs<unsigned long>() );
|
||||
CALL_SUBTEST( check_abs<half>() );
|
||||
CALL_SUBTEST( check_abs<bfloat16>() );
|
||||
CALL_SUBTEST( check_abs<float>() );
|
||||
CALL_SUBTEST( check_abs<double>() );
|
||||
CALL_SUBTEST( check_abs<long double>() );
|
||||
|
@ -836,6 +836,7 @@ EIGEN_DECLARE_TEST(packetmath)
|
||||
#ifdef EIGEN_PACKET_MATH_SSE_H
|
||||
CALL_SUBTEST_14(( packetmath<bool,internal::packet_traits<bool>::type>() ));
|
||||
#endif
|
||||
CALL_SUBTEST_15(( packetmath<bfloat16,internal::packet_traits<bfloat16>::type>() ));
|
||||
g_first_pass = false;
|
||||
}
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ T apply_bit_op(Bits a, Bits b, Func f) {
|
||||
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,float) \
|
||||
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,double) \
|
||||
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,half) \
|
||||
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,bfloat16) \
|
||||
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,std::complex<float>) \
|
||||
EIGEN_TEST_MAKE_BITWISE2(OP,FUNC,std::complex<double>)
|
||||
|
||||
|
@ -101,6 +101,17 @@ Eigen::half RandomToTypeUniform<Eigen::half>(uint64_t* state, uint64_t stream) {
|
||||
return result - Eigen::half(1.0f);
|
||||
}
|
||||
|
||||
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Eigen::bfloat16 RandomToTypeUniform<Eigen::bfloat16>(uint64_t* state, uint64_t stream) {
|
||||
Eigen::bfloat16 result;
|
||||
// Generate 7 random bits for the mantissa
|
||||
unsigned rnd = PCG_XSH_RS_generator(state, stream);
|
||||
result.value = static_cast<uint16_t>(rnd & 0x7fu);
|
||||
// Set the exponent
|
||||
result.value |= (static_cast<uint16_t>(127) << 7);
|
||||
// Return the final result
|
||||
return result - Eigen::bfloat16(1.0f);
|
||||
}
|
||||
|
||||
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
float RandomToTypeUniform<float>(uint64_t* state, uint64_t stream) {
|
||||
|
@ -62,6 +62,7 @@ namespace Eigen {
|
||||
|
||||
#include "src/SpecialFunctions/BesselFunctionsImpl.h"
|
||||
#include "src/SpecialFunctions/BesselFunctionsPacketMath.h"
|
||||
#include "src/SpecialFunctions/BesselFunctionsBFloat16.h"
|
||||
#include "src/SpecialFunctions/BesselFunctionsHalf.h"
|
||||
#include "src/SpecialFunctions/BesselFunctionsFunctors.h"
|
||||
#include "src/SpecialFunctions/BesselFunctionsArrayAPI.h"
|
||||
@ -70,6 +71,7 @@ namespace Eigen {
|
||||
#include "src/SpecialFunctions/HipVectorCompatibility.h"
|
||||
#endif
|
||||
#include "src/SpecialFunctions/SpecialFunctionsPacketMath.h"
|
||||
#include "src/SpecialFunctions/SpecialFunctionsBFloat16.h"
|
||||
#include "src/SpecialFunctions/SpecialFunctionsHalf.h"
|
||||
#include "src/SpecialFunctions/SpecialFunctionsFunctors.h"
|
||||
#include "src/SpecialFunctions/SpecialFunctionsArrayAPI.h"
|
||||
|
@ -0,0 +1,68 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_BESSELFUNCTIONS_BFLOAT16_H
|
||||
#define EIGEN_BESSELFUNCTIONS_BFLOAT16_H
|
||||
|
||||
namespace Eigen {
|
||||
namespace numext {
|
||||
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_i0(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i0e(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_i0e(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_i1(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_i1e(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_i1e(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j0(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_j0(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_j1(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_j1(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y0(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_y0(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_y1(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_y1(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_k0(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k0e(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_k0e(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_k1(static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bessel_k1e(const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::bessel_k1e(static_cast<float>(x)));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace numext
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_BESSELFUNCTIONS_BFLOAT16_H
|
@ -0,0 +1,58 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_SPECIALFUNCTIONS_BFLOAT16_H
|
||||
#define EIGEN_SPECIALFUNCTIONS_BFLOAT16_H
|
||||
|
||||
namespace Eigen {
|
||||
namespace numext {
|
||||
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 lgamma(const Eigen::bfloat16& a) {
|
||||
return Eigen::bfloat16(Eigen::numext::lgamma(static_cast<float>(a)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 digamma(const Eigen::bfloat16& a) {
|
||||
return Eigen::bfloat16(Eigen::numext::digamma(static_cast<float>(a)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 zeta(const Eigen::bfloat16& x, const Eigen::bfloat16& q) {
|
||||
return Eigen::bfloat16(Eigen::numext::zeta(static_cast<float>(x), static_cast<float>(q)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 polygamma(const Eigen::bfloat16& n, const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::polygamma(static_cast<float>(n), static_cast<float>(x)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erf(const Eigen::bfloat16& a) {
|
||||
return Eigen::bfloat16(Eigen::numext::erf(static_cast<float>(a)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 erfc(const Eigen::bfloat16& a) {
|
||||
return Eigen::bfloat16(Eigen::numext::erfc(static_cast<float>(a)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 ndtri(const Eigen::bfloat16& a) {
|
||||
return Eigen::bfloat16(Eigen::numext::ndtri(static_cast<float>(a)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma(const Eigen::bfloat16& a, const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::igamma(static_cast<float>(a), static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igamma_der_a(const Eigen::bfloat16& a, const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::igamma_der_a(static_cast<float>(a), static_cast<float>(x)));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 gamma_sample_der_alpha(const Eigen::bfloat16& alpha, const Eigen::bfloat16& sample) {
|
||||
return Eigen::bfloat16(Eigen::numext::gamma_sample_der_alpha(static_cast<float>(alpha), static_cast<float>(sample)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 igammac(const Eigen::bfloat16& a, const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 betainc(const Eigen::bfloat16& a, const Eigen::bfloat16& b, const Eigen::bfloat16& x) {
|
||||
return Eigen::bfloat16(Eigen::numext::betainc(static_cast<float>(a), static_cast<float>(b), static_cast<float>(x)));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace numext
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_SPECIALFUNCTIONS_BFLOAT16_H
|
@ -511,6 +511,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_reduction) {
|
||||
CALL_SUBTEST(( test_simple_reductions<float,ColMajor>() ));
|
||||
CALL_SUBTEST(( test_simple_reductions<float,RowMajor>() ));
|
||||
CALL_SUBTEST(( test_simple_reductions<Eigen::half,ColMajor>() ));
|
||||
CALL_SUBTEST(( test_simple_reductions<Eigen::bfloat16,ColMajor>() ));
|
||||
CALL_SUBTEST(test_reductions_in_expr<ColMajor>());
|
||||
CALL_SUBTEST(test_reductions_in_expr<RowMajor>());
|
||||
CALL_SUBTEST(test_full_reductions<ColMajor>());
|
||||
|
Loading…
Reference in New Issue
Block a user