Add increment/decrement operators to Eigen::half.

This is for consistency with bfloat16, and to support initialization
with `std::iota`.
This commit is contained in:
Antonio Sanchez 2021-03-15 10:50:37 -07:00
parent b271110788
commit 14487ed14e
2 changed files with 36 additions and 0 deletions

View File

@ -465,6 +465,28 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
return half(static_cast<float>(a) / static_cast<float>(b));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a) {
a += half(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a) {
a -= half(1);
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a, int) {
half original_value = a;
++a;
return original_value;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
half original_value = a;
--a;
return original_value;
}
// Conversion routines, including fallbacks for the host or older CUDA.
// Note that newer Intel CPUs (Haswell or newer) have vectorized versions of
// these in hardware. If we need more performance on older/other CPUs, they are

View File

@ -168,6 +168,20 @@ void test_arithmetic()
VERIFY_IS_APPROX(float(half(1.0f) / half(3.0f)), 0.33333f);
VERIFY_IS_EQUAL(float(-half(4096.0f)), -4096.0f);
VERIFY_IS_EQUAL(float(-half(-4096.0f)), 4096.0f);
half x(3);
half y = ++x;
VERIFY_IS_EQUAL(x, half(4));
VERIFY_IS_EQUAL(y, half(4));
y = --x;
VERIFY_IS_EQUAL(x, half(3));
VERIFY_IS_EQUAL(y, half(3));
y = x++;
VERIFY_IS_EQUAL(x, half(4));
VERIFY_IS_EQUAL(y, half(3));
y = x--;
VERIFY_IS_EQUAL(x, half(3));
VERIFY_IS_EQUAL(y, half(4));
}
void test_comparison()