Cleaned up the fp16 code a little more

This commit is contained in:
Benoit Steiner 2016-05-24 22:43:26 -07:00
parent cb26784d07
commit d041a528da
2 changed files with 47 additions and 55 deletions

View File

@ -10,20 +10,12 @@
#ifndef EIGEN_PACKET_MATH_HALF_CUDA_H #ifndef EIGEN_PACKET_MATH_HALF_CUDA_H
#define EIGEN_PACKET_MATH_HALF_CUDA_H #define EIGEN_PACKET_MATH_HALF_CUDA_H
//#if defined(EIGEN_HAS_CUDA_FP16)
// Make sure this is only available when targeting a GPU: we don't want to
// introduce conflicts between these packet_traits definitions and the ones
// we'll use on the host side (SSE, AVX, ...)
//#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
// Most of the following operations require arch >= 3.0 // Most of the following operations require arch >= 3.0
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
template<> struct is_arithmetic<half2> { enum { value = true }; }; template<> struct is_arithmetic<half2> { enum { value = true }; };
@ -90,27 +82,27 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Unaligned>(const Ei
#endif #endif
} }
template<> EIGEN_DEVICE_FUNC inline half2 pgather<Eigen::half, half2>(const Eigen::half* from, Index stride) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather<Eigen::half, half2>(const Eigen::half* from, Index stride) {
return __halves2half2(from[0*stride], from[1*stride]); return __halves2half2(from[0*stride], from[1*stride]);
} }
template<> EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, half2>(Eigen::half* to, const half2& from, Index stride) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, half2>(Eigen::half* to, const half2& from, Index stride) {
to[stride*0] = __low2half(from); to[stride*0] = __low2half(from);
to[stride*1] = __high2half(from); to[stride*1] = __high2half(from);
} }
template<> EIGEN_DEVICE_FUNC inline Eigen::half pfirst<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst<half2>(const half2& a) {
return __low2half(a); return __low2half(a);
} }
template<> EIGEN_DEVICE_FUNC inline half2 pabs<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs<half2>(const half2& a) {
half2 result; half2 result;
result.x = a.x & 0x7FFF7FFF; result.x = a.x & 0x7FFF7FFF;
return result; return result;
} }
EIGEN_DEVICE_FUNC inline void EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
ptranspose(PacketBlock<half2,2>& kernel) { ptranspose(PacketBlock<half2,2>& kernel) {
__half a1 = __low2half(kernel.packet[0]); __half a1 = __low2half(kernel.packet[0]);
__half a2 = __high2half(kernel.packet[0]); __half a2 = __high2half(kernel.packet[0]);
@ -229,7 +221,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax<half2>(const half2&
return __halves2half2(r1, r2); return __halves2half2(r1, r2);
} }
template<> EIGEN_DEVICE_FUNC inline Eigen::half predux<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux<half2>(const half2& a) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hadd(__low2half(a), __high2half(a)); return __hadd(__low2half(a), __high2half(a));
#else #else
@ -239,7 +231,7 @@ template<> EIGEN_DEVICE_FUNC inline Eigen::half predux<half2>(const half2& a) {
#endif #endif
} }
template<> EIGEN_DEVICE_FUNC inline Eigen::half predux_max<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max<half2>(const half2& a) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
__half first = __low2half(a); __half first = __low2half(a);
__half second = __high2half(a); __half second = __high2half(a);
@ -251,7 +243,7 @@ template<> EIGEN_DEVICE_FUNC inline Eigen::half predux_max<half2>(const half2& a
#endif #endif
} }
template<> EIGEN_DEVICE_FUNC inline Eigen::half predux_min<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min<half2>(const half2& a) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
__half first = __low2half(a); __half first = __low2half(a);
__half second = __high2half(a); __half second = __high2half(a);
@ -263,7 +255,7 @@ template<> EIGEN_DEVICE_FUNC inline Eigen::half predux_min<half2>(const half2& a
#endif #endif
} }
template<> EIGEN_DEVICE_FUNC inline Eigen::half predux_mul<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul<half2>(const half2& a) {
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
return __hmul(__low2half(a), __high2half(a)); return __hmul(__low2half(a), __high2half(a));
#else #else
@ -273,7 +265,7 @@ template<> EIGEN_DEVICE_FUNC inline Eigen::half predux_mul<half2>(const half2& a
#endif #endif
} }
template<> EIGEN_DEVICE_FUNC inline half2 plog<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog<half2>(const half2& a) {
float a1 = __low2float(a); float a1 = __low2float(a);
float a2 = __high2float(a); float a2 = __high2float(a);
float r1 = logf(a1); float r1 = logf(a1);
@ -281,7 +273,7 @@ template<> EIGEN_DEVICE_FUNC inline half2 plog<half2>(const half2& a) {
return __floats2half2_rn(r1, r2); return __floats2half2_rn(r1, r2);
} }
template<> EIGEN_DEVICE_FUNC inline half2 pexp<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp<half2>(const half2& a) {
float a1 = __low2float(a); float a1 = __low2float(a);
float a2 = __high2float(a); float a2 = __high2float(a);
float r1 = expf(a1); float r1 = expf(a1);
@ -289,7 +281,7 @@ template<> EIGEN_DEVICE_FUNC inline half2 pexp<half2>(const half2& a) {
return __floats2half2_rn(r1, r2); return __floats2half2_rn(r1, r2);
} }
template<> EIGEN_DEVICE_FUNC inline half2 psqrt<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt<half2>(const half2& a) {
float a1 = __low2float(a); float a1 = __low2float(a);
float a2 = __high2float(a); float a2 = __high2float(a);
float r1 = sqrtf(a1); float r1 = sqrtf(a1);
@ -297,7 +289,7 @@ template<> EIGEN_DEVICE_FUNC inline half2 psqrt<half2>(const half2& a) {
return __floats2half2_rn(r1, r2); return __floats2half2_rn(r1, r2);
} }
template<> EIGEN_DEVICE_FUNC inline half2 prsqrt<half2>(const half2& a) { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt<half2>(const half2& a) {
float a1 = __low2float(a); float a1 = __low2float(a);
float a2 = __high2float(a); float a2 = __high2float(a);
float r1 = rsqrtf(a1); float r1 = rsqrtf(a1);
@ -346,37 +338,37 @@ struct packet_traits<Eigen::half> : default_packet_traits {
template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16}; typedef Packet8h half; }; template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16}; typedef Packet8h half; };
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) { template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
Packet8h result; Packet8h result;
result.x = _mm_set1_epi16(from.x); result.x = _mm_set1_epi16(from.x);
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) { template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
return raw_uint16_to_half(static_cast<unsigned short>(_mm_extract_epi16(from.x, 0))); return raw_uint16_to_half(static_cast<unsigned short>(_mm_extract_epi16(from.x, 0)));
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) { template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
Packet8h result; Packet8h result;
result.x = _mm_load_si128(reinterpret_cast<const __m128i*>(from)); result.x = _mm_load_si128(reinterpret_cast<const __m128i*>(from));
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) { template<> EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
Packet8h result; Packet8h result;
result.x = _mm_loadu_si128(reinterpret_cast<const __m128i*>(from)); result.x = _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8h& from) { template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8h& from) {
_mm_store_si128((__m128i*)to, from.x); _mm_store_si128((__m128i*)to, from.x);
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8h& from) { template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8h& from) {
_mm_storeu_si128((__m128i*)to, from.x); _mm_storeu_si128((__m128i*)to, from.x);
} }
template<> EIGEN_DEVICE_FUNC inline Packet8h template<> EIGEN_STRONG_INLINE Packet8h
ploadquad<Packet8h>(const Eigen::half* from) { ploadquad<Packet8h>(const Eigen::half* from) {
Packet8h result; Packet8h result;
unsigned short a = from[0].x; unsigned short a = from[0].x;
@ -427,30 +419,30 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
#endif #endif
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) { template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
Packet8f af = half2float(a); Packet8f af = half2float(a);
Packet8f bf = half2float(b); Packet8f bf = half2float(b);
Packet8f rf = padd(af, bf); Packet8f rf = padd(af, bf);
return float2half(rf); return float2half(rf);
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) { template<> EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
Packet8f af = half2float(a); Packet8f af = half2float(a);
Packet8f bf = half2float(b); Packet8f bf = half2float(b);
Packet8f rf = pmul(af, bf); Packet8f rf = pmul(af, bf);
return float2half(rf); return float2half(rf);
} }
template<> EIGEN_DEVICE_FUNC inline Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride)
{ {
Packet8h result; Packet8h result;
result.x = _mm_set_epi16(from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); result.x = _mm_set_epi16(from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride)
{ {
EIGEN_ALIGN32 Eigen::half aux[8]; EIGEN_ALIGN32 Eigen::half aux[8];
pstore(aux, from); pstore(aux, from);
@ -464,7 +456,7 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet8h>(Eigen::
to[stride*7].x = aux[7].x; to[stride*7].x = aux[7].x;
} }
EIGEN_DEVICE_FUNC inline void EIGEN_STRONG_INLINE void
ptranspose(PacketBlock<Packet8h,8>& kernel) { ptranspose(PacketBlock<Packet8h,8>& kernel) {
__m128i a = kernel.packet[0].x; __m128i a = kernel.packet[0].x;
__m128i b = kernel.packet[1].x; __m128i b = kernel.packet[1].x;
@ -512,7 +504,7 @@ ptranspose(PacketBlock<Packet8h,8>& kernel) {
kernel.packet[7].x = a7b7c7d7e7f7g7h7; kernel.packet[7].x = a7b7c7d7e7f7g7h7;
} }
EIGEN_DEVICE_FUNC inline void EIGEN_STRONG_INLINE void
ptranspose(PacketBlock<Packet8h,4>& kernel) { ptranspose(PacketBlock<Packet8h,4>& kernel) {
EIGEN_ALIGN32 Eigen::half in[4][8]; EIGEN_ALIGN32 Eigen::half in[4][8];
pstore<Eigen::half>(in[0], kernel.packet[0]); pstore<Eigen::half>(in[0], kernel.packet[0]);
@ -550,7 +542,7 @@ template<> struct is_arithmetic<Packet4h> { enum { value = true }; };
template <> template <>
struct packet_traits<Eigen::half> : default_packet_traits { struct packet_traits<Eigen::half> : default_packet_traits {
typedef Packet4h type; typedef Packet4h type;
// There is no half-size packet for Packet8h. // There is no half-size packet for Packet4h.
typedef Packet4h half; typedef Packet4h half;
enum { enum {
Vectorizable = 1, Vectorizable = 1,
@ -579,19 +571,19 @@ struct packet_traits<Eigen::half> : default_packet_traits {
template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16}; typedef Packet4h half; }; template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16}; typedef Packet4h half; };
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h pset1<Packet4h>(const Eigen::half& from) { template<> EIGEN_STRONG_INLINE Packet4h pset1<Packet4h>(const Eigen::half& from) {
Packet4h result; Packet4h result;
result.x = _mm_set1_pi16(from.x); result.x = _mm_set1_pi16(from.x);
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h>(const Packet4h& from) { template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h>(const Packet4h& from) {
return raw_uint16_to_half(static_cast<unsigned short>(_mm_cvtsi64_si32(from.x))); return raw_uint16_to_half(static_cast<unsigned short>(_mm_cvtsi64_si32(from.x)));
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h pconj(const Packet4h& a) { return a; } template<> EIGEN_STRONG_INLINE Packet4h pconj(const Packet4h& a) { return a; }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const Packet4h& a, const Packet4h& b) { template<> EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const Packet4h& a, const Packet4h& b) {
__int64_t a64 = _mm_cvtm64_si64(a.x); __int64_t a64 = _mm_cvtm64_si64(a.x);
__int64_t b64 = _mm_cvtm64_si64(b.x); __int64_t b64 = _mm_cvtm64_si64(b.x);
@ -614,7 +606,7 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const P
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const Packet4h& b) { template<> EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const Packet4h& b) {
__int64_t a64 = _mm_cvtm64_si64(a.x); __int64_t a64 = _mm_cvtm64_si64(a.x);
__int64_t b64 = _mm_cvtm64_si64(b.x); __int64_t b64 = _mm_cvtm64_si64(b.x);
@ -637,41 +629,41 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const P
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h pload<Packet4h>(const Eigen::half* from) { template<> EIGEN_STRONG_INLINE Packet4h pload<Packet4h>(const Eigen::half* from) {
Packet4h result; Packet4h result;
result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from)); result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h ploadu<Packet4h>(const Eigen::half* from) { template<> EIGEN_STRONG_INLINE Packet4h ploadu<Packet4h>(const Eigen::half* from) {
Packet4h result; Packet4h result;
result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from)); result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4h& from) { template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4h& from) {
__int64_t r = _mm_cvtm64_si64(from.x); __int64_t r = _mm_cvtm64_si64(from.x);
*(reinterpret_cast<__int64_t*>(to)) = r; *(reinterpret_cast<__int64_t*>(to)) = r;
} }
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4h& from) { template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4h& from) {
__int64_t r = _mm_cvtm64_si64(from.x); __int64_t r = _mm_cvtm64_si64(from.x);
*(reinterpret_cast<__int64_t*>(to)) = r; *(reinterpret_cast<__int64_t*>(to)) = r;
} }
template<> EIGEN_DEVICE_FUNC inline Packet4h template<> EIGEN_STRONG_INLINE Packet4h
ploadquad<Packet4h>(const Eigen::half* from) { ploadquad<Packet4h>(const Eigen::half* from) {
return pset1<Packet4h>(*from); return pset1<Packet4h>(*from);
} }
template<> EIGEN_DEVICE_FUNC inline Packet4h pgather<Eigen::half, Packet4h>(const Eigen::half* from, Index stride) template<> EIGEN_STRONG_INLINE Packet4h pgather<Eigen::half, Packet4h>(const Eigen::half* from, Index stride)
{ {
Packet4h result; Packet4h result;
result.x = _mm_set_pi16(from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); result.x = _mm_set_pi16(from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
return result; return result;
} }
template<> EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet4h>(Eigen::half* to, const Packet4h& from, Index stride) template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4h>(Eigen::half* to, const Packet4h& from, Index stride)
{ {
__int64_t a = _mm_cvtm64_si64(from.x); __int64_t a = _mm_cvtm64_si64(from.x);
to[stride*0].x = static_cast<unsigned short>(a); to[stride*0].x = static_cast<unsigned short>(a);
@ -680,7 +672,7 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet4h>(Eigen::
to[stride*3].x = static_cast<unsigned short>(a >> 48); to[stride*3].x = static_cast<unsigned short>(a >> 48);
} }
EIGEN_DEVICE_FUNC inline void EIGEN_STRONG_INLINE void
ptranspose(PacketBlock<Packet4h,4>& kernel) { ptranspose(PacketBlock<Packet4h,4>& kernel) {
__m64 T0 = _mm_unpacklo_pi16(kernel.packet[0].x, kernel.packet[1].x); __m64 T0 = _mm_unpacklo_pi16(kernel.packet[0].x, kernel.packet[1].x);
__m64 T1 = _mm_unpacklo_pi16(kernel.packet[2].x, kernel.packet[3].x); __m64 T1 = _mm_unpacklo_pi16(kernel.packet[2].x, kernel.packet[3].x);

View File

@ -19,7 +19,7 @@ struct scalar_cast_op<float, Eigen::half> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::half result_type; typedef Eigen::half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float2half(a); return __float2half(a);
#else #else
return Eigen::half(a); return Eigen::half(a);
@ -37,7 +37,7 @@ struct scalar_cast_op<int, Eigen::half> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::half result_type; typedef Eigen::half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float2half(static_cast<float>(a)); return __float2half(static_cast<float>(a));
#else #else
return Eigen::half(static_cast<float>(a)); return Eigen::half(static_cast<float>(a));
@ -55,7 +55,7 @@ struct scalar_cast_op<Eigen::half, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type; typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __half2float(a); return __half2float(a);
#else #else
return static_cast<float>(a); return static_cast<float>(a);
@ -69,7 +69,7 @@ struct functor_traits<scalar_cast_op<Eigen::half, float> >
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
template <> template <>
struct type_casting_traits<Eigen::half, float> { struct type_casting_traits<Eigen::half, float> {
@ -139,7 +139,7 @@ struct type_casting_traits<Eigen::half, float> {
}; };
}; };
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pcast<Packet4h, Packet4f>(const Packet4h& a) { template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4h, Packet4f>(const Packet4h& a) {
__int64_t a64 = _mm_cvtm64_si64(a.x); __int64_t a64 = _mm_cvtm64_si64(a.x);
Eigen::half h = raw_uint16_to_half(static_cast<unsigned short>(a64)); Eigen::half h = raw_uint16_to_half(static_cast<unsigned short>(a64));
float f1 = static_cast<float>(h); float f1 = static_cast<float>(h);