Merged in mfigurnov/eigen/gamma-der-a (pull request PR-403)

Derivative of the incomplete Gamma function and the sample of a Gamma random variable

Approved-by: Benoit Steiner <benoit.steiner.goog@gmail.com>
This commit is contained in:
Benoit Steiner 2018-06-11 17:57:47 +00:00
commit d3a380af4d
12 changed files with 792 additions and 207 deletions

View File

@ -85,6 +85,8 @@ struct default_packet_traits
HasI0e = 0,
HasI1e = 0,
HasIGamma = 0,
HasIGammaDerA = 0,
HasGammaSampleDerAlpha = 0,
HasIGammac = 0,
HasBetaInc = 0,

View File

@ -47,6 +47,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasI0e = 1,
HasI1e = 1,
HasIGamma = 1,
HasIGammaDerA = 1,
HasGammaSampleDerAlpha = 1,
HasIGammac = 1,
HasBetaInc = 1,
@ -78,6 +80,8 @@ template<> struct packet_traits<double> : default_packet_traits
HasI0e = 1,
HasI1e = 1,
HasIGamma = 1,
HasIGammaDerA = 1,
HasGammaSampleDerAlpha = 1,
HasIGammac = 1,
HasBetaInc = 1,

View File

@ -152,6 +152,20 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), internal::scalar_igamma_op<Scalar>());
}
// igamma_der_a(a = this, x = other)
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_igamma_der_a_op<Scalar>, const Derived, const OtherDerived>
igamma_der_a(const OtherDerived& other) const {
return binaryExpr(other.derived(), internal::scalar_igamma_der_a_op<Scalar>());
}
// gamma_sample_der_alpha(alpha = this, sample = other)
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_gamma_sample_der_alpha_op<Scalar>, const Derived, const OtherDerived>
gamma_sample_der_alpha(const OtherDerived& other) const {
return binaryExpr(other.derived(), internal::scalar_gamma_sample_der_alpha_op<Scalar>());
}
// igammac(a = this, x = other)
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_igammac_op<Scalar>, const Derived, const OtherDerived>

View File

