Made it possible to leverage several binary functor in a CUDA kernel

Explicitely specified the return type of the various scalar_cmp_op functors.
This commit is contained in:
Benoit Steiner 2015-12-02 17:21:33 -08:00
parent c5b86893e7
commit d2d4c45d55

View File

@ -29,7 +29,7 @@ template<typename Scalar> struct scalar_sum_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::padd(a,b); }
template<typename Packet>
EIGEN_STRONG_INLINE const Scalar predux(const Packet& a) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar predux(const Packet& a) const
{ return internal::predux(a); }
};
template<typename Scalar>
@ -68,7 +68,7 @@ template<typename LhsScalar,typename RhsScalar> struct scalar_product_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pmul(a,b); }
template<typename Packet>
EIGEN_STRONG_INLINE const result_type predux(const Packet& a) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type predux(const Packet& a) const
{ return internal::predux_mul(a); }
};
template<typename LhsScalar,typename RhsScalar>
@ -175,30 +175,37 @@ struct result_of<scalar_cmp_op<Scalar, Cmp>(Scalar,Scalar)> {
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_EQ> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a==b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_LT> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a<b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_LE> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a<=b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_GT> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_GE> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>=b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_UNORD> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return !(a<=b || b<=a);}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_NEQ> {
typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a!=b;}
};
@ -448,9 +455,9 @@ struct functor_traits<scalar_add_op<Scalar> >
*/
template<typename Scalar>
struct scalar_sub_op {
inline scalar_sub_op(const scalar_sub_op& other) : m_other(other.m_other) { }
inline scalar_sub_op(const Scalar& other) : m_other(other) { }
inline Scalar operator() (const Scalar& a) const { return a - m_other; }
EIGEN_DEVICE_FUNC inline scalar_sub_op(const scalar_sub_op& other) : m_other(other.m_other) { }
EIGEN_DEVICE_FUNC inline scalar_sub_op(const Scalar& other) : m_other(other) { }
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return a - m_other; }
template <typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
{ return internal::psub(a, pset1<Packet>(m_other)); }
@ -466,9 +473,9 @@ struct functor_traits<scalar_sub_op<Scalar> >
*/
template<typename Scalar>
struct scalar_rsub_op {
inline scalar_rsub_op(const scalar_rsub_op& other) : m_other(other.m_other) { }
inline scalar_rsub_op(const Scalar& other) : m_other(other) { }
inline Scalar operator() (const Scalar& a) const { return m_other - a; }
EIGEN_DEVICE_FUNC inline scalar_rsub_op(const scalar_rsub_op& other) : m_other(other.m_other) { }
EIGEN_DEVICE_FUNC inline scalar_rsub_op(const Scalar& other) : m_other(other) { }
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return m_other - a; }
template <typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
{ return internal::psub(pset1<Packet>(m_other), a); }
@ -485,8 +492,8 @@ struct functor_traits<scalar_rsub_op<Scalar> >
template<typename Scalar>
struct scalar_pow_op {
// FIXME default copy constructors seems bugged with std::complex<>
inline scalar_pow_op(const scalar_pow_op& other) : m_exponent(other.m_exponent) { }
inline scalar_pow_op(const Scalar& exponent) : m_exponent(exponent) {}
EIGEN_DEVICE_FUNC inline scalar_pow_op(const scalar_pow_op& other) : m_exponent(other.m_exponent) { }
EIGEN_DEVICE_FUNC inline scalar_pow_op(const Scalar& exponent) : m_exponent(exponent) {}
EIGEN_DEVICE_FUNC
inline Scalar operator() (const Scalar& a) const { return numext::pow(a, m_exponent); }
const Scalar m_exponent;
@ -501,7 +508,7 @@ struct functor_traits<scalar_pow_op<Scalar> >
*/
template<typename Scalar>
struct scalar_inverse_mult_op {
scalar_inverse_mult_op(const Scalar& other) : m_other(other) {}
EIGEN_DEVICE_FUNC scalar_inverse_mult_op(const Scalar& other) : m_other(other) {}
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return m_other / a; }
template<typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const