Add nan-propagation options to matrix and array plugins.

This commit is contained in:
Rasmus Munk Larsen 2021-10-21 12:00:50 -07:00 committed by Antonio Sánchez
parent b86e013321
commit 2d3fec8ff6
3 changed files with 66 additions and 15 deletions

View File

@ -30,15 +30,27 @@ operator/(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*
* \sa max()
*/
EIGEN_MAKE_CWISE_BINARY_OP(min,min)
template <int NaNPropagation=PropagateFast, typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>
#ifdef EIGEN_PARSED_BY_DOXYGEN
min
#else
(min)
#endif
(const OtherDerived &other) const
{
return CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of \c *this and scalar \a other
*
* \sa max()
*/
template <int NaNPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar>, const Derived,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> >
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> >
#ifdef EIGEN_PARSED_BY_DOXYGEN
min
#else
@ -46,7 +58,7 @@ min
#endif
(const Scalar &other) const
{
return (min)(Derived::PlainObject::Constant(rows(), cols(), other));
return (min<NaNPropagation>)(Derived::PlainObject::Constant(rows(), cols(), other));
}
/** \returns an expression of the coefficient-wise max of \c *this and \a other
@ -56,14 +68,26 @@ min
*
* \sa min()
*/
EIGEN_MAKE_CWISE_BINARY_OP(max,max)
template <int NaNPropagation=PropagateFast, typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>
#ifdef EIGEN_PARSED_BY_DOXYGEN
max
#else
(max)
#endif
(const OtherDerived &other) const
{
return CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise max of \c *this and scalar \a other
*
* \sa min()
*/
template <int NaNPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar>, const Derived,
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> >
#ifdef EIGEN_PARSED_BY_DOXYGEN
max
@ -72,7 +96,7 @@ max
#endif
(const Scalar &other) const
{
return (max)(Derived::PlainObject::Constant(rows(), cols(), other));
return (max<NaNPropagation>)(Derived::PlainObject::Constant(rows(), cols(), other));
}
/** \returns an expression of the coefficient-wise absdiff of \c *this and \a other

View File

@ -72,20 +72,21 @@ cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*
* \sa class CwiseBinaryOp, max()
*/
template<typename OtherDerived>
template<int NaNPropagation=PropagateFast, typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar>, const Derived, const OtherDerived>
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>
cwiseMin(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of *this and scalar \a other
*
* \sa class CwiseBinaryOp, min()
*/
template<int NaNPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar>, const Derived, const ConstantReturnType>
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const ConstantReturnType>
cwiseMin(const Scalar &other) const
{
return cwiseMin(Derived::Constant(rows(), cols(), other));
@ -98,20 +99,21 @@ cwiseMin(const Scalar &other) const
*
* \sa class CwiseBinaryOp, min()
*/
template<typename OtherDerived>
template<int NaNPropagation=PropagateFast, typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar>, const Derived, const OtherDerived>
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>
cwiseMax(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise max of *this and scalar \a other
*
* \sa class CwiseBinaryOp, min()
*/
template<int NaNPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar>, const Derived, const ConstantReturnType>
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived, const ConstantReturnType>
cwiseMax(const Scalar &other) const
{
return cwiseMax(Derived::Constant(rows(), cols(), other));

View File

@ -211,6 +211,31 @@ template<typename MatrixType> void cwise_min_max(const MatrixType& m)
VERIFY_IS_APPROX(MatrixType::Constant(rows,cols, maxM1).array(), (m1.array().max)( maxM1));
VERIFY_IS_APPROX(m1.array(), (m1.array().max)( minM1));
// Test NaN propagation for min/max.
if (!NumTraits<Scalar>::IsInteger) {
m1(0,0) = NumTraits<Scalar>::quiet_NaN();
// Elementwise.
VERIFY((numext::isnan)(m1.template cwiseMax<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY((numext::isnan)(m1.template cwiseMin<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY(!(numext::isnan)(m1.template cwiseMax<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY(!(numext::isnan)(m1.template cwiseMin<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY((numext::isnan)(m1.array().template max<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY((numext::isnan)(m1.array().template min<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY(!(numext::isnan)(m1.array().template max<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY(!(numext::isnan)(m1.array().template min<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
// Reductions.
VERIFY((numext::isnan)(m1.template maxCoeff<PropagateNaN>()));
VERIFY((numext::isnan)(m1.template minCoeff<PropagateNaN>()));
if (m1.size() > 1) {
VERIFY(!(numext::isnan)(m1.template maxCoeff<PropagateNumbers>()));
VERIFY(!(numext::isnan)(m1.template minCoeff<PropagateNumbers>()));
} else {
VERIFY((numext::isnan)(m1.template maxCoeff<PropagateNumbers>()));
VERIFY((numext::isnan)(m1.template minCoeff<PropagateNumbers>()));
}
}
}
template<typename MatrixTraits> void resize(const MatrixTraits& t)