mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Merged in rmlarsen/eigen2 (pull request PR-392)
Add vectorized clip functor for Eigen Tensors
This commit is contained in:
commit
ad355b3f05
@ -209,6 +209,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return unaryExpr(internal::scalar_abs_op<Scalar>());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clip_op<Scalar>, const Derived>
|
||||
clip(Scalar min, Scalar max) const {
|
||||
return unaryExpr(internal::scalar_clip_op<Scalar>(min, max));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
|
||||
conjugate() const {
|
||||
|
@ -487,6 +487,25 @@ struct functor_traits<GaussianGenerator<T, Index, NumDims> > {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct scalar_clip_op {
|
||||
EIGEN_DEVICE_FUNC inline scalar_clip_op(const Scalar& _min, const Scalar& _max) : m_min(_min), m_max(_max) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
|
||||
operator()(const Scalar& x) const {
|
||||
return numext::mini(numext::maxi(x, m_min), m_max);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||
packetOp(const Packet& x) const {
|
||||
return internal::pmin(internal::pmax(x, pset1<Packet>(m_min)), pset1<Packet>(m_max));
|
||||
}
|
||||
const Scalar m_min;
|
||||
const Scalar m_max;
|
||||
};
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_clip_op<Scalar> >
|
||||
{ enum { Cost = 2 * NumTraits<Scalar>::AddCost, PacketAccess = (packet_traits<Scalar>::HasMin && packet_traits<Scalar>::HasMax)}; };
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
|
@ -340,6 +340,26 @@ void test_minmax_nan_propagation_templ() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_clip()
|
||||
{
|
||||
Tensor<float, 1> vec(6);
|
||||
vec(0) = 4.0;
|
||||
vec(1) = 8.0;
|
||||
vec(2) = 15.0;
|
||||
vec(3) = 16.0;
|
||||
vec(4) = 23.0;
|
||||
vec(5) = 42.0;
|
||||
|
||||
float kMin = 20;
|
||||
float kMax = 30;
|
||||
|
||||
Tensor<float, 1> vec_clipped(6);
|
||||
vec_clipped = vec.clip(kMin, kMax);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
VERIFY_IS_EQUAL(vec_clipped(i), numext::mini(numext::maxi(vec(i), kMin), kMax));
|
||||
}
|
||||
}
|
||||
|
||||
static void test_minmax_nan_propagation()
|
||||
{
|
||||
test_minmax_nan_propagation_templ<float>();
|
||||
@ -356,5 +376,6 @@ void test_cxx11_tensor_expr()
|
||||
CALL_SUBTEST(test_functors());
|
||||
CALL_SUBTEST(test_type_casting());
|
||||
CALL_SUBTEST(test_select());
|
||||
CALL_SUBTEST(test_clip());
|
||||
CALL_SUBTEST(test_minmax_nan_propagation());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user