mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-06 19:10:36 +08:00
Fix numext::arg return type.
The cxx11 path for `numext::arg` incorrectly returned the complex type instead of the real type, leading to compile errors. Fixed this and added tests. Related to !477, which uncovered the issue.
This commit is contained in:
parent
722ca0b665
commit
90e9a33e1c
@ -592,8 +592,9 @@ struct arg_default_impl;
|
||||
|
||||
template<typename Scalar>
|
||||
struct arg_default_impl<Scalar, true> {
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline Scalar run(const Scalar& x)
|
||||
static inline RealScalar run(const Scalar& x)
|
||||
{
|
||||
#if defined(EIGEN_HIP_DEVICE_COMPILE)
|
||||
// HIP does not seem to have a native device side implementation for the math routine "arg"
|
||||
@ -601,7 +602,7 @@ struct arg_default_impl<Scalar, true> {
|
||||
#else
|
||||
EIGEN_USING_STD(arg);
|
||||
#endif
|
||||
return static_cast<Scalar>(arg(x));
|
||||
return static_cast<RealScalar>(arg(x));
|
||||
}
|
||||
};
|
||||
|
||||
@ -612,7 +613,7 @@ struct arg_default_impl<Scalar, false> {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline RealScalar run(const Scalar& x)
|
||||
{
|
||||
return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
|
||||
return (x < Scalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0);
|
||||
}
|
||||
};
|
||||
#else
|
||||
@ -623,7 +624,7 @@ struct arg_default_impl
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline RealScalar run(const Scalar& x)
|
||||
{
|
||||
return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
|
||||
return (x < RealScalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -61,6 +61,20 @@ void check_abs() {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void check_arg() {
|
||||
typedef typename NumTraits<T>::Real Real;
|
||||
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
|
||||
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
|
||||
|
||||
for(int k=0; k<100; ++k)
|
||||
{
|
||||
T x = internal::random<T>();
|
||||
Real y = numext::arg(x);
|
||||
VERIFY_IS_APPROX( y, std::arg(x) );
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct check_sqrt_impl {
|
||||
static void run() {
|
||||
@ -242,10 +256,12 @@ EIGEN_DECLARE_TEST(numext) {
|
||||
CALL_SUBTEST( check_abs<float>() );
|
||||
CALL_SUBTEST( check_abs<double>() );
|
||||
CALL_SUBTEST( check_abs<long double>() );
|
||||
|
||||
CALL_SUBTEST( check_abs<std::complex<float> >() );
|
||||
CALL_SUBTEST( check_abs<std::complex<double> >() );
|
||||
|
||||
CALL_SUBTEST( check_arg<std::complex<float> >() );
|
||||
CALL_SUBTEST( check_arg<std::complex<double> >() );
|
||||
|
||||
CALL_SUBTEST( check_sqrt<float>() );
|
||||
CALL_SUBTEST( check_sqrt<double>() );
|
||||
CALL_SUBTEST( check_sqrt<std::complex<float> >() );
|
||||
|
Loading…
x
Reference in New Issue
Block a user