Ensure Igamma does not NaN or Inf for large values.

This commit is contained in:
Srinivas Vasudevan 2020-01-14 21:32:48 +00:00 committed by Rasmus Munk Larsen
parent 6601abce86
commit f6c6de5d63
2 changed files with 48 additions and 5 deletions

View File

@ -713,6 +713,18 @@ struct cephes_helper<double> {
enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
template <typename Scalar>
static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
/* Compute x**a * exp(-x) / gamma(a) */
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
// Assuming x and a aren't Nan.
(numext::isnan)(logax)) {
return Scalar(0);
}
return numext::exp(logax);
}
template <typename Scalar, IgammaComputationMode mode>
EIGEN_DEVICE_FUNC
int igamma_num_iterations() {
@ -755,6 +767,15 @@ struct igammac_cf_impl {
return zero;
}
Scalar ax = main_igamma_term<Scalar>(a, x);
// This is independent of mode. If this value is zero,
// then the function value is zero. If the function value is zero,
// then we are in a neighborhood where the function value evalutes to zero,
// so the derivative is zero.
if (ax == zero) {
return zero;
}
// continued fraction
Scalar y = one - a;
Scalar z = x + y + one;
@ -825,9 +846,7 @@ struct igammac_cf_impl {
}
/* Compute x**a * exp(-x) / gamma(a) */
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
Scalar ax = numext::exp(logax);
Scalar dax_da = ax * dlogax_da;
switch (mode) {
@ -858,6 +877,18 @@ struct igamma_series_impl {
const Scalar one = 1;
const Scalar machep = cephes_helper<Scalar>::machep();
Scalar ax = main_igamma_term<Scalar>(a, x);
// This is independent of mode. If this value is zero,
// then the function value is zero. If the function value is zero,
// then we are in a neighborhood where the function value evalutes to zero,
// so the derivative is zero.
if (ax == zero) {
return zero;
}
ax /= a;
/* power series */
Scalar r = a;
Scalar c = one;
@ -886,10 +917,7 @@ struct igamma_series_impl {
}
}
/* Compute x**a * exp(-x) / gamma(a + 1) */
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a + one);
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
Scalar ax = numext::exp(logax);
Scalar dax_da = ax * dlogax_da;
switch (mode) {

View File

@ -7,6 +7,7 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include <limits.h>
#include "main.h"
#include "../Eigen/SpecialFunctions"
@ -74,6 +75,7 @@ template<typename ArrayType> void array_special_functions()
ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
// Gamma(a, 0) == Gamma(a)
VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
@ -86,6 +88,19 @@ template<typename ArrayType> void array_special_functions()
// gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
}
{
// Verify for large a and x that values are between 0 and 1.
ArrayType m1 = ArrayType::Random(rows,cols);
ArrayType m2 = ArrayType::Random(rows,cols);
Scalar max_exponent = std::numeric_limits<Scalar>::max_exponent10;
ArrayType a = m1.abs() * pow(10., max_exponent - 1);
ArrayType x = m2.abs() * pow(10., max_exponent - 1);
for (int i = 0; i < a.size(); ++i) {
Scalar igam = numext::igamma(a(i), x(i));
VERIFY(0 <= igam);
VERIFY(igam <= 1);
}
}
{
// Check exact values of igamma and igammac against a third party calculation.