Handle PropagateFast the same way as PropagateNaN in minmax visitor to

This commit is contained in:
Rasmus Munk Larsen 2023-03-13 20:47:11 +00:00
parent 9d72412385
commit 79de101d23
2 changed files with 18 additions and 7 deletions

View File

@ -453,8 +453,10 @@ struct minmax_compare<Scalar, NaNPropagation, false> {
static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max<NaNPropagation>(p); }
};
// Default imlementatio
template <typename Derived, bool is_min, int NaNPropagation>
// Default implementation used by non-floating types, where we do not
// need special logic for NaN handling.
template <typename Derived, bool is_min, int NaNPropagation,
bool isInt = NumTraits<typename Derived::Scalar>::IsInteger>
struct minmax_coeff_visitor : coeff_visitor<Derived> {
using Scalar = typename Derived::Scalar;
using Packet = typename packet_traits<Scalar>::type;
@ -493,7 +495,7 @@ struct minmax_coeff_visitor : coeff_visitor<Derived> {
// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN,
// in which case, row=0, col=0 is returned for the location.
template <typename Derived, bool is_min>
struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers> : coeff_visitor<Derived> {
struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers, false> : coeff_visitor<Derived> {
typedef typename Derived::Scalar Scalar;
using Packet = typename packet_traits<Scalar>::type;
using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
@ -537,10 +539,10 @@ struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers> : coeff_visitor<D
}
};
// Propagate NaN. If the matrix contains NaN, the location of the first NaN will be returned in
// row and col.
template <typename Derived, bool is_min>
struct minmax_coeff_visitor<Derived, is_min, PropagateNaN> : coeff_visitor<Derived> {
// Propagate NaNs. If the matrix contains NaN, the location of the first NaN
// will be returned in row and col.
template <typename Derived, bool is_min, int NaNPropagation>
struct minmax_coeff_visitor<Derived, is_min, NaNPropagation, false> : coeff_visitor<Derived> {
typedef typename Derived::Scalar Scalar;
using Packet = typename packet_traits<Scalar>::type;
using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;

View File

@ -112,6 +112,15 @@ template<typename MatrixType> void matrixVisitor(const MatrixType& p)
VERIFY(eigen_maxcol == 0);
VERIFY((numext::isnan)(eigen_minc));
VERIFY((numext::isnan)(eigen_maxc));
eigen_minc = m.template minCoeff<PropagateFast>(&eigen_minrow, &eigen_mincol);
eigen_maxc = m.template maxCoeff<PropagateFast>(&eigen_maxrow, &eigen_maxcol);
VERIFY(eigen_minrow == 0);
VERIFY(eigen_maxrow == 0);
VERIFY(eigen_mincol == 0);
VERIFY(eigen_maxcol == 0);
VERIFY((numext::isnan)(eigen_minc));
VERIFY((numext::isnan)(eigen_maxc));
}
}