mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-18 14:34:17 +08:00
bde6741641
Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously `complex_sqrt` was only used for CUDA and MSVC), and implements custom `complex_rsqrt`. Also introduces `numext::rsqrt` to simplify implementation, and modified `numext::hypot` to adhere to IEEE IEC 6059 for special cases. The `complex_sqrt` and `complex_rsqrt` implementations were found to be significantly faster than `std::sqrt<std::complex<T>>` and `1/numext::sqrt<std::complex<T>>`. Benchmark file attached. ``` GCC 10, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt<std::complex<float>> 9.21 ns 9.21 ns 73225448 BM_StdSqrt<std::complex<float>> 17.1 ns 17.1 ns 40966545 BM_Sqrt<std::complex<double>> 8.53 ns 8.53 ns 81111062 BM_StdSqrt<std::complex<double>> 21.5 ns 21.5 ns 32757248 BM_Rsqrt<std::complex<float>> 10.3 ns 10.3 ns 68047474 BM_DivSqrt<std::complex<float>> 16.3 ns 16.3 ns 42770127 BM_Rsqrt<std::complex<double>> 11.3 ns 11.3 ns 61322028 BM_DivSqrt<std::complex<double>> 16.5 ns 16.5 ns 42200711 Clang 11, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt<std::complex<float>> 7.46 ns 7.45 ns 90742042 BM_StdSqrt<std::complex<float>> 16.6 ns 16.6 ns 42369878 BM_Sqrt<std::complex<double>> 8.49 ns 8.49 ns 81629030 BM_StdSqrt<std::complex<double>> 21.8 ns 21.7 ns 31809588 BM_Rsqrt<std::complex<float>> 8.39 ns 8.39 ns 82933666 BM_DivSqrt<std::complex<float>> 14.4 ns 14.4 ns 48638676 BM_Rsqrt<std::complex<double>> 9.83 ns 9.82 ns 70068956 BM_DivSqrt<std::complex<double>> 15.7 ns 15.7 ns 44487798 Clang 9, Pixel 2, aarch64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt<std::complex<float>> 24.2 ns 24.1 ns 28616031 BM_StdSqrt<std::complex<float>> 104 ns 103 ns 6826926 BM_Sqrt<std::complex<double>> 31.8 ns 31.8 ns 22157591 BM_StdSqrt<std::complex<double>> 128 ns 128 ns 5437375 BM_Rsqrt<std::complex<float>> 31.9 ns 31.8 ns 22384383 BM_DivSqrt<std::complex<float>> 99.2 ns 98.9 ns 7250438 BM_Rsqrt<std::complex<double>> 46.0 ns 45.8 ns 15338689 BM_DivSqrt<std::complex<double>> 119 ns 119 ns 5898944 ```
260 lines
8.4 KiB
C++
260 lines
8.4 KiB
C++
// This file is part of Eigen, a lightweight C++ template library
|
|
// for linear algebra.
|
|
//
|
|
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla
|
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
#include "main.h"
|
|
|
|
template<typename T, typename U>
|
|
bool check_if_equal_or_nans(const T& actual, const U& expected) {
|
|
return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
|
|
}
|
|
|
|
template<typename T, typename U>
|
|
bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
|
|
return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
|
|
&& check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
|
|
}
|
|
|
|
template<typename T, typename U>
|
|
bool test_is_equal_or_nans(const T& actual, const U& expected)
|
|
{
|
|
if (check_if_equal_or_nans(actual, expected)) {
|
|
return true;
|
|
}
|
|
|
|
// false:
|
|
std::cerr
|
|
<< "\n actual = " << actual
|
|
<< "\n expected = " << expected << "\n\n";
|
|
return false;
|
|
}
|
|
|
|
#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
|
|
|
|
template<typename T>
|
|
void check_abs() {
|
|
typedef typename NumTraits<T>::Real Real;
|
|
Real zero(0);
|
|
|
|
if(NumTraits<T>::IsSigned)
|
|
VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
|
|
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
|
|
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
|
|
|
|
for(int k=0; k<100; ++k)
|
|
{
|
|
T x = internal::random<T>();
|
|
if(!internal::is_same<T,bool>::value)
|
|
x = x/Real(2);
|
|
if(NumTraits<T>::IsSigned)
|
|
{
|
|
VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
|
|
VERIFY( numext::abs(-x) >= zero );
|
|
}
|
|
VERIFY( numext::abs(x) >= zero );
|
|
VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
|
|
}
|
|
}
|
|
|
|
template<typename T>
|
|
struct check_sqrt_impl {
|
|
static void run() {
|
|
for (int i=0; i<1000; ++i) {
|
|
const T x = numext::abs(internal::random<T>());
|
|
const T sqrtx = numext::sqrt(x);
|
|
VERIFY_IS_APPROX(sqrtx*sqrtx, x);
|
|
}
|
|
|
|
// Corner cases.
|
|
const T zero = T(0);
|
|
const T one = T(1);
|
|
const T inf = std::numeric_limits<T>::infinity();
|
|
const T nan = std::numeric_limits<T>::quiet_NaN();
|
|
VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
|
|
VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
|
|
VERIFY((numext::isnan)(numext::sqrt(nan)));
|
|
VERIFY((numext::isnan)(numext::sqrt(-one)));
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct check_sqrt_impl<std::complex<T> > {
|
|
static void run() {
|
|
typedef typename std::complex<T> ComplexT;
|
|
|
|
for (int i=0; i<1000; ++i) {
|
|
const ComplexT x = internal::random<ComplexT>();
|
|
const ComplexT sqrtx = numext::sqrt(x);
|
|
VERIFY_IS_APPROX(sqrtx*sqrtx, x);
|
|
}
|
|
|
|
// Corner cases.
|
|
const T zero = T(0);
|
|
const T one = T(1);
|
|
const T inf = std::numeric_limits<T>::infinity();
|
|
const T nan = std::numeric_limits<T>::quiet_NaN();
|
|
|
|
// Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
|
|
const int kNumCorners = 20;
|
|
const ComplexT corners[kNumCorners][2] = {
|
|
{ComplexT(zero, zero), ComplexT(zero, zero)},
|
|
{ComplexT(-zero, zero), ComplexT(zero, zero)},
|
|
{ComplexT(zero, -zero), ComplexT(zero, zero)},
|
|
{ComplexT(-zero, -zero), ComplexT(zero, zero)},
|
|
{ComplexT(one, inf), ComplexT(inf, inf)},
|
|
{ComplexT(nan, inf), ComplexT(inf, inf)},
|
|
{ComplexT(one, -inf), ComplexT(inf, -inf)},
|
|
{ComplexT(nan, -inf), ComplexT(inf, -inf)},
|
|
{ComplexT(-inf, one), ComplexT(zero, inf)},
|
|
{ComplexT(inf, one), ComplexT(inf, zero)},
|
|
{ComplexT(-inf, -one), ComplexT(zero, -inf)},
|
|
{ComplexT(inf, -one), ComplexT(inf, -zero)},
|
|
{ComplexT(-inf, nan), ComplexT(nan, inf)},
|
|
{ComplexT(inf, nan), ComplexT(inf, nan)},
|
|
{ComplexT(zero, nan), ComplexT(nan, nan)},
|
|
{ComplexT(one, nan), ComplexT(nan, nan)},
|
|
{ComplexT(nan, zero), ComplexT(nan, nan)},
|
|
{ComplexT(nan, one), ComplexT(nan, nan)},
|
|
{ComplexT(nan, -one), ComplexT(nan, nan)},
|
|
{ComplexT(nan, nan), ComplexT(nan, nan)},
|
|
};
|
|
|
|
for (int i=0; i<kNumCorners; ++i) {
|
|
const ComplexT& x = corners[i][0];
|
|
const ComplexT sqrtx = corners[i][1];
|
|
VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
|
|
}
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
void check_sqrt() {
|
|
check_sqrt_impl<T>::run();
|
|
}
|
|
|
|
template<typename T>
|
|
struct check_rsqrt_impl {
|
|
static void run() {
|
|
const T zero = T(0);
|
|
const T one = T(1);
|
|
const T inf = std::numeric_limits<T>::infinity();
|
|
const T nan = std::numeric_limits<T>::quiet_NaN();
|
|
|
|
for (int i=0; i<1000; ++i) {
|
|
const T x = numext::abs(internal::random<T>());
|
|
const T rsqrtx = numext::rsqrt(x);
|
|
const T invx = one / x;
|
|
VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
|
|
}
|
|
|
|
// Corner cases.
|
|
VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
|
|
VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
|
|
VERIFY((numext::isnan)(numext::rsqrt(nan)));
|
|
VERIFY((numext::isnan)(numext::rsqrt(-one)));
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct check_rsqrt_impl<std::complex<T> > {
|
|
static void run() {
|
|
typedef typename std::complex<T> ComplexT;
|
|
const T zero = T(0);
|
|
const T one = T(1);
|
|
const T inf = std::numeric_limits<T>::infinity();
|
|
const T nan = std::numeric_limits<T>::quiet_NaN();
|
|
|
|
for (int i=0; i<1000; ++i) {
|
|
const ComplexT x = internal::random<ComplexT>();
|
|
const ComplexT invx = ComplexT(one, zero) / x;
|
|
const ComplexT rsqrtx = numext::rsqrt(x);
|
|
VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
|
|
}
|
|
|
|
// GCC and MSVC differ in their treatment of 1/(0 + 0i)
|
|
// GCC/clang = (inf, nan)
|
|
// MSVC = (nan, nan)
|
|
// and 1 / (x + inf i)
|
|
// GCC/clang = (0, 0)
|
|
// MSVC = (nan, nan)
|
|
#if (EIGEN_COMP_GNUC)
|
|
{
|
|
const int kNumCorners = 20;
|
|
const ComplexT corners[kNumCorners][2] = {
|
|
// Only consistent across GCC, clang
|
|
{ComplexT(zero, zero), ComplexT(zero, zero)},
|
|
{ComplexT(-zero, zero), ComplexT(zero, zero)},
|
|
{ComplexT(zero, -zero), ComplexT(zero, zero)},
|
|
{ComplexT(-zero, -zero), ComplexT(zero, zero)},
|
|
{ComplexT(one, inf), ComplexT(inf, inf)},
|
|
{ComplexT(nan, inf), ComplexT(inf, inf)},
|
|
{ComplexT(one, -inf), ComplexT(inf, -inf)},
|
|
{ComplexT(nan, -inf), ComplexT(inf, -inf)},
|
|
// Consistent across GCC, clang, MSVC
|
|
{ComplexT(-inf, one), ComplexT(zero, inf)},
|
|
{ComplexT(inf, one), ComplexT(inf, zero)},
|
|
{ComplexT(-inf, -one), ComplexT(zero, -inf)},
|
|
{ComplexT(inf, -one), ComplexT(inf, -zero)},
|
|
{ComplexT(-inf, nan), ComplexT(nan, inf)},
|
|
{ComplexT(inf, nan), ComplexT(inf, nan)},
|
|
{ComplexT(zero, nan), ComplexT(nan, nan)},
|
|
{ComplexT(one, nan), ComplexT(nan, nan)},
|
|
{ComplexT(nan, zero), ComplexT(nan, nan)},
|
|
{ComplexT(nan, one), ComplexT(nan, nan)},
|
|
{ComplexT(nan, -one), ComplexT(nan, nan)},
|
|
{ComplexT(nan, nan), ComplexT(nan, nan)},
|
|
};
|
|
|
|
for (int i=0; i<kNumCorners; ++i) {
|
|
const ComplexT& x = corners[i][0];
|
|
const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
|
|
VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
void check_rsqrt() {
|
|
check_rsqrt_impl<T>::run();
|
|
}
|
|
|
|
EIGEN_DECLARE_TEST(numext) {
|
|
for(int k=0; k<g_repeat; ++k)
|
|
{
|
|
CALL_SUBTEST( check_abs<bool>() );
|
|
CALL_SUBTEST( check_abs<signed char>() );
|
|
CALL_SUBTEST( check_abs<unsigned char>() );
|
|
CALL_SUBTEST( check_abs<short>() );
|
|
CALL_SUBTEST( check_abs<unsigned short>() );
|
|
CALL_SUBTEST( check_abs<int>() );
|
|
CALL_SUBTEST( check_abs<unsigned int>() );
|
|
CALL_SUBTEST( check_abs<long>() );
|
|
CALL_SUBTEST( check_abs<unsigned long>() );
|
|
CALL_SUBTEST( check_abs<half>() );
|
|
CALL_SUBTEST( check_abs<bfloat16>() );
|
|
CALL_SUBTEST( check_abs<float>() );
|
|
CALL_SUBTEST( check_abs<double>() );
|
|
CALL_SUBTEST( check_abs<long double>() );
|
|
|
|
CALL_SUBTEST( check_abs<std::complex<float> >() );
|
|
CALL_SUBTEST( check_abs<std::complex<double> >() );
|
|
|
|
CALL_SUBTEST( check_sqrt<float>() );
|
|
CALL_SUBTEST( check_sqrt<double>() );
|
|
CALL_SUBTEST( check_sqrt<std::complex<float> >() );
|
|
CALL_SUBTEST( check_sqrt<std::complex<double> >() );
|
|
|
|
CALL_SUBTEST( check_rsqrt<float>() );
|
|
CALL_SUBTEST( check_rsqrt<double>() );
|
|
CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
|
|
CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
|
|
}
|
|
}
|