Add support for vectorizing logical comparisons.

This commit is contained in:
derekjchow 2021-07-19 16:51:17 -07:00 committed by Derek Chow
parent a77638387d
commit 66ca41bd47
2 changed files with 28 additions and 4 deletions

View File

@ -219,6 +219,7 @@ struct packet_traits<int8_t> : default_packet_traits
size = 16,
HasHalfPacket = 1,
HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@ -248,6 +249,7 @@ struct packet_traits<uint8_t> : default_packet_traits
size = 16,
HasHalfPacket = 1,
HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,

View File

@ -110,13 +110,13 @@ struct scalar_conj_product_op : binary_op_base<LhsScalar,RhsScalar>
enum {
Conj = NumTraits<LhsScalar>::IsComplex
};
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_conj_product_op>::ReturnType result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const
{ return conj_helper<LhsScalar,RhsScalar,Conj,false>().pmul(a,b); }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return conj_helper<Packet,Packet,Conj,false>().pmul(a,b); }
@ -205,7 +205,8 @@ template<typename LhsScalar, typename RhsScalar, ComparisonName cmp>
struct functor_traits<scalar_cmp_op<LhsScalar,RhsScalar, cmp> > {
enum {
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
PacketAccess = false
PacketAccess = is_same<LhsScalar, RhsScalar>::value &&
packet_traits<LhsScalar>::HasCmp
};
};
@ -221,6 +222,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_EQ> : binary_op_base<LhsScalar,Rhs
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a==b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(a,b); }
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,RhsScalar>
@ -228,6 +232,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,Rhs
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_lt(a,b); }
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,RhsScalar>
@ -235,6 +242,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,Rhs
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_le(a,b); }
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,RhsScalar>
@ -242,6 +252,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,Rhs
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_lt(b,a); }
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,RhsScalar>
@ -249,6 +262,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,Rhs
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_le(b,a); }
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,RhsScalar>
@ -256,6 +272,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return !(a<=b || b<=a);}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(internal::por(internal::pcmp_le(a, b), internal::pcmp_le(b, a)), internal::pzero(a)); }
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,RhsScalar>
@ -263,6 +282,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,Rh
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(internal::pcmp_eq(a, b), internal::pzero(a)); }
};
/** \internal