mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-13 18:37:27 +08:00
bug #1004: one more rewrite of LinSpaced for floating point numbers to guarantee both interpolation and monotonicity.
This version simply does low+i*step plus a branch to return high if i==size-1. Vectorization is accomplished with a branch and the help of pinsertlast. Some quick benchmark revealed that the overhead is really marginal, even when filling small vectors.
This commit is contained in:
parent
13fc18d3a2
commit
58146be99b
@ -43,15 +43,12 @@ template <typename Scalar, typename Packet>
|
||||
struct linspaced_op_impl<Scalar,Packet,/*IsInteger*/false>
|
||||
{
|
||||
linspaced_op_impl(const Scalar& low, const Scalar& high, Index num_steps) :
|
||||
m_low(low), m_step(num_steps==1 ? Scalar() : (high-low)/Scalar(num_steps-1)), m_interPacket(plset<Packet>(0))
|
||||
{
|
||||
// Compute the correction to be applied to ensure 'high' is returned exactly for i==num_steps-1
|
||||
m_corr = (high - (m_low+Scalar(num_steps-1)*m_step))/Scalar(num_steps<=1 ? 1 : num_steps-1);
|
||||
}
|
||||
m_low(low), m_high(high), m_size1(num_steps==1 ? 1 : num_steps-1), m_step(num_steps==1 ? Scalar() : (high-low)/Scalar(num_steps-1)), m_interPacket(plset<Packet>(0))
|
||||
{}
|
||||
|
||||
template<typename IndexType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (IndexType i) const {
|
||||
return m_low + i*m_step + i*m_corr;
|
||||
return (i==m_size1)? m_high : (m_low + i*m_step);
|
||||
}
|
||||
|
||||
template<typename IndexType>
|
||||
@ -60,11 +57,15 @@ struct linspaced_op_impl<Scalar,Packet,/*IsInteger*/false>
|
||||
// Principle:
|
||||
// [low, ..., low] + ( [step, ..., step] * ( [i, ..., i] + [0, ..., size] ) )
|
||||
Packet pi = padd(pset1<Packet>(Scalar(i)),m_interPacket);
|
||||
return padd(padd(pset1<Packet>(m_low), pmul(pset1<Packet>(m_step), pi)),
|
||||
pmul(pset1<Packet>(m_corr), pi)); }
|
||||
Packet res = padd(pset1<Packet>(m_low), pmul(pset1<Packet>(m_step), pi));
|
||||
if(i==m_size1-unpacket_traits<Packet>::size+1)
|
||||
res = pinsertlast(res, m_high);
|
||||
return res;
|
||||
}
|
||||
|
||||
const Scalar m_low;
|
||||
Scalar m_corr;
|
||||
const Scalar m_high;
|
||||
const Index m_size1;
|
||||
const Scalar m_step;
|
||||
const Packet m_interPacket;
|
||||
};
|
||||
@ -104,8 +105,8 @@ template <typename Scalar, typename PacketType> struct functor_traits< linspaced
|
||||
enum
|
||||
{
|
||||
Cost = 1,
|
||||
PacketAccess = (!NumTraits<Scalar>::IsInteger) && packet_traits<Scalar>::HasSetLinear
|
||||
&& ((!NumTraits<Scalar>::IsInteger) || packet_traits<Scalar>::HasDiv),
|
||||
PacketAccess = (!NumTraits<Scalar>::IsInteger) && packet_traits<Scalar>::HasSetLinear && packet_traits<Scalar>::HasBlend,
|
||||
/*&& ((!NumTraits<Scalar>::IsInteger) || packet_traits<Scalar>::HasDiv),*/ // <- vectorization for integer is currently disabled
|
||||
IsRepeatable = true
|
||||
};
|
||||
};
|
||||
@ -129,9 +130,34 @@ template <typename Scalar, typename PacketType> struct linspaced_op
|
||||
// Linear access is automatically determined from the operator() prototypes available for the given functor.
|
||||
// If it exposes an operator()(i,j), then we assume the i and j coefficients are required independently
|
||||
// and linear access is not possible. In all other cases, linear access is enabled.
|
||||
// Users should not have to deal with this struture.
|
||||
// Users should not have to deal with this structure.
|
||||
template<typename Functor> struct functor_has_linear_access { enum { ret = !has_binary_operator<Functor>::value }; };
|
||||
|
||||
// For unreliable compilers, let's specialize the has_*ary_operator
|
||||
// helpers so that at least built-in nullary functors work fine.
|
||||
#if !( (EIGEN_COMP_MSVC>1600) || (EIGEN_GNUC_AT_LEAST(4,8)) || (EIGEN_COMP_ICC>=1600))
|
||||
template<typename Scalar,typename IndexType>
|
||||
struct has_nullary_operator<scalar_constant_op<Scalar>,IndexType> { enum { value = 1}; };
|
||||
template<typename Scalar,typename IndexType>
|
||||
struct has_unary_operator<scalar_constant_op<Scalar>,IndexType> { enum { value = 0}; };
|
||||
template<typename Scalar,typename IndexType>
|
||||
struct has_binary_operator<scalar_constant_op<Scalar>,IndexType> { enum { value = 0}; };
|
||||
|
||||
template<typename Scalar,typename IndexType>
|
||||
struct has_nullary_operator<scalar_identity_op<Scalar>,IndexType> { enum { value = 0}; };
|
||||
template<typename Scalar,typename IndexType>
|
||||
struct has_unary_operator<scalar_identity_op<Scalar>,IndexType> { enum { value = 0}; };
|
||||
template<typename Scalar,typename IndexType>
|
||||
struct has_binary_operator<scalar_identity_op<Scalar>,IndexType> { enum { value = 1}; };
|
||||
|
||||
template<typename Scalar, typename PacketType,typename IndexType>
|
||||
struct has_nullary_operator<linspaced_op<Scalar,PacketType>,IndexType> { enum { value = 0}; };
|
||||
template<typename Scalar, typename PacketType,typename IndexType>
|
||||
struct has_unary_operator<linspaced_op<Scalar,PacketType>,IndexType> { enum { value = 1}; };
|
||||
template<typename Scalar, typename PacketType,typename IndexType>
|
||||
struct has_binary_operator<linspaced_op<Scalar,PacketType>,IndexType> { enum { value = 0}; };
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -30,6 +30,7 @@ bool equalsIdentity(const MatrixType& A)
|
||||
|
||||
bool diagOK = (A.diagonal().array() == 1).all();
|
||||
return offDiagOK && diagOK;
|
||||
|
||||
}
|
||||
|
||||
template<typename VectorType>
|
||||
@ -43,6 +44,10 @@ void testVectorType(const VectorType& base)
|
||||
Scalar low = (size == 1 ? high : internal::random<Scalar>(-500,500));
|
||||
if (low>high) std::swap(low,high);
|
||||
|
||||
// check low==high
|
||||
if(internal::random<float>(0.f,1.f)<0.05f)
|
||||
low = high;
|
||||
|
||||
const Scalar step = ((size == 1) ? 1 : (high-low)/(size-1));
|
||||
|
||||
// check whether the result yields what we expect it to do
|
||||
@ -77,6 +82,8 @@ void testVectorType(const VectorType& base)
|
||||
}
|
||||
|
||||
VERIFY( m(m.size()-1) <= high );
|
||||
VERIFY( (m.array() <= high).all() );
|
||||
VERIFY( (m.array() >= low).all() );
|
||||
|
||||
|
||||
VERIFY( m(m.size()-1) >= low );
|
||||
|
Loading…
x
Reference in New Issue
Block a user