Modify tensor argmin/argmax to always return first occurence.

As written, depending on multithreading/gpu, the returned index from
`argmin`/`argmax` is not currently stable.  Here we modify the functors
to always keep the first occurence (i.e. if the value is equal to the
current min/max, then keep the one with the smallest index).

This is otherwise causing unpredictable results in some TF tests.
This commit is contained in:
Antonio Sanchez 2021-06-25 14:22:19 -07:00
parent 2d132d1736
commit 3a087ccb99

View File

@ -365,12 +365,16 @@ struct reducer_traits<OrReducer, Device> {
}; };
}; };
// Argmin/Argmax reducers. Returns the first occurrence if multiple locations
// Argmin/Argmax reducers // contain the same min/max value.
template <typename T> struct ArgMaxTupleReducer template <typename T> struct ArgMaxTupleReducer
{ {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t.second > accum->second) { *accum = t; } if (t.second < accum->second) {
return;
} else if (t.second > accum->second || t.first < accum->first) {
*accum = t;
}
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return T(0, NumTraits<typename T::second_type>::lowest()); return T(0, NumTraits<typename T::second_type>::lowest());
@ -394,7 +398,11 @@ struct reducer_traits<ArgMaxTupleReducer<T>, Device> {
template <typename T> struct ArgMinTupleReducer template <typename T> struct ArgMinTupleReducer
{ {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const {
if (t.second < accum->second) { *accum = t; } if (t.second > accum->second) {
return;
} else if (t.second < accum->second || t.first < accum->first) {
*accum = t;
}
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return T(0, NumTraits<typename T::second_type>::highest()); return T(0, NumTraits<typename T::second_type>::highest());