Add packet generic ops predux_fmin, predux_fmin_nan, predux_fmax, and predux_fmax_nan that implement reductions with PropagateNaN, and PropagateNumbers semantics. Add (slow) generic implementations for most reductions.

This commit is contained in:
Rasmus Munk Larsen 2020-10-13 21:48:31 +00:00
parent 807e51528d
commit c6953f799b
6 changed files with 395 additions and 294 deletions

View File

@ -215,19 +215,166 @@ pmul(const bool& a, const bool& b) { return a && b; }
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pdiv(const Packet& a, const Packet& b) { return a/b; }
/** \internal \returns one bits */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
ptrue(const Packet& /*a*/) { Packet b; memset((void*)&b, 0xff, sizeof(b)); return b;}
/** \internal \returns zero bits */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pzero(const Packet& /*a*/) { Packet b; memset((void*)&b, 0, sizeof(b)); return b;}
/** \internal \returns a <= b as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_le(const Packet& a, const Packet& b) { return a<=b ? ptrue(a) : pzero(a); }
/** \internal \returns a < b as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_lt(const Packet& a, const Packet& b) { return a<b ? ptrue(a) : pzero(a); }
/** \internal \returns a == b as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_eq(const Packet& a, const Packet& b) { return a==b ? ptrue(a) : pzero(a); }
/** \internal \returns a < b or a==NaN or b==NaN as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_lt_or_nan(const Packet& a, const Packet& b) { return a>=b ? pzero(a) : ptrue(a); }
template<> EIGEN_DEVICE_FUNC inline float pzero<float>(const float& a) {
EIGEN_UNUSED_VARIABLE(a)
return 0.f;
}
template<> EIGEN_DEVICE_FUNC inline double pzero<double>(const double& a) {
EIGEN_UNUSED_VARIABLE(a)
return 0.;
}
template <typename RealScalar>
EIGEN_DEVICE_FUNC inline std::complex<RealScalar> ptrue(const std::complex<RealScalar>& /*a*/) {
RealScalar b;
b = ptrue(b);
return std::complex<RealScalar>(b, b);
}
template <typename Packet, typename Op>
EIGEN_DEVICE_FUNC inline Packet bitwise_helper(const Packet& a, const Packet& b, Op op) {
const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
Packet c;
unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
for (size_t i = 0; i < sizeof(Packet); ++i) {
*c_ptr++ = op(*a_ptr++, *b_ptr++);
}
return c;
}
/** \internal \returns the bitwise and of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pand(const Packet& a, const Packet& b) {
EIGEN_USING_STD(bit_and);
return bitwise_helper(a ,b, bit_and<unsigned char>());
}
/** \internal \returns the bitwise or of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
por(const Packet& a, const Packet& b) {
EIGEN_USING_STD(bit_or);
return bitwise_helper(a ,b, bit_or<unsigned char>());
}
/** \internal \returns the bitwise xor of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pxor(const Packet& a, const Packet& b) {
EIGEN_USING_STD(bit_xor);
return bitwise_helper(a ,b, bit_xor<unsigned char>());
}
/** \internal \returns the bitwise and of \a a and not \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pandnot(const Packet& a, const Packet& b) { return pand(a, pxor(ptrue(b), b)); }
/** \internal \returns \a or \b for each field in packet according to \mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pselect(const Packet& mask, const Packet& a, const Packet& b) {
return por(pand(a,mask),pandnot(b,mask));
}
template<> EIGEN_DEVICE_FUNC inline float pselect<float>(
const float& cond, const float& a, const float&b) {
return numext::equal_strict(cond,0.f) ? b : a;
}
template<> EIGEN_DEVICE_FUNC inline double pselect<double>(
const double& cond, const double& a, const double& b) {
return numext::equal_strict(cond,0.) ? b : a;
}
template<> EIGEN_DEVICE_FUNC inline bool pselect<bool>(
const bool& cond, const bool& a, const bool& b) {
return cond ? a : b;
}
/** \internal \returns the min or of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, the result is implementation defined. */
template<int NaNPropagation>
struct pminmax_impl {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
return op(a,b);
}
};
/** \internal \returns the min or max of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, NaN is returned. */
template<>
struct pminmax_impl<PropagateNaN> {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
Packet not_nan_mask_a = pcmp_eq(a, a);
Packet not_nan_mask_b = pcmp_eq(b, b);
return pselect(not_nan_mask_a,
pselect(not_nan_mask_b, op(a, b), b),
a);
}
};
/** \internal \returns the min or max of \a a and \a b (coeff-wise)
If both \a a and \a b are NaN, NaN is returned.
Equivalent to std::fmin(a, b). */
template<>
struct pminmax_impl<PropagateNumbers> {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
Packet not_nan_mask_a = pcmp_eq(a, a);
Packet not_nan_mask_b = pcmp_eq(b, b);
return pselect(not_nan_mask_a,
pselect(not_nan_mask_b, op(a, b), a),
b);
}
};
/** \internal \returns the min of \a a and \a b (coeff-wise).
If \a a or \b b is NaN, the return value is implementation defined. */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pmin(const Packet& a, const Packet& b) { return numext::mini(a, b); }
pmin(const Packet& a, const Packet& b) { return numext::mini(a,b); }
/** \internal \returns the min of \a a and \a b (coeff-wise).
NaNPropagation determines the NaN propagation semantics. */
template<int NaNPropagation, typename Packet> EIGEN_DEVICE_FUNC inline Packet
pmin(const Packet& a, const Packet& b) { return pminmax_impl<NaNPropagation>::run(a,b, pmin<Packet>); }
/** \internal \returns the max of \a a and \a b (coeff-wise)
If \a a or \b b is NaN, the return value is implementation defined. */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); }
/** \internal \returns the max of \a a and \a b (coeff-wise).
NaNPropagation determines the NaN propagation semantics. */
template<int NaNPropagation, typename Packet> EIGEN_DEVICE_FUNC inline Packet
pmax(const Packet& a, const Packet& b) { return pminmax_impl<NaNPropagation>::run(a,b, pmax<Packet>); }
/** \internal \returns the absolute value of \a a */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pabs(const Packet& a) { using std::abs; return abs(a); }
pabs(const Packet& a) { return numext::abs(a); }
template<> EIGEN_DEVICE_FUNC inline unsigned int
pabs(const unsigned int& a) { return a; }
template<> EIGEN_DEVICE_FUNC inline unsigned long
@ -279,105 +426,6 @@ pldexp(const Packet &a, const Packet &exponent) {
return ldexp(a, static_cast<int>(exponent));
}
/** \internal \returns zero bits */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pzero(const Packet& /*a*/) { Packet b; memset((void*)&b, 0, sizeof(b)); return b;}
template<> EIGEN_DEVICE_FUNC inline float pzero<float>(const float& a) {
EIGEN_UNUSED_VARIABLE(a)
return 0.f;
}
template<> EIGEN_DEVICE_FUNC inline double pzero<double>(const double& a) {
EIGEN_UNUSED_VARIABLE(a)
return 0.;
}
/** \internal \returns one bits */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
ptrue(const Packet& /*a*/) { Packet b; memset((void*)&b, 0xff, sizeof(b)); return b;}
template <typename RealScalar>
EIGEN_DEVICE_FUNC inline std::complex<RealScalar> ptrue(const std::complex<RealScalar>& /*a*/) {
RealScalar b;
b = ptrue(b);
return std::complex<RealScalar>(b, b);
}
template <typename Packet, typename Op>
EIGEN_DEVICE_FUNC inline Packet bitwise_helper(const Packet& a, const Packet& b, Op op) {
const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
Packet c;
unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
for (size_t i = 0; i < sizeof(Packet); ++i) {
*c_ptr++ = op(*a_ptr++, *b_ptr++);
}
return c;
}
/** \internal \returns the bitwise and of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pand(const Packet& a, const Packet& b) {
EIGEN_USING_STD(bit_and);
return bitwise_helper(a ,b, bit_and<unsigned char>());
}
/** \internal \returns the bitwise or of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
por(const Packet& a, const Packet& b) {
EIGEN_USING_STD(bit_or);
return bitwise_helper(a ,b, bit_or<unsigned char>());
}
/** \internal \returns the bitwise xor of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pxor(const Packet& a, const Packet& b) {
EIGEN_USING_STD(bit_xor);
return bitwise_helper(a ,b, bit_xor<unsigned char>());
}
/** \internal \returns the bitwise and of \a a and not \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pandnot(const Packet& a, const Packet& b) { return pand(a, pxor(ptrue(b), b)); }
/** \internal \returns a <= b as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_le(const Packet& a, const Packet& b) { return a<=b ? ptrue(a) : pzero(a); }
/** \internal \returns a < b as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_lt(const Packet& a, const Packet& b) { return a<b ? ptrue(a) : pzero(a); }
/** \internal \returns a == b as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_eq(const Packet& a, const Packet& b) { return a==b ? ptrue(a) : pzero(a); }
/** \internal \returns a < b or a==NaN or b==NaN as a bit mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pcmp_lt_or_nan(const Packet& a, const Packet& b) { return a>=b ? pzero(a) : ptrue(a); }
/** \internal \returns \a or \b for each field in packet according to \mask */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pselect(const Packet& mask, const Packet& a, const Packet& b) {
return por(pand(a,mask),pandnot(b,mask));
}
template<> EIGEN_DEVICE_FUNC inline float pselect<float>(
const float& cond, const float& a, const float&b) {
return numext::equal_strict(cond,0.f) ? b : a;
}
template<> EIGEN_DEVICE_FUNC inline double pselect<double>(
const double& cond, const double& a, const double& b) {
return numext::equal_strict(cond,0.) ? b : a;
}
template<> EIGEN_DEVICE_FUNC inline bool pselect<bool>(
const bool& cond, const bool& a, const bool& b) {
return cond ? a : b;
}
/** \internal \returns the min of \a a and \a b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pabsdiff(const Packet& a, const Packet& b) { return pselect(pcmp_lt(a, b), psub(b, a), psub(a, b)); }
@ -507,57 +555,6 @@ template<typename Scalar> EIGEN_DEVICE_FUNC inline void prefetch(const Scalar* a
#endif
}
/** \internal \returns the first element of a packet */
template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type pfirst(const Packet& a)
{ return a; }
/** \internal \returns the sum of the elements of \a a*/
template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux(const Packet& a)
{ return a; }
/** \internal \returns the sum of the elements of upper and lower half of \a a if \a a is larger than 4.
* For a packet {a0, a1, a2, a3, a4, a5, a6, a7}, it returns a half packet {a0+a4, a1+a5, a2+a6, a3+a7}
* For packet-size smaller or equal to 4, this boils down to a noop.
*/
template<typename Packet> EIGEN_DEVICE_FUNC inline
typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type
predux_half_dowto4(const Packet& a)
{ return a; }
/** \internal \returns the product of the elements of \a a */
template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(const Packet& a)
{ return a; }
/** \internal \returns the min of the elements of \a a */
template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a)
{ return a; }
/** \internal \returns the max of the elements of \a a */
template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a)
{ return a; }
/** \internal \returns true if all coeffs of \a a means "true"
* It is supposed to be called on values returned by pcmp_*.
*/
// not needed yet
// template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_all(const Packet& a)
// { return bool(a); }
/** \internal \returns true if any coeffs of \a a means "true"
* It is supposed to be called on values returned by pcmp_*.
*/
template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a)
{
// Dirty but generic implementation where "true" is assumed to be non 0 and all the sames.
// It is expected that "true" is either:
// - Scalar(1)
// - bits full of ones (NaN for floats),
// - or first bit equals to 1 (1 for ints, smallest denormal for floats).
// For all these cases, taking the sum is just fine, and this boils down to a no-op for scalars.
typedef typename unpacket_traits<Packet>::type Scalar;
return numext::not_equal_strict(predux(a), Scalar(0));
}
/** \internal \returns the reversed elements of \a a*/
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet preverse(const Packet& a)
{ return a; }
@ -656,53 +653,104 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
/** \internal \returns the first element of a packet */
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
pfirst(const Packet& a)
{ return a; }
/** \internal \returns the max of \a a and \a b (coeff-wise)
If both \a a and \a b are NaN, NaN is returned.
Equivalent to std::fmax(a, b). */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pfmax(const Packet& a, const Packet& b) {
Packet not_nan_mask_a = pcmp_eq(a, a);
Packet not_nan_mask_b = pcmp_eq(b, b);
return pselect(not_nan_mask_a,
pselect(not_nan_mask_b, pmax(a, b), a),
b);
/** \internal \returns the sum of the elements of upper and lower half of \a a if \a a is larger than 4.
* For a packet {a0, a1, a2, a3, a4, a5, a6, a7}, it returns a half packet {a0+a4, a1+a5, a2+a6, a3+a7}
* For packet-size smaller or equal to 4, this boils down to a noop.
*/
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type
predux_half_dowto4(const Packet& a)
{ return a; }
// Slow generic implementation of Packet reduction.
template <typename Packet, typename Op>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux_helper(const Packet& a, Op op) {
typedef typename unpacket_traits<Packet>::type Scalar;
const size_t n = unpacket_traits<Packet>::size;
Scalar elements[n];
pstoreu<Scalar>(elements, a);
for(size_t k = n / 2; k > 0; k /= 2) {
for(size_t i = 0; i < k; ++i) {
elements[i] = op(elements[i], elements[i + k]);
}
}
return elements[0];
}
/** \internal \returns the min of \a a and \a b (coeff-wise)
If both \a a and \a b are NaN, NaN is returned.
Equivalent to std::fmin(a, b). */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pfmin(const Packet& a, const Packet& b) {
Packet not_nan_mask_a = pcmp_eq(a, a);
Packet not_nan_mask_b = pcmp_eq(b, b);
return pselect(not_nan_mask_a,
pselect(not_nan_mask_b, pmin(a, b), a),
b);
/** \internal \returns the sum of the elements of \a a*/
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux(const Packet& a)
{
return predux_helper(a, padd<typename unpacket_traits<Packet>::type>);
}
/** \internal \returns the max of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, NaN is returned. */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pfmax_nan(const Packet& a, const Packet& b) {
Packet not_nan_mask_a = pcmp_eq(a, a);
Packet not_nan_mask_b = pcmp_eq(b, b);
return pselect(not_nan_mask_a,
pselect(not_nan_mask_b, pmax(a, b), b),
a);
/** \internal \returns the product of the elements of \a a */
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux_mul(const Packet& a)
{
return predux_helper(a, pmul<typename unpacket_traits<Packet>::type>);
}
/** \internal \returns the min of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, NaN is returned. */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pfmin_nan(const Packet& a, const Packet& b) {
Packet not_nan_mask_a = pcmp_eq(a, a);
Packet not_nan_mask_b = pcmp_eq(b, b);
return pselect(not_nan_mask_a,
pselect(not_nan_mask_b, pmin(a, b), b),
a);
/** \internal \returns the min of the elements of \a a */
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux_min(const Packet& a)
{
return predux_helper(a, pmin<PropagateFast, typename unpacket_traits<Packet>::type>);
}
template<int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux_min(const Packet& a)
{
return predux_helper(a, pmin<NaNPropagation, typename unpacket_traits<Packet>::type>);
}
/** \internal \returns the max of the elements of \a a */
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux_max(const Packet& a)
{
return predux_helper(a, pmax<PropagateFast, typename unpacket_traits<Packet>::type>);
}
template<int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
predux_max(const Packet& a)
{
return predux_helper(a, pmax<NaNPropagation, typename unpacket_traits<Packet>::type>);
}
/** \internal \returns true if all coeffs of \a a means "true"
* It is supposed to be called on values returned by pcmp_*.
*/
// not needed yet
// template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_all(const Packet& a)
// { return bool(a); }
/** \internal \returns true if any coeffs of \a a means "true"
* It is supposed to be called on values returned by pcmp_*.
*/
template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a)
{
// Dirty but generic implementation where "true" is assumed to be non 0 and all the sames.
// It is expected that "true" is either:
// - Scalar(1)
// - bits full of ones (NaN for floats),
// - or first bit equals to 1 (1 for ints, smallest denormal for floats).
// For all these cases, taking the sum is just fine, and this boils down to a no-op for scalars.
typedef typename unpacket_traits<Packet>::type Scalar;
return numext::not_equal_strict(predux(a), Scalar(0));
}
/***************************************************************************
* The following functions might not have to be overwritten for vectorized types

View File

@ -140,29 +140,18 @@ struct scalar_min_op : binary_op_base<LhsScalar,RhsScalar>
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_min_op>::ReturnType result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
if (NaNPropagation == PropagateFast) {
return numext::mini(a, b);
} else if (NaNPropagation == PropagateNumbers) {
return internal::pfmin(a,b);
} else if (NaNPropagation == PropagateNaN) {
return internal::pfmin_nan(a,b);
}
return internal::pmin<NaNPropagation>(a, b);
}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{
if (NaNPropagation == PropagateFast) {
return internal::pmin(a,b);
} else if (NaNPropagation == PropagateNumbers) {
return internal::pfmin(a,b);
} else if (NaNPropagation == PropagateNaN) {
return internal::pfmin_nan(a,b);
}
return internal::pmin<NaNPropagation>(a,b);
}
// TODO(rmlarsen): Handle all NaN propagation semantics reductions.
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
{ return internal::predux_min(a); }
{
return internal::predux_min<NaNPropagation>(a);
}
};
template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
@ -184,29 +173,18 @@ struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar>
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_max_op>::ReturnType result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
if (NaNPropagation == PropagateFast) {
return numext::maxi(a, b);
} else if (NaNPropagation == PropagateNumbers) {
return internal::pfmax(a,b);
} else if (NaNPropagation == PropagateNaN) {
return internal::pfmax_nan(a,b);
}
return internal::pmax<NaNPropagation>(a,b);
}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{
if (NaNPropagation == PropagateFast) {
return internal::pmax(a,b);
} else if (NaNPropagation == PropagateNumbers) {
return internal::pfmax(a,b);
} else if (NaNPropagation == PropagateNaN) {
return internal::pfmax_nan(a,b);
}
return internal::pmax<NaNPropagation>(a,b);
}
// TODO(rmlarsen): Handle all NaN propagation semantics reductions.
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
{ return internal::predux_max(a); }
{
return internal::predux_max<NaNPropagation>(a);
}
};
template<typename LhsScalar,typename RhsScalar, int NaNPropagation>

View File

@ -801,10 +801,6 @@ void packetmath_notcomplex() {
Array<Scalar, Dynamic, 1>::Map(data1, PacketSize * 4).setRandom();
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = (std::min)(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_min(internal::pload<Packet>(data1))) && "internal::predux_min");
VERIFY((!PacketTraits::Vectorizable) || PacketTraits::HasMin);
VERIFY((!PacketTraits::Vectorizable) || PacketTraits::HasMax);
@ -817,13 +813,16 @@ void packetmath_notcomplex() {
using ::fmin;
using ::fmax;
#endif
CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, internal::pfmin);
CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pfmax);
CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, (internal::pmin<PropagateNumbers>));
CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pmax<PropagateNumbers>);
CHECK_CWISE1(numext::abs, internal::pabs);
CHECK_CWISE2_IF(PacketTraits::HasAbsDiff, REF_ABS_DIFF, internal::pabsdiff);
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = (std::max)(ref[0], data1[i]);
for (int i = 0; i < PacketSize; ++i) ref[0] = internal::pmin(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_min(internal::pload<Packet>(data1))) && "internal::predux_min");
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = internal::pmax(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_max(internal::pload<Packet>(data1))) && "internal::predux_max");
for (int i = 0; i < PacketSize; ++i) ref[i] = data1[0] + Scalar(i);
@ -852,16 +851,47 @@ void packetmath_notcomplex() {
}
}
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
data1[i + PacketSize] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
// Test NaN propagation.
if (!NumTraits<Scalar>::IsInteger) {
// Test reductions with no NaNs.
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = internal::pmin<PropagateNumbers>(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_min<PropagateNumbers>(internal::pload<Packet>(data1))) && "internal::predux_min<PropagateNumbers>");
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = internal::pmin<PropagateNaN>(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_min<PropagateNaN>(internal::pload<Packet>(data1))) && "internal::predux_min<PropagateNaN>");
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = internal::pmax<PropagateNumbers>(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_max<PropagateNumbers>(internal::pload<Packet>(data1))) && "internal::predux_max<PropagateNumbers>");
ref[0] = data1[0];
for (int i = 0; i < PacketSize; ++i) ref[0] = internal::pmax<PropagateNaN>(ref[0], data1[i]);
VERIFY(internal::isApprox(ref[0], internal::predux_max<PropagateNaN>(internal::pload<Packet>(data1))) && "internal::predux_max<PropagateNumbers>");
// A single NaN.
const size_t index = std::numeric_limits<size_t>::quiet_NaN() % PacketSize;
data1[index] = std::numeric_limits<Scalar>::quiet_NaN();
VERIFY(PacketSize==1 || !(numext::isnan)(internal::predux_min<PropagateNumbers>(internal::pload<Packet>(data1))));
VERIFY((numext::isnan)(internal::predux_min<PropagateNaN>(internal::pload<Packet>(data1))));
VERIFY(PacketSize==1 || !(numext::isnan)(internal::predux_max<PropagateNumbers>(internal::pload<Packet>(data1))));
VERIFY((numext::isnan)(internal::predux_max<PropagateNaN>(internal::pload<Packet>(data1))));
// All NaNs.
for (int i = 0; i < 4 * PacketSize; ++i) data1[i] = std::numeric_limits<Scalar>::quiet_NaN();
VERIFY((numext::isnan)(internal::predux_min<PropagateNumbers>(internal::pload<Packet>(data1))));
VERIFY((numext::isnan)(internal::predux_min<PropagateNaN>(internal::pload<Packet>(data1))));
VERIFY((numext::isnan)(internal::predux_max<PropagateNumbers>(internal::pload<Packet>(data1))));
VERIFY((numext::isnan)(internal::predux_max<PropagateNaN>(internal::pload<Packet>(data1))));
// Test NaN propagation for coefficient-wise min and max.
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
data1[i + PacketSize] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
}
// Note: NaN propagation is implementation defined for pmin/pmax, so we do not test it here.
CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, (internal::pmin<PropagateNumbers>));
CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pmax<PropagateNumbers>);
CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_nan_min, (internal::pmin<PropagateNaN>));
CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pmax<PropagateNaN>);
}
// Test NaN propagation for pfmin and pfmax. It should be equivalent to std::fmin.
// Note: NaN propagation is implementation defined for pmin/pmax, so we do not test it here.
CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, internal::pfmin);
CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pfmax);
CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_nan_min, internal::pfmin_nan);
CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pfmax_nan);
}
template <>

View File

@ -682,28 +682,30 @@ class TensorBase<Derived, ReadOnlyAccessors>
return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>());
}
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>
template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>
maximum(const Dims& dims) const {
return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType>());
return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType,NanPropagation>());
}
const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>
template <int NanPropagation=PropagateFast>
const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>
maximum() const {
DimensionList<Index, NumDimensions> in_dims;
return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>());
return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType,NanPropagation>());
}
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>
template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>
minimum(const Dims& dims) const {
return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType>());
return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType,NanPropagation>());
}
const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>
template <int NanPropagation=PropagateFast>
const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>
minimum() const {
DimensionList<Index, NumDimensions> in_dims;
return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>());
return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType,NanPropagation>());
}
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE

View File

@ -192,17 +192,19 @@ struct MinMaxBottomValue<T, false, false> {
};
template <typename T> struct MaxReducer
template <typename T, int NaNPropagation=PropagateFast> struct MaxReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t > *accum) { *accum = t; }
scalar_max_op<T, T, NaNPropagation> op;
*accum = op(t, *accum);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
(*accum) = pmax<Packet>(*accum, p);
scalar_max_op<T, T, NaNPropagation> op;
(*accum) = op.packetOp(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return MinMaxBottomValue<T, true, Eigen::NumTraits<T>::IsInteger>::bottom_value();
return MinMaxBottomValue<T, /*IsMax=*/true, Eigen::NumTraits<T>::IsInteger>::bottom_value();
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
@ -217,32 +219,34 @@ template <typename T> struct MaxReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return numext::maxi(saccum, predux_max(vaccum));
scalar_max_op<T, T, NaNPropagation> op;
return op(saccum, op.predux(vaccum));
}
};
template <typename T, typename Device>
struct reducer_traits<MaxReducer<T>, Device> {
template <typename T, typename Device, int NaNPropagation>
struct reducer_traits<MaxReducer<T, NaNPropagation>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
PacketAccess = PacketType<T, Device>::HasMax,
IsStateful = false,
IsExactlyAssociative = true
IsExactlyAssociative = (NaNPropagation!=PropagateFast)
};
};
template <typename T> struct MinReducer
template <typename T, int NaNPropagation=PropagateFast> struct MinReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t < *accum) { *accum = t; }
scalar_min_op<T, T, NaNPropagation> op;
*accum = op(t, *accum);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
(*accum) = pmin<Packet>(*accum, p);
scalar_min_op<T, T, NaNPropagation> op;
(*accum) = op.packetOp(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return MinMaxBottomValue<T, false, Eigen::NumTraits<T>::IsInteger>::bottom_value();
return MinMaxBottomValue<T, /*IsMax=*/false, Eigen::NumTraits<T>::IsInteger>::bottom_value();
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
@ -257,21 +261,21 @@ template <typename T> struct MinReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return numext::mini(saccum, predux_min(vaccum));
scalar_min_op<T, T, NaNPropagation> op;
return op(saccum, op.predux(vaccum));
}
};
template <typename T, typename Device>
struct reducer_traits<MinReducer<T>, Device> {
template <typename T, typename Device, int NaNPropagation>
struct reducer_traits<MinReducer<T, NaNPropagation>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
PacketAccess = PacketType<T, Device>::HasMin,
IsStateful = false,
IsExactlyAssociative = true
IsExactlyAssociative = (NaNPropagation!=PropagateFast)
};
};
template <typename T> struct ProdReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
@ -282,7 +286,6 @@ template <typename T> struct ProdReducer
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
(*accum) = pmul<Packet>(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
internal::scalar_cast_op<int, T> conv;
return conv(1);

View File

@ -302,12 +302,17 @@ static void test_select()
template <typename Scalar>
void test_minmax_nan_propagation_templ() {
for (int size = 1; size < 17; ++size) {
const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN();
std::cout << "size = " << size << std::endl;
const Scalar kNaN = std::numeric_limits<Scalar>::quiet_NaN();
const Scalar kInf = std::numeric_limits<Scalar>::infinity();
const Scalar kZero(0);
Tensor<Scalar, 1> vec_nan(size);
Tensor<Scalar, 1> vec_all_nan(size);
Tensor<Scalar, 1> vec_one_nan(size);
Tensor<Scalar, 1> vec_zero(size);
vec_nan.setConstant(kNan);
vec_all_nan.setConstant(kNaN);
vec_zero.setZero();
vec_one_nan.setZero();
vec_one_nan(size/2) = kNaN;
auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) {
for (int i = 0; i < size; ++i) {
@ -326,12 +331,12 @@ void test_minmax_nan_propagation_templ() {
// max(nan, 0) = nan
// max(0, nan) = nan
// max(0, 0) = 0
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kNan));
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_nan));
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kZero));
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_zero));
verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNan));
verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_nan));
verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kNaN));
verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_all_nan));
verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kZero));
verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_zero));
verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNaN));
verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero));
verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero));
@ -340,12 +345,12 @@ void test_minmax_nan_propagation_templ() {
// max(nan, 0) = 0
// max(0, nan) = 0
// max(0, 0) = 0
verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(kNan));
verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(vec_nan));
verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(kZero));
verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(vec_zero));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNan));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_nan));
verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(kNaN));
verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_all_nan));
verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(kZero));
verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_zero));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNaN));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero));
@ -354,12 +359,12 @@ void test_minmax_nan_propagation_templ() {
// min(nan, 0) = nan
// min(0, nan) = nan
// min(0, 0) = 0
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kNan));
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_nan));
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kZero));
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_zero));
verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNan));
verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_nan));
verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kNaN));
verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_all_nan));
verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kZero));
verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_zero));
verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNaN));
verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero));
verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero));
@ -368,14 +373,49 @@ void test_minmax_nan_propagation_templ() {
// min(nan, 0) = 0
// min(0, nan) = 0
// min(0, 0) = 0
verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(kNan));
verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(vec_nan));
verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(kZero));
verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(vec_zero));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNan));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_nan));
verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(kNaN));
verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_all_nan));
verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(kZero));
verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_zero));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNaN));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero));
// Test min and max reduction
Tensor<Scalar, 0> val;
val = vec_zero.minimum();
VERIFY_IS_EQUAL(val(), kZero);
val = vec_zero.template minimum<PropagateNaN>();
VERIFY_IS_EQUAL(val(), kZero);
val = vec_zero.template minimum<PropagateNumbers>();
VERIFY_IS_EQUAL(val(), kZero);
val = vec_zero.maximum();
VERIFY_IS_EQUAL(val(), kZero);
val = vec_zero.template maximum<PropagateNaN>();
VERIFY_IS_EQUAL(val(), kZero);
val = vec_zero.template maximum<PropagateNumbers>();
VERIFY_IS_EQUAL(val(), kZero);
// Test NaN propagation for tensor of all NaNs.
val = vec_all_nan.template minimum<PropagateNaN>();
VERIFY((numext::isnan)(val()));
val = vec_all_nan.template minimum<PropagateNumbers>();
VERIFY_IS_EQUAL(val(), kInf);
val = vec_all_nan.template maximum<PropagateNaN>();
VERIFY((numext::isnan)(val()));
val = vec_all_nan.template maximum<PropagateNumbers>();
VERIFY_IS_EQUAL(val(), -kInf);
// Test NaN propagation for tensor with a single NaN.
val = vec_one_nan.template minimum<PropagateNaN>();
VERIFY((numext::isnan)(val()));
val = vec_one_nan.template minimum<PropagateNumbers>();
VERIFY_IS_EQUAL(val(), (size == 1 ? kInf : kZero));
val = vec_one_nan.template maximum<PropagateNaN>();
VERIFY((numext::isnan)(val()));
val = vec_one_nan.template maximum<PropagateNumbers>();
VERIFY_IS_EQUAL(val(), (size == 1 ? -kInf : kZero));
}
}