Merge branch 'rmlarsen1/eigen-nan_prop'

This commit is contained in:
Rasmus Munk Larsen 2021-02-26 09:21:24 -08:00
commit fe19714f80
3 changed files with 41 additions and 5 deletions

View File

@ -449,9 +449,23 @@ template<typename Derived> class DenseBase
EIGEN_DEVICE_FUNC Scalar prod() const;
template<int NaNPropagation>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar minCoeff() const;
template<int NaNPropagation>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar maxCoeff() const;
// By default, the fastest version with undefined NaN propagation semantics is
// used.
// TODO(rmlarsen): Replace with default template argument when we move to
// c++11 or beyond.
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar minCoeff() const {
return minCoeff<PropagateFast>();
}
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar maxCoeff() const {
return maxCoeff<PropagateFast>();
}
template<typename IndexType> EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar minCoeff(IndexType* row, IndexType* col) const;
template<typename IndexType> EIGEN_DEVICE_FUNC

View File

@ -419,25 +419,33 @@ DenseBase<Derived>::redux(const Func& func) const
}
/** \returns the minimum of all coefficients of \c *this.
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is minimum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
* \warning the result is undefined if \c *this contains NaN.
*/
template<typename Derived>
template<int NaNPropagation>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::minCoeff() const
{
return derived().redux(Eigen::internal::scalar_min_op<Scalar,Scalar>());
return derived().redux(Eigen::internal::scalar_min_op<Scalar,Scalar, NaNPropagation>());
}
/** \returns the maximum of all coefficients of \c *this.
/** \returns the maximum of all coefficients of \c *this.
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
* \warning the result is undefined if \c *this contains NaN.
*/
template<typename Derived>
template<int NaNPropagation>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
DenseBase<Derived>::maxCoeff() const
{
return derived().redux(Eigen::internal::scalar_max_op<Scalar,Scalar>());
return derived().redux(Eigen::internal::scalar_max_op<Scalar,Scalar, NaNPropagation>());
}
/** \returns the sum of all coefficients of \c *this

View File

@ -610,6 +610,20 @@ template<typename ArrayType> void min_max(const ArrayType& m)
VERIFY_IS_APPROX(ArrayType::Constant(rows,cols, maxM1), (m1.max)( maxM1));
VERIFY_IS_APPROX(m1, (m1.max)( minM1));
// min/max with various NaN propagation options.
if (m1.size() > 1 && !NumTraits<Scalar>::IsInteger) {
m1(0,0) = std::numeric_limits<Scalar>::quiet_NaN();
maxM1 = m1.template maxCoeff<PropagateNaN>();
minM1 = m1.template minCoeff<PropagateNaN>();
VERIFY((numext::isnan)(maxM1));
VERIFY((numext::isnan)(minM1));
maxM1 = m1.template maxCoeff<PropagateNumbers>();
minM1 = m1.template minCoeff<PropagateNumbers>();
VERIFY(!(numext::isnan)(maxM1));
VERIFY(!(numext::isnan)(minM1));
}
}
EIGEN_DECLARE_TEST(array_cwise)