Fix numext::arg return type.

The cxx11 path for `numext::arg` incorrectly returned the complex type
instead of the real type, leading to compile errors. Fixed this and
added tests.

Related to !477, which uncovered the issue.
This commit is contained in:
Antonio Sanchez 2021-05-07 08:24:32 -07:00 committed by Rasmus Munk Larsen
parent 722ca0b665
commit 90e9a33e1c
2 changed files with 22 additions and 5 deletions

View File

@ -592,8 +592,9 @@ struct arg_default_impl;
template<typename Scalar>
struct arg_default_impl<Scalar, true> {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
static inline RealScalar run(const Scalar& x)
{
#if defined(EIGEN_HIP_DEVICE_COMPILE)
// HIP does not seem to have a native device side implementation for the math routine "arg"
@ -601,7 +602,7 @@ struct arg_default_impl<Scalar, true> {
#else
EIGEN_USING_STD(arg);
#endif
return static_cast<Scalar>(arg(x));
return static_cast<RealScalar>(arg(x));
}
};
@ -612,7 +613,7 @@ struct arg_default_impl<Scalar, false> {
EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x)
{
return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
return (x < Scalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0);
}
};
#else
@ -623,7 +624,7 @@ struct arg_default_impl
EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x)
{
return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
return (x < RealScalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0);
}
};

View File

@ -61,6 +61,20 @@ void check_abs() {
}
}
template<typename T>
void check_arg() {
typedef typename NumTraits<T>::Real Real;
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
for(int k=0; k<100; ++k)
{
T x = internal::random<T>();
Real y = numext::arg(x);
VERIFY_IS_APPROX( y, std::arg(x) );
}
}
template<typename T>
struct check_sqrt_impl {
static void run() {
@ -242,10 +256,12 @@ EIGEN_DECLARE_TEST(numext) {
CALL_SUBTEST( check_abs<float>() );
CALL_SUBTEST( check_abs<double>() );
CALL_SUBTEST( check_abs<long double>() );
CALL_SUBTEST( check_abs<std::complex<float> >() );
CALL_SUBTEST( check_abs<std::complex<double> >() );
CALL_SUBTEST( check_arg<std::complex<float> >() );
CALL_SUBTEST( check_arg<std::complex<double> >() );
CALL_SUBTEST( check_sqrt<float>() );
CALL_SUBTEST( check_sqrt<double>() );
CALL_SUBTEST( check_sqrt<std::complex<float> >() );