mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Modify tensor argmin/argmax to always return first occurence.
As written, depending on multithreading/gpu, the returned index from `argmin`/`argmax` is not currently stable. Here we modify the functors to always keep the first occurence (i.e. if the value is equal to the current min/max, then keep the one with the smallest index). This is otherwise causing unpredictable results in some TF tests.
This commit is contained in:
parent
2d132d1736
commit
3a087ccb99
@ -365,12 +365,16 @@ struct reducer_traits<OrReducer, Device> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Argmin/Argmax reducers. Returns the first occurrence if multiple locations
|
||||||
// Argmin/Argmax reducers
|
// contain the same min/max value.
|
||||||
template <typename T> struct ArgMaxTupleReducer
|
template <typename T> struct ArgMaxTupleReducer
|
||||||
{
|
{
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
|
||||||
if (t.second > accum->second) { *accum = t; }
|
if (t.second < accum->second) {
|
||||||
|
return;
|
||||||
|
} else if (t.second > accum->second || t.first < accum->first) {
|
||||||
|
*accum = t;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
return T(0, NumTraits<typename T::second_type>::lowest());
|
return T(0, NumTraits<typename T::second_type>::lowest());
|
||||||
@ -394,7 +398,11 @@ struct reducer_traits<ArgMaxTupleReducer<T>, Device> {
|
|||||||
template <typename T> struct ArgMinTupleReducer
|
template <typename T> struct ArgMinTupleReducer
|
||||||
{
|
{
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const {
|
||||||
if (t.second < accum->second) { *accum = t; }
|
if (t.second > accum->second) {
|
||||||
|
return;
|
||||||
|
} else if (t.second < accum->second || t.first < accum->first) {
|
||||||
|
*accum = t;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
return T(0, NumTraits<typename T::second_type>::highest());
|
return T(0, NumTraits<typename T::second_type>::highest());
|
||||||
|
Loading…
Reference in New Issue
Block a user