Turned the Index type used by the nullary wrapper into a template parameter.

This commit is contained in:
Benoit Steiner 2016-09-02 14:10:29 -07:00
parent 6c05c3dd49
commit 5a6be66cef

View File

@ -343,25 +343,29 @@ template<typename Scalar,typename NullaryOp,
bool has_binary = has_binary_operator<NullaryOp>::value>
struct nullary_wrapper
{
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j) const { return op(i,j); }
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i) const { return op(i); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const { return op.template packetOp<T>(i,j); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i) const { return op.template packetOp<T>(i); }
template <typename T, typename Index> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const { return op.template packetOp<T>(i,j); }
template <typename T, typename Index> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i) const { return op.template packetOp<T>(i); }
};
template<typename Scalar,typename NullaryOp>
struct nullary_wrapper<Scalar,NullaryOp,true,false,false>
{
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index=0, Index=0) const { return op(); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index=0, Index=0) const { return op.template packetOp<T>(); }
template <typename T, typename Index> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index=0, Index=0) const { return op.template packetOp<T>(); }
};
template<typename Scalar,typename NullaryOp>
struct nullary_wrapper<Scalar,NullaryOp,false,false,true>
{
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j=0) const { return op(i,j); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j=0) const { return op.template packetOp<T>(i,j); }
template <typename T, typename Index> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j=0) const { return op.template packetOp<T>(i,j); }
};
// We need the following specialization for vector-only functors assigned to a runtime vector,
@ -374,11 +378,12 @@ struct nullary_wrapper<Scalar,NullaryOp,false,true,false>
typedef nullary_wrapper<Scalar,NullaryOp,false,true,true> base;
using base::operator();
using base::packetOp;
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j) const {
eigen_assert(i==0 || j==0);
return op(i+j);
}
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const {
template <typename T, typename Index> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const {
eigen_assert(i==0 || j==0);
return op.template packetOp<T>(i+j);
}
@ -413,30 +418,32 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> >
typedef typename XprType::CoeffReturnType CoeffReturnType;
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const
{
return m_wrapper(m_functor, row, col);
}
template <typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const
{
return m_wrapper(m_functor,index);
}
template<int LoadMode, typename PacketType>
template<int LoadMode, typename PacketType, typename Index>
EIGEN_STRONG_INLINE
PacketType packet(Index row, Index col) const
{
return m_wrapper.template packetOp<PacketType>(m_functor,row, col);
return m_wrapper.template packetOp<PacketType>(m_functor, row, col);
}
template<int LoadMode, typename PacketType>
template<int LoadMode, typename PacketType, typename Index>
EIGEN_STRONG_INLINE
PacketType packet(Index index) const
{
return m_wrapper.template packetOp<PacketType>(m_functor,index);
return m_wrapper.template packetOp<PacketType>(m_functor, index);
}
protected: