mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Implement missing AVX half ops.
Minimal implementation of AVX `Eigen::half` ops to bring in line with `bfloat16`. Allows `packetmath_13` to pass. Also adjusted `bfloat16` packet traits to match the supported set of ops (e.g. Bessel is not actually implemented).
This commit is contained in:
parent
38abf2be42
commit
a3b300f1af
@ -158,6 +158,16 @@ Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
|
|||||||
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
|
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
|
||||||
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
|
||||||
|
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
|
||||||
|
@ -119,22 +119,34 @@ struct packet_traits<Eigen::half> : default_packet_traits {
|
|||||||
AlignedOnScalar = 1,
|
AlignedOnScalar = 1,
|
||||||
size = 8,
|
size = 8,
|
||||||
HasHalfPacket = 0,
|
HasHalfPacket = 0,
|
||||||
|
|
||||||
|
HasCmp = 1,
|
||||||
HasAdd = 1,
|
HasAdd = 1,
|
||||||
HasSub = 1,
|
HasSub = 1,
|
||||||
HasMul = 1,
|
HasMul = 1,
|
||||||
HasDiv = 1,
|
HasDiv = 1,
|
||||||
|
HasSin = EIGEN_FAST_MATH,
|
||||||
|
HasCos = EIGEN_FAST_MATH,
|
||||||
HasNegate = 1,
|
HasNegate = 1,
|
||||||
HasAbs = 0,
|
HasAbs = 1,
|
||||||
HasAbs2 = 0,
|
HasAbs2 = 0,
|
||||||
HasMin = 0,
|
HasMin = 1,
|
||||||
HasMax = 0,
|
HasMax = 1,
|
||||||
HasConj = 0,
|
HasConj = 1,
|
||||||
HasSetLinear = 0,
|
HasSetLinear = 0,
|
||||||
HasSqrt = 0,
|
HasLog = 1,
|
||||||
HasRsqrt = 0,
|
HasLog1p = 1,
|
||||||
HasExp = 0,
|
HasExpm1 = 1,
|
||||||
HasLog = 0,
|
HasExp = 1,
|
||||||
HasBlend = 0
|
HasSqrt = 1,
|
||||||
|
HasRsqrt = 1,
|
||||||
|
HasTanh = EIGEN_FAST_MATH,
|
||||||
|
HasErf = EIGEN_FAST_MATH,
|
||||||
|
HasBlend = 0,
|
||||||
|
HasRound = 1,
|
||||||
|
HasFloor = 1,
|
||||||
|
HasCeil = 1,
|
||||||
|
HasRint = 1
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -150,16 +162,24 @@ struct packet_traits<bfloat16> : default_packet_traits {
|
|||||||
size = 8,
|
size = 8,
|
||||||
HasHalfPacket = 0,
|
HasHalfPacket = 0,
|
||||||
|
|
||||||
HasCmp = 1,
|
HasCmp = 1,
|
||||||
|
HasAdd = 1,
|
||||||
|
HasSub = 1,
|
||||||
|
HasMul = 1,
|
||||||
HasDiv = 1,
|
HasDiv = 1,
|
||||||
HasSin = EIGEN_FAST_MATH,
|
HasSin = EIGEN_FAST_MATH,
|
||||||
HasCos = EIGEN_FAST_MATH,
|
HasCos = EIGEN_FAST_MATH,
|
||||||
|
HasNegate = 1,
|
||||||
|
HasAbs = 1,
|
||||||
|
HasAbs2 = 0,
|
||||||
|
HasMin = 1,
|
||||||
|
HasMax = 1,
|
||||||
|
HasConj = 1,
|
||||||
|
HasSetLinear = 0,
|
||||||
HasLog = 1,
|
HasLog = 1,
|
||||||
HasLog1p = 1,
|
HasLog1p = 1,
|
||||||
HasExpm1 = 1,
|
HasExpm1 = 1,
|
||||||
HasExp = 1,
|
HasExp = 1,
|
||||||
HasNdtri = 1,
|
|
||||||
HasBessel = 1,
|
|
||||||
HasSqrt = 1,
|
HasSqrt = 1,
|
||||||
HasRsqrt = 1,
|
HasRsqrt = 1,
|
||||||
HasTanh = EIGEN_FAST_MATH,
|
HasTanh = EIGEN_FAST_MATH,
|
||||||
@ -870,8 +890,7 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Packet math for Eigen::half
|
// Packet math for Eigen::half
|
||||||
// TODO(cantonios): add missing packet ops
|
|
||||||
// - pabs, pmin, pmax, plset, pround, print, pceil, pfloor, pcmp_lt, pcmp_le, pcmp_lt_or_nan
|
|
||||||
template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; };
|
template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; };
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
|
template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
|
||||||
@ -914,6 +933,16 @@ ploadquad<Packet8h>(const Eigen::half* from) {
|
|||||||
return _mm_set_epi16(b, b, b, b, a, a, a, a);
|
return _mm_set_epi16(b, b, b, b, a, a, a, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
|
||||||
|
return _mm_cmpeq_epi32(a, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) {
|
||||||
|
const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
|
||||||
|
return _mm_andnot_si128(sign_mask, a);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
|
EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
|
||||||
#ifdef EIGEN_HAS_FP16_C
|
#ifdef EIGEN_HAS_FP16_C
|
||||||
return _mm256_cvtph_ps(a);
|
return _mm256_cvtph_ps(a);
|
||||||
@ -951,8 +980,21 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
|
template <>
|
||||||
return _mm_cmpeq_epi32(a, a);
|
EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a,
|
||||||
|
const Packet8h& b) {
|
||||||
|
return float2half(pmin<Packet8f>(half2float(a), half2float(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a,
|
||||||
|
const Packet8h& b) {
|
||||||
|
return float2half(pmax<Packet8f>(half2float(a), half2float(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
|
||||||
|
return float2half(plset<Packet8f>(static_cast<float>(a)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) {
|
template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) {
|
||||||
@ -974,13 +1016,36 @@ template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Pack
|
|||||||
return _mm_blendv_epi8(b, a, mask);
|
return _mm_blendv_epi8(b, a, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
|
||||||
|
return float2half(pround<Packet8f>(half2float(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
|
||||||
|
return float2half(print<Packet8f>(half2float(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
|
||||||
|
return float2half(pceil<Packet8f>(half2float(a)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
|
||||||
|
return float2half(pfloor<Packet8f>(half2float(a)));
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
|
template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
|
||||||
Packet8f af = half2float(a);
|
return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
|
||||||
Packet8f bf = half2float(b);
|
}
|
||||||
Packet8f rf = pcmp_eq(af, bf);
|
|
||||||
// Pack the 32-bit flags into 16-bits flags.
|
template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) {
|
||||||
return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
|
return Pack16To8(pcmp_le(half2float(a), half2float(b)));
|
||||||
_mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) {
|
||||||
|
return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) {
|
||||||
|
return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
|
template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
|
||||||
@ -1148,6 +1213,8 @@ ptranspose(PacketBlock<Packet8h,4>& kernel) {
|
|||||||
kernel.packet[3] = pload<Packet8h>(out[3]);
|
kernel.packet[3] = pload<Packet8h>(out[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BFloat16 implementation.
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
|
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
|
||||||
#ifdef EIGEN_VECTORIZE_AVX2
|
#ifdef EIGEN_VECTORIZE_AVX2
|
||||||
__m256i extend = _mm256_cvtepu16_epi32(a);
|
__m256i extend = _mm256_cvtepu16_epi32(a);
|
||||||
@ -1262,7 +1329,8 @@ template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
|
EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
|
||||||
return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a)));
|
const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
|
||||||
|
return _mm_andnot_si128(sign_mask, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -56,6 +56,13 @@
|
|||||||
#define EIGEN_CONSTEXPR
|
#define EIGEN_CONSTEXPR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
|
||||||
|
template <> \
|
||||||
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
|
||||||
|
PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
|
||||||
|
return float2half(METHOD<PACKET_F>(half2float(_x))); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
struct half;
|
struct half;
|
||||||
|
@ -556,7 +556,7 @@ void packetmath_real() {
|
|||||||
VERIFY((numext::isnan)(data2[0]));
|
VERIFY((numext::isnan)(data2[0]));
|
||||||
// TODO(rmlarsen): Re-enable for bfloat16.
|
// TODO(rmlarsen): Re-enable for bfloat16.
|
||||||
if (!internal::is_same<Scalar, bfloat16>::value) {
|
if (!internal::is_same<Scalar, bfloat16>::value) {
|
||||||
VERIFY_IS_EQUAL(std::exp(small), data2[1]);
|
VERIFY_IS_APPROX(std::exp(small), data2[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
data1[0] = -small;
|
data1[0] = -small;
|
||||||
@ -564,21 +564,21 @@ void packetmath_real() {
|
|||||||
h.store(data2, internal::pexp(h.load(data1)));
|
h.store(data2, internal::pexp(h.load(data1)));
|
||||||
// TODO(rmlarsen): Re-enable for bfloat16.
|
// TODO(rmlarsen): Re-enable for bfloat16.
|
||||||
if (!internal::is_same<Scalar, bfloat16>::value) {
|
if (!internal::is_same<Scalar, bfloat16>::value) {
|
||||||
VERIFY_IS_EQUAL(std::exp(-small), data2[0]);
|
VERIFY_IS_APPROX(std::exp(-small), data2[0]);
|
||||||
}
|
}
|
||||||
VERIFY_IS_EQUAL(std::exp(Scalar(0)), data2[1]);
|
VERIFY_IS_EQUAL(std::exp(Scalar(0)), data2[1]);
|
||||||
|
|
||||||
data1[0] = (std::numeric_limits<Scalar>::min)();
|
data1[0] = (std::numeric_limits<Scalar>::min)();
|
||||||
data1[1] = -(std::numeric_limits<Scalar>::min)();
|
data1[1] = -(std::numeric_limits<Scalar>::min)();
|
||||||
h.store(data2, internal::pexp(h.load(data1)));
|
h.store(data2, internal::pexp(h.load(data1)));
|
||||||
VERIFY_IS_EQUAL(std::exp((std::numeric_limits<Scalar>::min)()), data2[0]);
|
VERIFY_IS_APPROX(std::exp((std::numeric_limits<Scalar>::min)()), data2[0]);
|
||||||
VERIFY_IS_EQUAL(std::exp(-(std::numeric_limits<Scalar>::min)()), data2[1]);
|
VERIFY_IS_APPROX(std::exp(-(std::numeric_limits<Scalar>::min)()), data2[1]);
|
||||||
|
|
||||||
data1[0] = std::numeric_limits<Scalar>::denorm_min();
|
data1[0] = std::numeric_limits<Scalar>::denorm_min();
|
||||||
data1[1] = -std::numeric_limits<Scalar>::denorm_min();
|
data1[1] = -std::numeric_limits<Scalar>::denorm_min();
|
||||||
h.store(data2, internal::pexp(h.load(data1)));
|
h.store(data2, internal::pexp(h.load(data1)));
|
||||||
VERIFY_IS_EQUAL(std::exp(std::numeric_limits<Scalar>::denorm_min()), data2[0]);
|
VERIFY_IS_APPROX(std::exp(std::numeric_limits<Scalar>::denorm_min()), data2[0]);
|
||||||
VERIFY_IS_EQUAL(std::exp(-std::numeric_limits<Scalar>::denorm_min()), data2[1]);
|
VERIFY_IS_APPROX(std::exp(-std::numeric_limits<Scalar>::denorm_min()), data2[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (PacketTraits::HasTanh) {
|
if (PacketTraits::HasTanh) {
|
||||||
@ -618,7 +618,7 @@ void packetmath_real() {
|
|||||||
test::packet_helper<PacketTraits::HasLog, Packet> h;
|
test::packet_helper<PacketTraits::HasLog, Packet> h;
|
||||||
h.store(data2, internal::plog(h.load(data1)));
|
h.store(data2, internal::plog(h.load(data1)));
|
||||||
VERIFY((numext::isnan)(data2[0]));
|
VERIFY((numext::isnan)(data2[0]));
|
||||||
VERIFY_IS_EQUAL(std::log(std::numeric_limits<Scalar>::epsilon()), data2[1]);
|
VERIFY_IS_APPROX(std::log(std::numeric_limits<Scalar>::epsilon()), data2[1]);
|
||||||
|
|
||||||
data1[0] = -std::numeric_limits<Scalar>::epsilon();
|
data1[0] = -std::numeric_limits<Scalar>::epsilon();
|
||||||
data1[1] = Scalar(0);
|
data1[1] = Scalar(0);
|
||||||
@ -629,7 +629,7 @@ void packetmath_real() {
|
|||||||
data1[0] = (std::numeric_limits<Scalar>::min)();
|
data1[0] = (std::numeric_limits<Scalar>::min)();
|
||||||
data1[1] = -(std::numeric_limits<Scalar>::min)();
|
data1[1] = -(std::numeric_limits<Scalar>::min)();
|
||||||
h.store(data2, internal::plog(h.load(data1)));
|
h.store(data2, internal::plog(h.load(data1)));
|
||||||
VERIFY_IS_EQUAL(std::log((std::numeric_limits<Scalar>::min)()), data2[0]);
|
VERIFY_IS_APPROX(std::log((std::numeric_limits<Scalar>::min)()), data2[0]);
|
||||||
VERIFY((numext::isnan)(data2[1]));
|
VERIFY((numext::isnan)(data2[1]));
|
||||||
|
|
||||||
// Note: 32-bit arm always flushes denorms to zero.
|
// Note: 32-bit arm always flushes denorms to zero.
|
||||||
@ -672,8 +672,10 @@ void packetmath_real() {
|
|||||||
VERIFY((numext::isnan)(data2[0]));
|
VERIFY((numext::isnan)(data2[0]));
|
||||||
VERIFY((numext::isnan)(data2[1]));
|
VERIFY((numext::isnan)(data2[1]));
|
||||||
}
|
}
|
||||||
// TODO(rmlarsen): Re-enable for bfloat16.
|
// TODO(rmlarsen): Re-enable for half and bfloat16.
|
||||||
if (PacketTraits::HasCos && !internal::is_same<Scalar, bfloat16>::value) {
|
if (PacketTraits::HasCos
|
||||||
|
&& !internal::is_same<Scalar, half>::value
|
||||||
|
&& !internal::is_same<Scalar, bfloat16>::value) {
|
||||||
test::packet_helper<PacketTraits::HasCos, Packet> h;
|
test::packet_helper<PacketTraits::HasCos, Packet> h;
|
||||||
for (Scalar k = Scalar(1); k < Scalar(10000) / std::numeric_limits<Scalar>::epsilon(); k *= Scalar(2)) {
|
for (Scalar k = Scalar(1); k < Scalar(10000) / std::numeric_limits<Scalar>::epsilon(); k *= Scalar(2)) {
|
||||||
for (int k1 = 0; k1 <= 1; ++k1) {
|
for (int k1 = 0; k1 <= 1; ++k1) {
|
||||||
@ -1074,12 +1076,7 @@ EIGEN_DECLARE_TEST(packetmath) {
|
|||||||
CALL_SUBTEST_10(test::runner<uint64_t>::run());
|
CALL_SUBTEST_10(test::runner<uint64_t>::run());
|
||||||
CALL_SUBTEST_11(test::runner<std::complex<float> >::run());
|
CALL_SUBTEST_11(test::runner<std::complex<float> >::run());
|
||||||
CALL_SUBTEST_12(test::runner<std::complex<double> >::run());
|
CALL_SUBTEST_12(test::runner<std::complex<double> >::run());
|
||||||
#if defined(EIGEN_VECTORIZE_AVX)
|
|
||||||
// AVX half packets not fully implemented.
|
|
||||||
CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>()));
|
|
||||||
#else
|
|
||||||
CALL_SUBTEST_13(test::runner<half>::run());
|
CALL_SUBTEST_13(test::runner<half>::run());
|
||||||
#endif
|
|
||||||
CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>()));
|
CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>()));
|
||||||
CALL_SUBTEST_15(test::runner<bfloat16>::run());
|
CALL_SUBTEST_15(test::runner<bfloat16>::run());
|
||||||
g_first_pass = false;
|
g_first_pass = false;
|
||||||
|
Loading…
Reference in New Issue
Block a user