@ -29,6 +29,8 @@ namespace Eigen {
* - erfc
* - lgamma
* - igamma
* - igamma_der_a
* - gamma_sample_der_alpha
* - igammac
* - digamma
* - polygamma

View File

@ -33,6 +33,48 @@ igamma(const Eigen::ArrayBase<Derived>& a, const Eigen::ArrayBase<ExponentDerive
);
}
/** \cpp11 \returns an expression of the coefficient-wise igamma_der_a(\a a, \a x) to the given arrays.
*
* This function computes the coefficient-wise derivative of the incomplete
* gamma function with respect to the parameter a.
*
* \note This function supports only float and double scalar types in c++11
* mode. To support other scalar types,
* or float/double in non c++11 mode, the user has to provide implementations
* of igamma_der_a(T,T) for any scalar
* type T to be supported.
*
* \sa Eigen::igamma(), Eigen::lgamma()
*/
template <typename Derived, typename ExponentDerived>
inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_igamma_der_a_op<typename Derived::Scalar>, const Derived, const ExponentDerived>
igamma_der_a(const Eigen::ArrayBase<Derived>& a, const Eigen::ArrayBase<ExponentDerived>& x) {
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_igamma_der_a_op<typename Derived::Scalar>, const Derived, const ExponentDerived>(
a.derived(),
x.derived());
}
/** \cpp11 \returns an expression of the coefficient-wise gamma_sample_der_alpha(\a alpha, \a sample) to the given arrays.
*
* This function computes the coefficient-wise derivative of the sample
* of a Gamma(alpha, 1) random variable with respect to the parameter alpha.
*
* \note This function supports only float and double scalar types in c++11
* mode. To support other scalar types,
* or float/double in non c++11 mode, the user has to provide implementations
* of gamma_sample_der_alpha(T,T) for any scalar
* type T to be supported.
*
* \sa Eigen::igamma(), Eigen::lgamma()
*/
template <typename AlphaDerived, typename SampleDerived>
inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_gamma_sample_der_alpha_op<typename AlphaDerived::Scalar>, const AlphaDerived, const SampleDerived>
gamma_sample_der_alpha(const Eigen::ArrayBase<AlphaDerived>& alpha, const Eigen::ArrayBase<SampleDerived>& sample) {
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_gamma_sample_der_alpha_op<typename AlphaDerived::Scalar>, const AlphaDerived, const SampleDerived>(
alpha.derived(),
sample.derived());
}
/** \cpp11 \returns an expression of the coefficient-wise igammac(\a a, \a x) to the given arrays.
*
* This function computes the coefficient-wise complementary incomplete gamma function.

View File

@ -41,6 +41,60 @@ struct functor_traits<scalar_igamma_op<Scalar> > {
};
};
/** \internal
* \brief Template functor to compute the derivative of the incomplete gamma
* function igamma_der_a(a, x)
*
* \sa class CwiseBinaryOp, Cwise::igamma_der_a
*/
template <typename Scalar>
struct scalar_igamma_der_a_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_der_a_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
using numext::igamma_der_a;
return igamma_der_a(a, x);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
return internal::pigamma_der_a(a, x);
}
};
template <typename Scalar>
struct functor_traits<scalar_igamma_der_a_op<Scalar> > {
enum {
// 2x the cost of igamma
Cost = 40 * NumTraits<Scalar>::MulCost + 20 * NumTraits<Scalar>::AddCost,
PacketAccess = packet_traits<Scalar>::HasIGammaDerA
};
};
/** \internal
* \brief Template functor to compute the derivative of the sample
* of a Gamma(alpha, 1) random variable with respect to the parameter alpha
* gamma_sample_der_alpha(alpha, sample)
*
* \sa class CwiseBinaryOp, Cwise::gamma_sample_der_alpha
*/
template <typename Scalar>
struct scalar_gamma_sample_der_alpha_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_gamma_sample_der_alpha_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const {
using numext::gamma_sample_der_alpha;
return gamma_sample_der_alpha(alpha, sample);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const {
return internal::pgamma_sample_der_alpha(alpha, sample);
}
};
template <typename Scalar>
struct functor_traits<scalar_gamma_sample_der_alpha_op<Scalar> > {
enum {
// 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out)
Cost = 30 * NumTraits<Scalar>::MulCost + 15 * NumTraits<Scalar>::AddCost,
PacketAccess = packet_traits<Scalar>::HasGammaSampleDerAlpha
};
};
/** \internal
* \brief Template functor to compute the complementary incomplete gamma function igammac(a, x)

View File

@ -33,6 +33,14 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erfc(const Eigen::h
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen::half& a, const Eigen::half& x) {
return Eigen::half(Eigen::numext::igamma(static_cast<float>(a), static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma_der_a(const Eigen::half& a, const Eigen::half& x) {
return Eigen::half(Eigen::numext::igamma_der_a(static_cast<float>(a), static_cast<float>(x)));
}
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half gamma_sample_der_alpha(const Eigen::half& alpha, const Eigen::half& sample) {
return Eigen::half(Eigen::numext::gamma_sample_der_alpha(static_cast<float>(alpha), static_cast<float>(sample)));
}
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) {
return Eigen::half(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
}

View File

@ -521,6 +521,197 @@ struct cephes_helper<double> {
}
};
enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
template <typename Scalar, IgammaComputationMode mode>
EIGEN_DEVICE_FUNC
int igamma_num_iterations() {
/* Returns the maximum number of internal iterations for igamma computation.
*/
if (mode == VALUE) {
return 2000;
}
if (internal::is_same<Scalar, float>::value) {
return 200;
} else if (internal::is_same<Scalar, double>::value) {
return 500;
} else {
return 2000;
}
}
template <typename Scalar, IgammaComputationMode mode>
struct igammac_cf_impl {
/* Computes igamc(a, x) or derivative (depending on the mode)
* using the continued fraction expansion of the complementary
* incomplete Gamma function.
*
* Preconditions:
* a > 0
* x >= 1
* x >= a
*/
EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar two = 2;
const Scalar machep = cephes_helper<Scalar>::machep();
const Scalar big = cephes_helper<Scalar>::big();
const Scalar biginv = cephes_helper<Scalar>::biginv();
if ((numext::isinf)(x)) {
return zero;
}
// continued fraction
Scalar y = one - a;
Scalar z = x + y + one;
Scalar c = zero;
Scalar pkm2 = one;
Scalar qkm2 = x;
Scalar pkm1 = x + one;
Scalar qkm1 = z * x;
Scalar ans = pkm1 / qkm1;
Scalar dpkm2_da = zero;
Scalar dqkm2_da = zero;
Scalar dpkm1_da = zero;
Scalar dqkm1_da = -x;
Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
c += one;
y += one;
z += two;
Scalar yc = y * c;
Scalar pk = pkm1 * z - pkm2 * yc;
Scalar qk = qkm1 * z - qkm2 * yc;
Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
if (qk != zero) {
Scalar ans_prev = ans;
ans = pk / qk;
Scalar dans_da_prev = dans_da;
dans_da = (dpk_da - ans * dqk_da) / qk;
if (mode == VALUE) {
if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
break;
}
} else {
if (numext::abs(dans_da - dans_da_prev) <= machep) {
break;
}
}
}
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
dpkm2_da = dpkm1_da;
dpkm1_da = dpk_da;
dqkm2_da = dqkm1_da;
dqkm1_da = dqk_da;
if (numext::abs(pk) > big) {
pkm2 *= biginv;
pkm1 *= biginv;
qkm2 *= biginv;
qkm1 *= biginv;
dpkm2_da *= biginv;
dpkm1_da *= biginv;
dqkm2_da *= biginv;
dqkm1_da *= biginv;
}
}
/* Compute x**a * exp(-x) / gamma(a) */
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
Scalar ax = numext::exp(logax);
Scalar dax_da = ax * dlogax_da;
switch (mode) {
case VALUE:
return ans * ax;
case DERIVATIVE:
return ans * dax_da + dans_da * ax;
case SAMPLE_DERIVATIVE:
return -(dans_da + ans * dlogax_da) * x;
}
}
};
template <typename Scalar, IgammaComputationMode mode>
struct igamma_series_impl {
/* Computes igam(a, x) or its derivative (depending on the mode)
* using the series expansion of the incomplete Gamma function.
*
* Preconditions:
* x > 0
* a > 0
* !(x > 1 && x > a)
*/
EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar machep = cephes_helper<Scalar>::machep();
/* power series */
Scalar r = a;
Scalar c = one;
Scalar ans = one;
Scalar dc_da = zero;
Scalar dans_da = zero;
for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
r += one;
Scalar term = x / r;
Scalar dterm_da = -x / (r * r);
dc_da = term * dc_da + dterm_da * c;
dans_da += dc_da;
c *= term;
ans += c;
if (mode == VALUE) {
if (c <= machep * ans) {
break;
}
} else {
if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
break;
}
}
}
/* Compute x**a * exp(-x) / gamma(a + 1) */
Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a + one);
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
Scalar ax = numext::exp(logax);
Scalar dax_da = ax * dlogax_da;
switch (mode) {
case VALUE:
return ans * ax;
case DERIVATIVE:
return ans * dax_da + dans_da * ax;
case SAMPLE_DERIVATIVE:
return -(dans_da + ans * dlogax_da) * x / a;
}
}
};
#if !EIGEN_HAS_C99_MATH
template <typename Scalar>
@ -535,8 +726,6 @@ struct igammac_impl {
#else
template <typename Scalar> struct igamma_impl; // predeclare igamma_impl
template <typename Scalar>
struct igammac_impl {
EIGEN_DEVICE_FUNC
@ -604,97 +793,15 @@ struct igammac_impl {
return nan;
}
if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
return nan;
}
if ((x < one) || (x < a)) {
/* The checks above ensure that we meet the preconditions for
* igamma_impl::Impl(), so call it, rather than igamma_impl::Run().
* Calling Run() would also work, but in that case the compiler may not be
* able to prove that igammac_impl::Run and igamma_impl::Run are not
* mutually recursive. This leads to worse code, particularly on
* platforms like nvptx, where recursion is allowed only begrudgingly.
*/
return (one - igamma_impl<Scalar>::Impl(a, x));
return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
}
return Impl(a, x);
}
private:
/* igamma_impl calls igammac_impl::Impl. */
friend struct igamma_impl<Scalar>;
/* Actually computes igamc(a, x).
*
* Preconditions:
* a > 0
* x >= 1
* x >= a
*/
EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar two = 2;
const Scalar machep = cephes_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
const Scalar big = cephes_helper<Scalar>::big();
const Scalar biginv = cephes_helper<Scalar>::biginv();
const Scalar inf = NumTraits<Scalar>::infinity();
Scalar ans, ax, c, yc, r, t, y, z;
Scalar pk, pkm1, pkm2, qk, qkm1, qkm2;
if (x == inf) return zero; // std::isinf crashes on CUDA
/* Compute x**a * exp(-x) / gamma(a) */
ax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
if (ax < -maxlog) { // underflow
return zero;
}
ax = numext::exp(ax);
// continued fraction
y = one - a;
z = x + y + one;
c = zero;
pkm2 = one;
qkm2 = x;
pkm1 = x + one;
qkm1 = z * x;
ans = pkm1 / qkm1;
for (int i = 0; i < 2000; i++) {
c += one;
y += one;
z += two;
yc = y * c;
pk = pkm1 * z - pkm2 * yc;
qk = qkm1 * z - qkm2 * yc;
if (qk != zero) {
r = pk / qk;
t = numext::abs((ans - r) / r);
ans = r;
} else {
t = one;
}
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
if (numext::abs(pk) > big) {
pkm2 *= biginv;
pkm1 *= biginv;
qkm2 *= biginv;
qkm1 *= biginv;
}
if (t <= machep) {
break;
}
}
return (ans * ax);
return igammac_cf_impl<Scalar, VALUE>::run(a, x);
}
};
@ -704,15 +811,10 @@ struct igammac_impl {
* Implementation of igamma (incomplete gamma integral), based on Cephes but requires C++11/C99 *
************************************************************************************************/
template <typename Scalar>
struct igamma_retval {
typedef Scalar type;
};
#if !EIGEN_HAS_C99_MATH
template <typename Scalar>
struct igamma_impl {
template <typename Scalar, IgammaComputationMode mode>
struct igamma_generic_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
@ -723,69 +825,17 @@ struct igamma_impl {
#else
template <typename Scalar>
struct igamma_impl {
template <typename Scalar, IgammaComputationMode mode>
struct igamma_generic_impl {
EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
/* igam()
* Incomplete gamma integral
*
*
*
* SYNOPSIS:
*
* double a, x, y, igam();
*
* y = igam( a, x );
*
* DESCRIPTION:
*
* The function is defined by
*
* x
* -
* 1 | | -t a-1
* igam(a,x) = ----- | e t dt.
* - | |
* | (a) -
* 0
*
*
* In this implementation both arguments must be positive.
* The integral is evaluated by either a power series or
* continued fraction expansion, depending on the relative
* values of a and x.
*
* ACCURACY (double):
*
* Relative error:
* arithmetic domain # trials peak rms
* IEEE 0,30 200000 3.6e-14 2.9e-15
* IEEE 0,100 300000 9.9e-14 1.5e-14
*
*
* ACCURACY (float):
*
* Relative error:
* arithmetic domain # trials peak rms
* IEEE 0,30 20000 7.8e-6 5.9e-7
*
*/
/*
Cephes Math Library Release 2.2: June, 1992
Copyright 1985, 1987, 1992 by Stephen L. Moshier
Direct inquiries to 30 Frost Street, Cambridge, MA 02140
*/
/* left tail of incomplete gamma function:
*
* inf. k
* a -x - x
* x e > ----------
* - -
* k=0 | (a+k+1)
/* Depending on the mode, returns
* - VALUE: incomplete Gamma function igamma(a, x)
* - DERIVATIVE: derivative of incomplete Gamma function d/da igamma(a, x)
* - SAMPLE_DERIVATIVE: implicit derivative of a Gamma random variable
* x ~ Gamma(x | a, 1), dx/da = -1 / Gamma(x | a, 1) * d igamma(a, x) / dx
*
* Derivatives are implemented by forward-mode differentiation.
*/
const Scalar zero = 0;
const Scalar one = 1;
@ -797,71 +847,167 @@ struct igamma_impl {
return nan;
}
if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
return nan;
}
if ((x > one) && (x > a)) {
/* The checks above ensure that we meet the preconditions for
* igammac_impl::Impl(), so call it, rather than igammac_impl::Run().
* Calling Run() would also work, but in that case the compiler may not be
* able to prove that igammac_impl::Run and igamma_impl::Run are not
* mutually recursive. This leads to worse code, particularly on
* platforms like nvptx, where recursion is allowed only begrudgingly.
*/
return (one - igammac_impl<Scalar>::Impl(a, x));
}
return Impl(a, x);
}
private:
/* igammac_impl calls igamma_impl::Impl. */
friend struct igammac_impl<Scalar>;
/* Actually computes igam(a, x).
*
* Preconditions:
* x > 0
* a > 0
* !(x > 1 && x > a)
*/
EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar machep = cephes_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
Scalar ans, ax, c, r;
/* Compute x**a * exp(-x) / gamma(a) */
ax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
if (ax < -maxlog) {
// underflow
return zero;
}
ax = numext::exp(ax);
/* power series */
r = a;
c = one;
ans = one;
for (int i = 0; i < 2000; i++) {
r += one;
c *= x/r;
ans += c;
if (c/ans <= machep) {
break;
Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
if (mode == VALUE) {
return one - ret;
} else {
return -ret;
}
}
return (ans * ax / a);
return igamma_series_impl<Scalar, mode>::run(a, x);
}
};
#endif // EIGEN_HAS_C99_MATH
template <typename Scalar>
struct igamma_retval {
typedef Scalar type;
};
template <typename Scalar>
struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
/* igam()
* Incomplete gamma integral.
*
* The CDF of Gamma(a, 1) random variable at the point x.
*
* Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
* 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
* The ground truth is computed by mpmath. Mean absolute error:
* float: 1.26713e-05
* double: 2.33606e-12
*
* Cephes documentation below.
*
* SYNOPSIS:
*
* double a, x, y, igam();
*
* y = igam( a, x );
*
* DESCRIPTION:
*
* The function is defined by
*
* x
* -
* 1 | | -t a-1
* igam(a,x) = ----- | e t dt.
* - | |
* | (a) -
* 0
*
*
* In this implementation both arguments must be positive.
* The integral is evaluated by either a power series or
* continued fraction expansion, depending on the relative
* values of a and x.
*
* ACCURACY (double):
*
* Relative error:
* arithmetic domain # trials peak rms
* IEEE 0,30 200000 3.6e-14 2.9e-15
* IEEE 0,100 300000 9.9e-14 1.5e-14
*
*
* ACCURACY (float):
*
* Relative error:
* arithmetic domain # trials peak rms
* IEEE 0,30 20000 7.8e-6 5.9e-7
*
*/
/*
Cephes Math Library Release 2.2: June, 1992
Copyright 1985, 1987, 1992 by Stephen L. Moshier
Direct inquiries to 30 Frost Street, Cambridge, MA 02140
*/
/* left tail of incomplete gamma function:
*
* inf. k
* a -x - x
* x e > ----------
* - -
* k=0 | (a+k+1)
*
*/
};
template <typename Scalar>
struct igamma_der_a_retval : igamma_retval<Scalar> {};
template <typename Scalar>
struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
/* Derivative of the incomplete Gamma function with respect to a.
*
* Computes d/da igamma(a, x) by forward differentiation of the igamma code.
*
* Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
* 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
* The ground truth is computed by mpmath. Mean absolute error:
* float: 6.17992e-07
* double: 4.60453e-12
*
* Reference:
* R. Moore. "Algorithm AS 187: Derivatives of the incomplete gamma
* integral". Journal of the Royal Statistical Society. 1982
*/
};
template <typename Scalar>
struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
template <typename Scalar>
struct gamma_sample_der_alpha_impl
: igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
/* Derivative of a Gamma random variable sample with respect to alpha.
*
* Consider a sample of a Gamma random variable with the concentration
* parameter alpha: sample ~ Gamma(alpha, 1). The reparameterization
* derivative that we want to compute is dsample / dalpha =
* d igammainv(alpha, u) / dalpha, where u = igamma(alpha, sample).
* However, this formula is numerically unstable and expensive, so instead
* we use implicit differentiation:
*
* igamma(alpha, sample) = u, where u ~ Uniform(0, 1).
* Apply d / dalpha to both sides:
* d igamma(alpha, sample) / dalpha
* + d igamma(alpha, sample) / dsample * dsample/dalpha = 0
* d igamma(alpha, sample) / dalpha
* + Gamma(sample | alpha, 1) dsample / dalpha = 0
* dsample/dalpha = - (d igamma(alpha, sample) / dalpha)
* / Gamma(sample | alpha, 1)
*
* Here Gamma(sample | alpha, 1) is the PDF of the Gamma distribution
* (note that the derivative of the CDF w.r.t. sample is the PDF).
* See the reference below for more details.
*
* The derivative of igamma(alpha, sample) is computed by forward
* differentiation of the igamma code. Division by the Gamma PDF is performed
* in the same code, increasing the accuracy and speed due to cancellation
* of some terms.
*
* Accuracy estimation. For each alpha in [10^-2, 10^-1...10^3] we sample
* 50 Gamma random variables sample ~ Gamma(sample | alpha, 1), a total of 300
* points. The ground truth is computed by mpmath. Mean absolute error:
* float: 2.1686e-06
* double: 1.4774e-12
*
* Reference:
* M. Figurnov, S. Mohamed, A. Mnih "Implicit Reparameterization Gradients".
* 2018
*/
};
/*****************************************************************************
* Implementation of Riemann zeta function of two arguments, based on Cephes *
*****************************************************************************/
@ -1950,6 +2096,18 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma, Scalar)
return EIGEN_MATHFUNC_IMPL(igamma, Scalar)::run(a, x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma_der_a, Scalar)
igamma_der_a(const Scalar& a, const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(igamma_der_a, Scalar)::run(a, x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(gamma_sample_der_alpha, Scalar)
gamma_sample_der_alpha(const Scalar& a, const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(gamma_sample_der_alpha, Scalar)::run(a, x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar)
igammac(const Scalar& a, const Scalar& x) {

View File

@ -42,6 +42,21 @@ Packet perfc(const Packet& a) { using numext::erfc; return erfc(a); }
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet pigamma(const Packet& a, const Packet& x) { using numext::igamma; return igamma(a, x); }
/** \internal \returns the derivative of the incomplete gamma function
* igamma_der_a(\a a, \a x) */
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pigamma_der_a(const Packet& a, const Packet& x) {
using numext::igamma_der_a; return igamma_der_a(a, x);
}
/** \internal \returns compute the derivative of the sample
* of Gamma(alpha, 1) random variable with respect to the parameter a
* gamma_sample_der_alpha(\a alpha, \a sample) */
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pgamma_sample_der_alpha(const Packet& alpha, const Packet& sample) {
using numext::gamma_sample_der_alpha; return gamma_sample_der_alpha(alpha, sample);
}
/** \internal \returns the complementary incomplete gamma function igammac(\a a, \a x) */
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet pigammac(const Packet& a, const Packet& x) { using numext::igammac; return igammac(a, x); }

View File

@ -120,6 +120,41 @@ double2 pigamma<double2>(const double2& a, const double2& x)
return make_double2(igamma(a.x, x.x), igamma(a.y, x.y));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pigamma_der_a<float4>(
const float4& a, const float4& x) {
using numext::igamma_der_a;
return make_float4(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y),
igamma_der_a(a.z, x.z), igamma_der_a(a.w, x.w));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pigamma_der_a<double2>(const double2& a, const double2& x) {
using numext::igamma_der_a;
return make_double2(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pgamma_sample_der_alpha<float4>(
const float4& alpha, const float4& sample) {
using numext::gamma_sample_der_alpha;
return make_float4(
gamma_sample_der_alpha(alpha.x, sample.x),
gamma_sample_der_alpha(alpha.y, sample.y),
gamma_sample_der_alpha(alpha.z, sample.z),
gamma_sample_der_alpha(alpha.w, sample.w));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pgamma_sample_der_alpha<double2>(const double2& alpha, const double2& sample) {
using numext::gamma_sample_der_alpha;
return make_double2(
gamma_sample_der_alpha(alpha.x, sample.x),
gamma_sample_der_alpha(alpha.y, sample.y));
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 pigammac<float4>(const float4& a, const float4& x)
{

View File

@ -1318,6 +1318,157 @@ void test_cuda_i1e()
cudaFree(d_out);
}
template <typename Scalar>
void test_cuda_igamma_der_a()
{
Tensor<Scalar, 1> in_x(30);
Tensor<Scalar, 1> in_a(30);
Tensor<Scalar, 1> out(30);
Tensor<Scalar, 1> expected_out(30);
out.setZero();
Array<Scalar, 1, Dynamic> in_a_array(30);
Array<Scalar, 1, Dynamic> in_x_array(30);
Array<Scalar, 1, Dynamic> expected_out_array(30);
// See special_functions.cpp for the Python code that generates the test data.
in_a_array << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
in_x_array << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 0.0132865061065,
0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 0.333412038288,
1.18135687766, 0.580629033777, 0.170631439426, 0.786686768458,
7.63873279537, 13.1944344379, 11.896042354, 10.5830172417, 10.5020942233,
92.8918587747, 95.003720371, 86.3715926467, 96.0330217672, 82.6389930677,
968.702906754, 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
expected_out_array << -32.7256441441, -36.4394150514, -9.66467612263,
-36.4394150514, -36.4394150514, -1.0891900302, -2.66351229645,
-2.48666868596, -0.929700494428, -3.56327722764, -0.455320135314,
-0.391437214323, -0.491352055991, -0.350454834292, -0.471773162921,
-0.104084440522, -0.0723646747909, -0.0992828975532, -0.121638215446,
-0.122619605294, -0.0317670267286, -0.0359974812869, -0.0154359225363,
-0.0375775365921, -0.00794899153653, -0.00777303219211, -0.00796085782042,
-0.0125850719397, -0.00455500206958, -0.00476436993148;
for (int i = 0; i < 30; ++i) {
in_x(i) = in_x_array(i);
in_a(i) = in_a_array(i);
expected_out(i) = expected_out_array(i);
}
std::size_t bytes = in_x.size() * sizeof(Scalar);
Scalar* d_a;
Scalar* d_x;
Scalar* d_out;
cudaMalloc((void**)(&d_a), bytes);
cudaMalloc((void**)(&d_x), bytes);
cudaMalloc((void**)(&d_out), bytes);
cudaMemcpy(d_a, in_a.data(), bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_x, in_x.data(), bytes, cudaMemcpyHostToDevice);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_a(d_a, 30);
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_x(d_x, 30);
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 30);
gpu_out.device(gpu_device) = gpu_a.igamma_der_a(gpu_x);
assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost,
gpu_device.stream()) == cudaSuccess);
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
for (int i = 0; i < 30; ++i) {
VERIFY_IS_APPROX(out(i), expected_out(i));
}
cudaFree(d_a);
cudaFree(d_x);
cudaFree(d_out);
}
template <typename Scalar>
void test_cuda_gamma_sample_der_alpha()
{
Tensor<Scalar, 1> in_alpha(30);
Tensor<Scalar, 1> in_sample(30);
Tensor<Scalar, 1> out(30);
Tensor<Scalar, 1> expected_out(30);
out.setZero();
Array<Scalar, 1, Dynamic> in_alpha_array(30);
Array<Scalar, 1, Dynamic> in_sample_array(30);
Array<Scalar, 1, Dynamic> expected_out_array(30);
// See special_functions.cpp for the Python code that generates the test data.
in_alpha_array << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0,
1.0, 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0,
100.0, 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
in_sample_array << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 0.0132865061065,
0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 0.333412038288,
1.18135687766, 0.580629033777, 0.170631439426, 0.786686768458,
7.63873279537, 13.1944344379, 11.896042354, 10.5830172417, 10.5020942233,
92.8918587747, 95.003720371, 86.3715926467, 96.0330217672, 82.6389930677,
968.702906754, 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
expected_out_array << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
1.00106492525, 0.97734200649, 1.02198794179;
for (int i = 0; i < 30; ++i) {
in_alpha(i) = in_alpha_array(i);
in_sample(i) = in_sample_array(i);
expected_out(i) = expected_out_array(i);
}
std::size_t bytes = in_alpha.size() * sizeof(Scalar);
Scalar* d_alpha;
Scalar* d_sample;
Scalar* d_out;
cudaMalloc((void**)(&d_alpha), bytes);
cudaMalloc((void**)(&d_sample), bytes);
cudaMalloc((void**)(&d_out), bytes);
cudaMemcpy(d_alpha, in_alpha.data(), bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_sample, in_sample.data(), bytes, cudaMemcpyHostToDevice);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_alpha(d_alpha, 30);
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_sample(d_sample, 30);
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 30);
gpu_out.device(gpu_device) = gpu_alpha.gamma_sample_der_alpha(gpu_sample);
assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost,
gpu_device.stream()) == cudaSuccess);
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
for (int i = 0; i < 30; ++i) {
VERIFY_IS_APPROX(out(i), expected_out(i));
}
cudaFree(d_alpha);
cudaFree(d_sample);
cudaFree(d_out);
}
void test_cxx11_tensor_cuda()
{
@ -1396,5 +1547,11 @@ void test_cxx11_tensor_cuda()
CALL_SUBTEST_6(test_cuda_i1e<float>());
CALL_SUBTEST_6(test_cuda_i1e<double>());
CALL_SUBTEST_6(test_cuda_igamma_der_a<float>());
CALL_SUBTEST_6(test_cuda_igamma_der_a<double>());
CALL_SUBTEST_6(test_cuda_gamma_sample_der_alpha<float>());
CALL_SUBTEST_6(test_cuda_gamma_sample_der_alpha<double>());
#endif
}

View File

@ -376,6 +376,100 @@ template<typename ArrayType> void array_special_functions()
CALL_SUBTEST(res = i1e(x);
verify_component_wise(res, expected););
}
/* Code to generate the data for the following two test cases.
N = 5
np.random.seed(3)
a = np.logspace(-2, 3, 6)
a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N]))
x = np.random.gamma(a, 1.0)
x = np.maximum(x, np.finfo(np.float32).tiny)
def igamma(a, x):
return mpmath.gammainc(a, 0, x, regularized=True)
def igamma_der_a(a, x):
res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a)
return np.float64(res)
def gamma_sample_der_alpha(a, x):
igamma_x = igamma(a, x)
def igammainv_of_igamma(a_prime):
return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) -
igamma_x, x, solver='newton')
return np.float64(mpmath.diff(igammainv_of_igamma, a))
v_igamma_der_a = np.vectorize(igamma_der_a)(a, x)
v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x)
*/
#if EIGEN_HAS_C99_MATH
// Test igamma_der_a
{
ArrayType a(30);
ArrayType x(30);
ArrayType res(30);
ArrayType v(30);
a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0,
1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514,
-36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596,
-0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323,
-0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522,
-0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294,
-0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921,
-0.00794899153653, -0.00777303219211, -0.00796085782042,
-0.0125850719397, -0.00455500206958, -0.00476436993148;
CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v););
}
// Test gamma_sample_der_alpha
{
ArrayType alpha(30);
ArrayType sample(30);
ArrayType res(30);
ArrayType v(30);
alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
1.00106492525, 0.97734200649, 1.02198794179;
CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample);
verify_component_wise(res, v););
}
#endif // EIGEN_HAS_C99_MATH
}
void test_special_functions()