Merged in rmlarsen/eigen2 (pull request PR-392)

Add vectorized clip functor for Eigen Tensors
This commit is contained in:
Benoit Steiner 2018-05-16 01:15:56 +00:00
commit ad355b3f05
3 changed files with 46 additions and 0 deletions

View File

@ -209,6 +209,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_abs_op<Scalar>());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clip_op<Scalar>, const Derived>
clip(Scalar min, Scalar max) const {
return unaryExpr(internal::scalar_clip_op<Scalar>(min, max));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
conjugate() const {

View File

@ -487,6 +487,25 @@ struct functor_traits<GaussianGenerator<T, Index, NumDims> > {
};
};
template <typename Scalar>
struct scalar_clip_op {
EIGEN_DEVICE_FUNC inline scalar_clip_op(const Scalar& _min, const Scalar& _max) : m_min(_min), m_max(_max) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
operator()(const Scalar& x) const {
return numext::mini(numext::maxi(x, m_min), m_max);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
packetOp(const Packet& x) const {
return internal::pmin(internal::pmax(x, pset1<Packet>(m_min)), pset1<Packet>(m_max));
}
const Scalar m_min;
const Scalar m_max;
};
template<typename Scalar>
struct functor_traits<scalar_clip_op<Scalar> >
{ enum { Cost = 2 * NumTraits<Scalar>::AddCost, PacketAccess = (packet_traits<Scalar>::HasMin && packet_traits<Scalar>::HasMax)}; };
} // end namespace internal
} // end namespace Eigen

View File

@ -340,6 +340,26 @@ void test_minmax_nan_propagation_templ() {
}
}
static void test_clip()
{
Tensor<float, 1> vec(6);
vec(0) = 4.0;
vec(1) = 8.0;
vec(2) = 15.0;
vec(3) = 16.0;
vec(4) = 23.0;
vec(5) = 42.0;
float kMin = 20;
float kMax = 30;
Tensor<float, 1> vec_clipped(6);
vec_clipped = vec.clip(kMin, kMax);
for (int i = 0; i < 6; ++i) {
VERIFY_IS_EQUAL(vec_clipped(i), numext::mini(numext::maxi(vec(i), kMin), kMax));
}
}
static void test_minmax_nan_propagation()
{
test_minmax_nan_propagation_templ<float>();
@ -356,5 +376,6 @@ void test_cxx11_tensor_expr()
CALL_SUBTEST(test_functors());
CALL_SUBTEST(test_type_casting());
CALL_SUBTEST(test_select());
CALL_SUBTEST(test_clip());
CALL_SUBTEST(test_minmax_nan_propagation());
}