add EqualSpaced / setEqualSpaced

This commit is contained in:
Charles Schlosser 2022-12-13 00:54:57 +00:00 committed by Rasmus Munk Larsen
parent 273f803846
commit 2004831941
4 changed files with 78 additions and 6 deletions

View File

@ -306,6 +306,20 @@ DenseBase<Derived>::LinSpaced(const Scalar& low, const Scalar& high)
return DenseBase<Derived>::NullaryExpr(Derived::SizeAtCompileTime, internal::linspaced_op<Scalar>(low,high,Derived::SizeAtCompileTime));
}
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::RandomAccessEqualSpacedReturnType
DenseBase<Derived>::EqualSpaced(Index size, const Scalar& low, const Scalar& step) {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return DenseBase<Derived>::NullaryExpr(size, internal::equalspaced_op<Scalar>(low, step));
}
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase<Derived>::RandomAccessEqualSpacedReturnType
DenseBase<Derived>::EqualSpaced(const Scalar& low, const Scalar& step) {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return DenseBase<Derived>::NullaryExpr(Derived::SizeAtCompileTime, internal::equalspaced_op<Scalar>(low, step));
}
/** \returns true if all coefficients in this matrix are approximately equal to \a val, to within precision \a prec */
template<typename Derived>
EIGEN_DEVICE_FUNC bool DenseBase<Derived>::isApproxToConstant
@ -455,6 +469,19 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setLinSpaced(
return setLinSpaced(size(), low, high);
}
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setEqualSpaced(Index newSize, const Scalar& low,
const Scalar& step) {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return derived() = Derived::NullaryExpr(newSize, internal::equalspaced_op<Scalar>(low, step));
}
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::setEqualSpaced(const Scalar& low,
const Scalar& step) {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
return setEqualSpaced(size(), low, step);
}
// zero:
/** \returns an expression of a zero matrix.

View File

@ -258,6 +258,8 @@ template<typename Derived> class DenseBase
EIGEN_DEPRECATED typedef CwiseNullaryOp<internal::linspaced_op<Scalar>,PlainObject> SequentialLinSpacedReturnType;
/** \internal Represents a vector with linearly spaced coefficients that allows random access. */
typedef CwiseNullaryOp<internal::linspaced_op<Scalar>,PlainObject> RandomAccessLinSpacedReturnType;
/** \internal Represents a vector with equally spaced coefficients that allows random access. */
typedef CwiseNullaryOp<internal::equalspaced_op<Scalar>, PlainObject> RandomAccessEqualSpacedReturnType;
/** \internal the return type of MatrixBase::eigenvalues() */
typedef Matrix<typename NumTraits<typename internal::traits<Derived>::Scalar>::Real, internal::traits<Derived>::ColsAtCompileTime, 1> EigenvaluesReturnType;
@ -336,6 +338,11 @@ template<typename Derived> class DenseBase
EIGEN_DEVICE_FUNC static const RandomAccessLinSpacedReturnType
LinSpaced(const Scalar& low, const Scalar& high);
EIGEN_DEVICE_FUNC static const RandomAccessEqualSpacedReturnType
EqualSpaced(Index size, const Scalar& low, const Scalar& step);
EIGEN_DEVICE_FUNC static const RandomAccessEqualSpacedReturnType
EqualSpaced(const Scalar& low, const Scalar& step);
template<typename CustomNullaryOp> EIGEN_DEVICE_FUNC
static const CwiseNullaryOp<CustomNullaryOp, PlainObject>
NullaryExpr(Index rows, Index cols, const CustomNullaryOp& func);
@ -357,6 +364,8 @@ template<typename Derived> class DenseBase
EIGEN_DEVICE_FUNC Derived& setConstant(const Scalar& value);
EIGEN_DEVICE_FUNC Derived& setLinSpaced(Index size, const Scalar& low, const Scalar& high);
EIGEN_DEVICE_FUNC Derived& setLinSpaced(const Scalar& low, const Scalar& high);
EIGEN_DEVICE_FUNC Derived& setEqualSpaced(Index size, const Scalar& low, const Scalar& step);
EIGEN_DEVICE_FUNC Derived& setEqualSpaced(const Scalar& low, const Scalar& step);
EIGEN_DEVICE_FUNC Derived& setZero();
EIGEN_DEVICE_FUNC Derived& setOnes();
EIGEN_DEVICE_FUNC Derived& setRandom();

View File

@ -145,6 +145,39 @@ template <typename Scalar> struct linspaced_op
const linspaced_op_impl<Scalar,NumTraits<Scalar>::IsInteger> impl;
};
template <typename Scalar>
struct equalspaced_op {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC equalspaced_op(const Scalar& start, const Scalar& step) : m_start(start), m_step(step) {}
template <typename IndexType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(IndexType i) const {
return m_start + m_step * static_cast<Scalar>(i);
}
template <typename Packet, typename IndexType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(IndexType i) const {
const Packet cst_start = pset1<Packet>(m_start);
const Packet cst_step = pset1<Packet>(m_step);
const Packet cst_lin0 = plset<Packet>(Scalar(0));
const Packet cst_offset = pmadd(cst_lin0, cst_step, cst_start);
Packet istep = pset1<Packet>(static_cast<Scalar>(i) * m_step);
return padd(cst_offset, istep);
}
const Scalar m_start;
const Scalar m_step;
};
template <typename Scalar>
struct functor_traits<equalspaced_op<Scalar> > {
enum {
Cost = NumTraits<Scalar>::AddCost + NumTraits<Scalar>::MulCost,
PacketAccess =
packet_traits<Scalar>::HasSetLinear && packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasAdd,
IsRepeatable = true
};
};
// Linear access is automatically determined from the operator() prototypes available for the given functor.
// If it exposes an operator()(i,j), then we assume the i and j coefficients are required independently
// and linear access is not possible. In all other cases, linear access is enabled.

View File

@ -78,8 +78,9 @@ void testVectorType(const VectorType& base)
const Scalar step = ((size == 1) ? 1 : (high-low)/RealScalar(size-1));
// check whether the result yields what we expect it to do
VectorType m(base);
VectorType m(base), o(base);
m.setLinSpaced(size,low,high);
o.setEqualSpaced(size, low, step);
if(!NumTraits<Scalar>::IsInteger)
{
@ -87,6 +88,7 @@ void testVectorType(const VectorType& base)
for (int i=0; i<size; ++i)
n(i) = low+RealScalar(i)*step;
VERIFY_IS_APPROX(m,n);
VERIFY_IS_APPROX(n,o);
CALL_SUBTEST( check_extremity_accuracy(m, low, high) );
}
@ -256,11 +258,12 @@ void nullary_overflow()
{
// Check possible overflow issue
int n = 60000;
ArrayXi a1(n), a2(n);
a1.setLinSpaced(n, 0, n-1);
for(int i=0; i<n; ++i)
a2(i) = i;
VERIFY_IS_APPROX(a1,a2);
ArrayXi a1(n), a2(n), a_ref(n);
a1.setLinSpaced(n, 0, n - 1);
a2.setEqualSpaced(n, 0, 1);
for (int i = 0; i < n; ++i) a_ref(i) = i;
VERIFY_IS_APPROX(a1, a_ref);
VERIFY_IS_APPROX(a2, a_ref);
}
template<int>