From f149e0ebc3d3d5ca63234e58ca72690caf07e3b5 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 7 Jan 2021 09:39:05 -0800 Subject: [PATCH] Fix MSVC complex sqrt and packetmath test. MSVC incorrectly handles `inf` cases for `std::sqrt>`. Here we replace it with a custom version (currently used on GPU). Also fixed the `packetmath` test, which previously skipped several corner cases since `CHECK_CWISE1` only tests the first `PacketSize` elements. --- Eigen/src/Core/MathFunctions.h | 16 +++++++++++ Eigen/src/Core/MathFunctionsImpl.h | 44 ++++++++++++++++++++++++++++ Eigen/src/Core/arch/CUDA/Complex.h | 46 ++++-------------------------- test/packetmath.cpp | 14 ++++----- test/packetmath_test_shared.h | 11 +++++++ 5 files changed, 84 insertions(+), 47 deletions(-) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 928bc8e72..5b5ca46f6 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -338,6 +338,22 @@ struct sqrt_impl } }; +// Complex sqrt defined in MathFunctionsImpl.h. +template std::complex complex_sqrt(const std::complex& a_x); + +// MSVC incorrectly handles inf cases. +#if EIGEN_COMP_MSVC > 0 +template +struct sqrt_impl > +{ + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) + { + return complex_sqrt(x); + } +}; +#endif + template struct sqrt_retval { diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 8288ad834..8ecddebf6 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -99,6 +99,50 @@ struct hypot_impl } }; +// Generic complex sqrt implementation that correctly handles corner cases +// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt +template +EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& z) { + // Computes the principal sqrt of the input. + // + // For a complex square root of the number x + i*y. We want to find real + // numbers u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = y / (2 * u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = y / (2 * v) + // + // Letting w = sqrt(0.5 * (|x| + |z|)), + // if x == 0: u = w, v = sign(y) * w + // if x > 0: u = w, v = y / (2 * w) + // if x < 0: u = |y| / (2 * w), v = sign(y) * w + + const T x = numext::real(z); + const T y = numext::imag(z); + const T zero = T(0); + const T cst_half = T(0.5); + + // Special case of isinf(y) + if ((numext::isinf)(y)) { + const T inf = std::numeric_limits::infinity(); + return std::complex(inf, y); + } + + T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); + return + x == zero ? std::complex(w, y < zero ? -w : w) + : x > zero ? std::complex(w, y / (2 * w)) + : std::complex(numext::abs(y) / (2 * w), y < zero ? -w : w ); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index 69334cafe..ab0207cac 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -95,46 +95,12 @@ template struct scalar_quotient_op, const std: template struct scalar_quotient_op, std::complex > : scalar_quotient_op, const std::complex > {}; template -struct sqrt_impl> { - static EIGEN_DEVICE_FUNC std::complex run(const std::complex& z) { - // Computes the principal sqrt of the input. - // - // For a complex square root of the number x + i*y. We want to find real - // numbers u and v such that - // (u + i*v)^2 = x + i*y <=> - // u^2 - v^2 + i*2*u*v = x + i*v. - // By equating the real and imaginary parts we get: - // u^2 - v^2 = x - // 2*u*v = y. - // - // For x >= 0, this has the numerically stable solution - // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) - // v = y / (2 * u) - // and for x < 0, - // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) - // u = y / (2 * v) - // - // Letting w = sqrt(0.5 * (|x| + |z|)), - // if x == 0: u = w, v = sign(y) * w - // if x > 0: u = w, v = y / (2 * w) - // if x < 0: u = |y| / (2 * w), v = sign(y) * w - - const T x = numext::real(z); - const T y = numext::imag(z); - const T zero = T(0); - const T cst_half = T(0.5); - - // Special case of isinf(y) - if ((numext::isinf)(y)) { - const T inf = std::numeric_limits::infinity(); - return std::complex(inf, y); - } - - T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); - return - x == zero ? std::complex(w, y < zero ? -w : w) - : x > zero ? std::complex(w, y / (2 * w)) - : std::complex(numext::abs(y) / (2 * w), y < zero ? -w : w ); +struct sqrt_impl > +{ + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) + { + return complex_sqrt(x); } }; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index f19d72502..ab9bec183 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -933,7 +933,7 @@ void packetmath_complex() { for (int i = 0; i < size; ++i) { data1[i] = Scalar(internal::random(), internal::random()); } - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size); // Test misc. corner cases. const RealScalar zero = RealScalar(0); @@ -944,32 +944,32 @@ void packetmath_complex() { data1[1] = Scalar(-zero, zero); data1[2] = Scalar(one, zero); data1[3] = Scalar(zero, one); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(-one, zero); data1[1] = Scalar(zero, -one); data1[2] = Scalar(one, one); data1[3] = Scalar(-one, -one); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(inf, zero); data1[1] = Scalar(zero, inf); data1[2] = Scalar(-inf, zero); data1[3] = Scalar(zero, -inf); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(inf, inf); data1[1] = Scalar(-inf, inf); data1[2] = Scalar(inf, -inf); data1[3] = Scalar(-inf, -inf); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(nan, zero); data1[1] = Scalar(zero, nan); data1[2] = Scalar(nan, one); data1[3] = Scalar(one, nan); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); data1[0] = Scalar(nan, nan); data1[1] = Scalar(inf, nan); data1[2] = Scalar(nan, inf); data1[3] = Scalar(-inf, nan); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4); } } diff --git a/test/packetmath_test_shared.h b/test/packetmath_test_shared.h index f8dc3711c..46a42604b 100644 --- a/test/packetmath_test_shared.h +++ b/test/packetmath_test_shared.h @@ -115,6 +115,17 @@ template bool areApprox(const Scalar* a, const Scalar* b, int s VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \ } +// Checks component-wise for input of size N. All of data1, data2, and ref +// should have size at least ceil(N/PacketSize)*PacketSize to avoid memory +// access errors. +#define CHECK_CWISE1_N(REFOP, POP, N) { \ + for (int i=0; i(data1 + j))); \ + VERIFY(test::areApprox(ref, data2, N) && #POP); \ +} + template struct packet_helper {