mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Fix MSVC complex sqrt and packetmath test.
MSVC incorrectly handles `inf` cases for `std::sqrt<std::complex<T>>`. Here we replace it with a custom version (currently used on GPU). Also fixed the `packetmath` test, which previously skipped several corner cases since `CHECK_CWISE1` only tests the first `PacketSize` elements.
This commit is contained in:
parent
8d9cfba799
commit
f149e0ebc3
@ -338,6 +338,22 @@ struct sqrt_impl
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Complex sqrt defined in MathFunctionsImpl.h.
|
||||||
|
template<typename T> std::complex<T> complex_sqrt(const std::complex<T>& a_x);
|
||||||
|
|
||||||
|
// MSVC incorrectly handles inf cases.
|
||||||
|
#if EIGEN_COMP_MSVC > 0
|
||||||
|
template<typename T>
|
||||||
|
struct sqrt_impl<std::complex<T> >
|
||||||
|
{
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
|
||||||
|
{
|
||||||
|
return complex_sqrt<T>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
template<typename Scalar>
|
template<typename Scalar>
|
||||||
struct sqrt_retval
|
struct sqrt_retval
|
||||||
{
|
{
|
||||||
|
@ -99,6 +99,50 @@ struct hypot_impl
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Generic complex sqrt implementation that correctly handles corner cases
|
||||||
|
// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
|
||||||
|
template<typename T>
|
||||||
|
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
|
||||||
|
// Computes the principal sqrt of the input.
|
||||||
|
//
|
||||||
|
// For a complex square root of the number x + i*y. We want to find real
|
||||||
|
// numbers u and v such that
|
||||||
|
// (u + i*v)^2 = x + i*y <=>
|
||||||
|
// u^2 - v^2 + i*2*u*v = x + i*v.
|
||||||
|
// By equating the real and imaginary parts we get:
|
||||||
|
// u^2 - v^2 = x
|
||||||
|
// 2*u*v = y.
|
||||||
|
//
|
||||||
|
// For x >= 0, this has the numerically stable solution
|
||||||
|
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
|
||||||
|
// v = y / (2 * u)
|
||||||
|
// and for x < 0,
|
||||||
|
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
|
||||||
|
// u = y / (2 * v)
|
||||||
|
//
|
||||||
|
// Letting w = sqrt(0.5 * (|x| + |z|)),
|
||||||
|
// if x == 0: u = w, v = sign(y) * w
|
||||||
|
// if x > 0: u = w, v = y / (2 * w)
|
||||||
|
// if x < 0: u = |y| / (2 * w), v = sign(y) * w
|
||||||
|
|
||||||
|
const T x = numext::real(z);
|
||||||
|
const T y = numext::imag(z);
|
||||||
|
const T zero = T(0);
|
||||||
|
const T cst_half = T(0.5);
|
||||||
|
|
||||||
|
// Special case of isinf(y)
|
||||||
|
if ((numext::isinf)(y)) {
|
||||||
|
const T inf = std::numeric_limits<T>::infinity();
|
||||||
|
return std::complex<T>(inf, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
|
||||||
|
return
|
||||||
|
x == zero ? std::complex<T>(w, y < zero ? -w : w)
|
||||||
|
: x > zero ? std::complex<T>(w, y / (2 * w))
|
||||||
|
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -95,46 +95,12 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
|
|||||||
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
|
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct sqrt_impl<std::complex<T>> {
|
struct sqrt_impl<std::complex<T> >
|
||||||
static EIGEN_DEVICE_FUNC std::complex<T> run(const std::complex<T>& z) {
|
{
|
||||||
// Computes the principal sqrt of the input.
|
EIGEN_DEVICE_FUNC
|
||||||
//
|
static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
|
||||||
// For a complex square root of the number x + i*y. We want to find real
|
{
|
||||||
// numbers u and v such that
|
return complex_sqrt<T>(x);
|
||||||
// (u + i*v)^2 = x + i*y <=>
|
|
||||||
// u^2 - v^2 + i*2*u*v = x + i*v.
|
|
||||||
// By equating the real and imaginary parts we get:
|
|
||||||
// u^2 - v^2 = x
|
|
||||||
// 2*u*v = y.
|
|
||||||
//
|
|
||||||
// For x >= 0, this has the numerically stable solution
|
|
||||||
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
|
|
||||||
// v = y / (2 * u)
|
|
||||||
// and for x < 0,
|
|
||||||
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
|
|
||||||
// u = y / (2 * v)
|
|
||||||
//
|
|
||||||
// Letting w = sqrt(0.5 * (|x| + |z|)),
|
|
||||||
// if x == 0: u = w, v = sign(y) * w
|
|
||||||
// if x > 0: u = w, v = y / (2 * w)
|
|
||||||
// if x < 0: u = |y| / (2 * w), v = sign(y) * w
|
|
||||||
|
|
||||||
const T x = numext::real(z);
|
|
||||||
const T y = numext::imag(z);
|
|
||||||
const T zero = T(0);
|
|
||||||
const T cst_half = T(0.5);
|
|
||||||
|
|
||||||
// Special case of isinf(y)
|
|
||||||
if ((numext::isinf)(y)) {
|
|
||||||
const T inf = std::numeric_limits<T>::infinity();
|
|
||||||
return std::complex<T>(inf, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
|
|
||||||
return
|
|
||||||
x == zero ? std::complex<T>(w, y < zero ? -w : w)
|
|
||||||
: x > zero ? std::complex<T>(w, y / (2 * w))
|
|
||||||
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -933,7 +933,7 @@ void packetmath_complex() {
|
|||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
|
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
|
||||||
}
|
}
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size);
|
||||||
|
|
||||||
// Test misc. corner cases.
|
// Test misc. corner cases.
|
||||||
const RealScalar zero = RealScalar(0);
|
const RealScalar zero = RealScalar(0);
|
||||||
@ -944,32 +944,32 @@ void packetmath_complex() {
|
|||||||
data1[1] = Scalar(-zero, zero);
|
data1[1] = Scalar(-zero, zero);
|
||||||
data1[2] = Scalar(one, zero);
|
data1[2] = Scalar(one, zero);
|
||||||
data1[3] = Scalar(zero, one);
|
data1[3] = Scalar(zero, one);
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
|
||||||
data1[0] = Scalar(-one, zero);
|
data1[0] = Scalar(-one, zero);
|
||||||
data1[1] = Scalar(zero, -one);
|
data1[1] = Scalar(zero, -one);
|
||||||
data1[2] = Scalar(one, one);
|
data1[2] = Scalar(one, one);
|
||||||
data1[3] = Scalar(-one, -one);
|
data1[3] = Scalar(-one, -one);
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
|
||||||
data1[0] = Scalar(inf, zero);
|
data1[0] = Scalar(inf, zero);
|
||||||
data1[1] = Scalar(zero, inf);
|
data1[1] = Scalar(zero, inf);
|
||||||
data1[2] = Scalar(-inf, zero);
|
data1[2] = Scalar(-inf, zero);
|
||||||
data1[3] = Scalar(zero, -inf);
|
data1[3] = Scalar(zero, -inf);
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
|
||||||
data1[0] = Scalar(inf, inf);
|
data1[0] = Scalar(inf, inf);
|
||||||
data1[1] = Scalar(-inf, inf);
|
data1[1] = Scalar(-inf, inf);
|
||||||
data1[2] = Scalar(inf, -inf);
|
data1[2] = Scalar(inf, -inf);
|
||||||
data1[3] = Scalar(-inf, -inf);
|
data1[3] = Scalar(-inf, -inf);
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
|
||||||
data1[0] = Scalar(nan, zero);
|
data1[0] = Scalar(nan, zero);
|
||||||
data1[1] = Scalar(zero, nan);
|
data1[1] = Scalar(zero, nan);
|
||||||
data1[2] = Scalar(nan, one);
|
data1[2] = Scalar(nan, one);
|
||||||
data1[3] = Scalar(one, nan);
|
data1[3] = Scalar(one, nan);
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
|
||||||
data1[0] = Scalar(nan, nan);
|
data1[0] = Scalar(nan, nan);
|
||||||
data1[1] = Scalar(inf, nan);
|
data1[1] = Scalar(inf, nan);
|
||||||
data1[2] = Scalar(nan, inf);
|
data1[2] = Scalar(nan, inf);
|
||||||
data1[3] = Scalar(-inf, nan);
|
data1[3] = Scalar(-inf, nan);
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, 4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,6 +115,17 @@ template<typename Scalar> bool areApprox(const Scalar* a, const Scalar* b, int s
|
|||||||
VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \
|
VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks component-wise for input of size N. All of data1, data2, and ref
|
||||||
|
// should have size at least ceil(N/PacketSize)*PacketSize to avoid memory
|
||||||
|
// access errors.
|
||||||
|
#define CHECK_CWISE1_N(REFOP, POP, N) { \
|
||||||
|
for (int i=0; i<N; ++i) \
|
||||||
|
ref[i] = REFOP(data1[i]); \
|
||||||
|
for (int j=0; j<N; j+=PacketSize) \
|
||||||
|
internal::pstore(data2 + j, POP(internal::pload<Packet>(data1 + j))); \
|
||||||
|
VERIFY(test::areApprox(ref, data2, N) && #POP); \
|
||||||
|
}
|
||||||
|
|
||||||
template<bool Cond,typename Packet>
|
template<bool Cond,typename Packet>
|
||||||
struct packet_helper
|
struct packet_helper
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user