Clean up half packet traits and add a few more missing packet ops.

This commit is contained in:
Rasmus Munk Larsen 2019-03-14 15:18:06 -07:00
parent b013176e52
commit 8450a6d519

View File

@ -30,6 +30,7 @@ template<> struct packet_traits<Eigen::half> : default_packet_traits
size=2,
HasHalfPacket = 0,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasSqrt = 1,
@ -572,6 +573,7 @@ struct packet_traits<half> : default_packet_traits {
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasAbs = 0,
HasAbs2 = 0,
@ -579,7 +581,6 @@ struct packet_traits<half> : default_packet_traits {
HasMax = 0,
HasConj = 0,
HasSetLinear = 0,
HasDiv = 0,
HasSqrt = 0,
HasRsqrt = 0,
HasExp = 0,
@ -770,6 +771,13 @@ template<> EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, con
return float2half(rf);
}
template<> EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = pdiv(af, bf);
return float2half(rf);
}
template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
Packet16f from_float = half2float(from);
return half(predux(from_float));
@ -1054,6 +1062,7 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasAbs = 0,
HasAbs2 = 0,
@ -1061,7 +1070,6 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasMax = 0,
HasConj = 0,
HasSetLinear = 0,
HasDiv = 0,
HasSqrt = 0,
HasRsqrt = 0,
HasExp = 0,
@ -1221,6 +1229,13 @@ template<> EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const
return float2half(rf);
}
template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
Packet8f af = half2float(a);
Packet8f bf = half2float(b);
Packet8f rf = pdiv(af, bf);
return float2half(rf);
}
template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride)
{
Packet8h result;
@ -1407,9 +1422,10 @@ struct packet_traits<Eigen::half> : default_packet_traits {
AlignedOnScalar = 1,
size = 4,
HasHalfPacket = 0,
HasAdd = 0,
HasSub = 0,
HasMul = 0,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 0,
HasAbs = 0,
HasAbs2 = 0,
@ -1417,7 +1433,6 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasMax = 0,
HasConj = 0,
HasSetLinear = 0,
HasDiv = 0,
HasSqrt = 0,
HasRsqrt = 0,
HasExp = 0,
@ -1464,6 +1479,29 @@ template<> EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const Packet4h& a, const
return result;
}
template<> EIGEN_STRONG_INLINE Packet4h psub<Packet4h>(const Packet4h& a, const Packet4h& b) {
__int64_t a64 = _mm_cvtm64_si64(a.x);
__int64_t b64 = _mm_cvtm64_si64(b.x);
Eigen::half h[4];
Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
h[0] = ha - hb;
ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
h[1] = ha - hb;
ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
h[2] = ha - hb;
ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
h[3] = ha - hb;
Packet4h result;
result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
return result;
}
template<> EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const Packet4h& b) {
__int64_t a64 = _mm_cvtm64_si64(a.x);
__int64_t b64 = _mm_cvtm64_si64(b.x);
@ -1487,6 +1525,29 @@ template<> EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const
return result;
}
template<> EIGEN_STRONG_INLINE Packet4h pdiv<Packet4h>(const Packet4h& a, const Packet4h& b) {
__int64_t a64 = _mm_cvtm64_si64(a.x);
__int64_t b64 = _mm_cvtm64_si64(b.x);
Eigen::half h[4];
Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
h[0] = ha / hb;
ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
h[1] = ha / hb;
ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
h[2] = ha / hb;
ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
h[3] = ha / hb;
Packet4h result;
result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
return result;
}
template<> EIGEN_STRONG_INLINE Packet4h pload<Packet4h>(const Eigen::half* from) {
Packet4h result;
result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));