Fix ambiguous call to CUDA __half constructor.

This commit is contained in:
Antonio Sanchez 2021-03-08 21:06:28 -08:00
parent 94327dbfba
commit 853a5c4b84

View File

@ -848,19 +848,23 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(c
#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) {
return static_cast<Eigen::half>(__shfl_sync(mask, static_cast<__half>(var), srcLane, width));
const __half h = var;
return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
return static_cast<Eigen::half>(__shfl_up_sync(mask, static_cast<__half>(var), delta, width));
const __half h = var;
return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
return static_cast<Eigen::half>(__shfl_down_sync(mask, static_cast<__half>(var), delta, width));
const __half h = var;
return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) {
return static_cast<Eigen::half>(__shfl_xor_sync(mask, static_cast<__half>(var), laneMask, width));
const __half h = var;
return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
}
#else // HIP or CUDA SDK < 9.0