mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Ensure Igamma does not NaN or Inf for large values.
This commit is contained in:
parent
6601abce86
commit
f6c6de5d63
@ -713,6 +713,18 @@ struct cephes_helper<double> {
|
||||
|
||||
enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
|
||||
|
||||
template <typename Scalar>
|
||||
static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
|
||||
/* Compute x**a * exp(-x) / gamma(a) */
|
||||
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
|
||||
if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
|
||||
// Assuming x and a aren't Nan.
|
||||
(numext::isnan)(logax)) {
|
||||
return Scalar(0);
|
||||
}
|
||||
return numext::exp(logax);
|
||||
}
|
||||
|
||||
template <typename Scalar, IgammaComputationMode mode>
|
||||
EIGEN_DEVICE_FUNC
|
||||
int igamma_num_iterations() {
|
||||
@ -755,6 +767,15 @@ struct igammac_cf_impl {
|
||||
return zero;
|
||||
}
|
||||
|
||||
Scalar ax = main_igamma_term<Scalar>(a, x);
|
||||
// This is independent of mode. If this value is zero,
|
||||
// then the function value is zero. If the function value is zero,
|
||||
// then we are in a neighborhood where the function value evalutes to zero,
|
||||
// so the derivative is zero.
|
||||
if (ax == zero) {
|
||||
return zero;
|
||||
}
|
||||
|
||||
// continued fraction
|
||||
Scalar y = one - a;
|
||||
Scalar z = x + y + one;
|
||||
@ -825,9 +846,7 @@ struct igammac_cf_impl {
|
||||
}
|
||||
|
||||
/* Compute x**a * exp(-x) / gamma(a) */
|
||||
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
|
||||
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
|
||||
Scalar ax = numext::exp(logax);
|
||||
Scalar dax_da = ax * dlogax_da;
|
||||
|
||||
switch (mode) {
|
||||
@ -858,6 +877,18 @@ struct igamma_series_impl {
|
||||
const Scalar one = 1;
|
||||
const Scalar machep = cephes_helper<Scalar>::machep();
|
||||
|
||||
Scalar ax = main_igamma_term<Scalar>(a, x);
|
||||
|
||||
// This is independent of mode. If this value is zero,
|
||||
// then the function value is zero. If the function value is zero,
|
||||
// then we are in a neighborhood where the function value evalutes to zero,
|
||||
// so the derivative is zero.
|
||||
if (ax == zero) {
|
||||
return zero;
|
||||
}
|
||||
|
||||
ax /= a;
|
||||
|
||||
/* power series */
|
||||
Scalar r = a;
|
||||
Scalar c = one;
|
||||
@ -886,10 +917,7 @@ struct igamma_series_impl {
|
||||
}
|
||||
}
|
||||
|
||||
/* Compute x**a * exp(-x) / gamma(a + 1) */
|
||||
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a + one);
|
||||
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
|
||||
Scalar ax = numext::exp(logax);
|
||||
Scalar dax_da = ax * dlogax_da;
|
||||
|
||||
switch (mode) {
|
||||
|
@ -7,6 +7,7 @@
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <limits.h>
|
||||
#include "main.h"
|
||||
#include "../Eigen/SpecialFunctions"
|
||||
|
||||
@ -74,6 +75,7 @@ template<typename ArrayType> void array_special_functions()
|
||||
ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
|
||||
ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
|
||||
|
||||
|
||||
// Gamma(a, 0) == Gamma(a)
|
||||
VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
|
||||
|
||||
@ -86,6 +88,19 @@ template<typename ArrayType> void array_special_functions()
|
||||
// gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
|
||||
VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
|
||||
}
|
||||
{
|
||||
// Verify for large a and x that values are between 0 and 1.
|
||||
ArrayType m1 = ArrayType::Random(rows,cols);
|
||||
ArrayType m2 = ArrayType::Random(rows,cols);
|
||||
Scalar max_exponent = std::numeric_limits<Scalar>::max_exponent10;
|
||||
ArrayType a = m1.abs() * pow(10., max_exponent - 1);
|
||||
ArrayType x = m2.abs() * pow(10., max_exponent - 1);
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
Scalar igam = numext::igamma(a(i), x(i));
|
||||
VERIFY(0 <= igam);
|
||||
VERIFY(igam <= 1);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Check exact values of igamma and igammac against a third party calculation.
|
||||
|
Loading…
Reference in New Issue
Block a user