diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 888a3f7ea..55b6a89e2 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -85,6 +85,8 @@ struct default_packet_traits HasI0e = 0, HasI1e = 0, HasIGamma = 0, + HasIGammaDerA = 0, + HasGammaSampleDerAlpha = 0, HasIGammac = 0, HasBetaInc = 0, diff --git a/Eigen/src/Core/arch/CUDA/PacketMath.h b/Eigen/src/Core/arch/CUDA/PacketMath.h index 704a4e0d9..ab8e477f4 100644 --- a/Eigen/src/Core/arch/CUDA/PacketMath.h +++ b/Eigen/src/Core/arch/CUDA/PacketMath.h @@ -47,6 +47,8 @@ template<> struct packet_traits : default_packet_traits HasI0e = 1, HasI1e = 1, HasIGamma = 1, + HasIGammaDerA = 1, + HasGammaSampleDerAlpha = 1, HasIGammac = 1, HasBetaInc = 1, @@ -78,6 +80,8 @@ template<> struct packet_traits : default_packet_traits HasI0e = 1, HasI1e = 1, HasIGamma = 1, + HasIGammaDerA = 1, + HasGammaSampleDerAlpha = 1, HasIGammac = 1, HasBetaInc = 1, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index d88e0df71..bdc1a17a7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -152,6 +152,20 @@ class TensorBase return binaryExpr(other.derived(), internal::scalar_igamma_op()); } + // igamma_der_a(a = this, x = other) + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + igamma_der_a(const OtherDerived& other) const { + return binaryExpr(other.derived(), internal::scalar_igamma_der_a_op()); + } + + // gamma_sample_der_alpha(alpha = this, sample = other) + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + gamma_sample_der_alpha(const OtherDerived& other) const { + return binaryExpr(other.derived(), internal::scalar_gamma_sample_der_alpha_op()); + } + // igammac(a = this, x = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> diff --git a/unsupported/Eigen/SpecialFunctions b/unsupported/Eigen/SpecialFunctions index 482ec6e6f..9441ba8f5 100644 --- a/unsupported/Eigen/SpecialFunctions +++ b/unsupported/Eigen/SpecialFunctions @@ -29,6 +29,8 @@ namespace Eigen { * - erfc * - lgamma * - igamma + * - igamma_der_a + * - gamma_sample_der_alpha * - igammac * - digamma * - polygamma diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h index b7a9d035b..30cdf4751 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h @@ -33,6 +33,48 @@ igamma(const Eigen::ArrayBase& a, const Eigen::ArrayBase +inline const Eigen::CwiseBinaryOp, const Derived, const ExponentDerived> +igamma_der_a(const Eigen::ArrayBase& a, const Eigen::ArrayBase& x) { + return Eigen::CwiseBinaryOp, const Derived, const ExponentDerived>( + a.derived(), + x.derived()); +} + +/** \cpp11 \returns an expression of the coefficient-wise gamma_sample_der_alpha(\a alpha, \a sample) to the given arrays. + * + * This function computes the coefficient-wise derivative of the sample + * of a Gamma(alpha, 1) random variable with respect to the parameter alpha. + * + * \note This function supports only float and double scalar types in c++11 + * mode. To support other scalar types, + * or float/double in non c++11 mode, the user has to provide implementations + * of gamma_sample_der_alpha(T,T) for any scalar + * type T to be supported. + * + * \sa Eigen::igamma(), Eigen::lgamma() + */ +template +inline const Eigen::CwiseBinaryOp, const AlphaDerived, const SampleDerived> +gamma_sample_der_alpha(const Eigen::ArrayBase& alpha, const Eigen::ArrayBase& sample) { + return Eigen::CwiseBinaryOp, const AlphaDerived, const SampleDerived>( + alpha.derived(), + sample.derived()); +} + /** \cpp11 \returns an expression of the coefficient-wise igammac(\a a, \a x) to the given arrays. * * This function computes the coefficient-wise complementary incomplete gamma function. diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h index 8420f0174..3a63dcdd6 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h @@ -41,6 +41,60 @@ struct functor_traits > { }; }; +/** \internal + * \brief Template functor to compute the derivative of the incomplete gamma + * function igamma_der_a(a, x) + * + * \sa class CwiseBinaryOp, Cwise::igamma_der_a + */ +template +struct scalar_igamma_der_a_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_der_a_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const { + using numext::igamma_der_a; + return igamma_der_a(a, x); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const { + return internal::pigamma_der_a(a, x); + } +}; +template +struct functor_traits > { + enum { + // 2x the cost of igamma + Cost = 40 * NumTraits::MulCost + 20 * NumTraits::AddCost, + PacketAccess = packet_traits::HasIGammaDerA + }; +}; + +/** \internal + * \brief Template functor to compute the derivative of the sample + * of a Gamma(alpha, 1) random variable with respect to the parameter alpha + * gamma_sample_der_alpha(alpha, sample) + * + * \sa class CwiseBinaryOp, Cwise::gamma_sample_der_alpha + */ +template +struct scalar_gamma_sample_der_alpha_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_gamma_sample_der_alpha_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const { + using numext::gamma_sample_der_alpha; + return gamma_sample_der_alpha(alpha, sample); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const { + return internal::pgamma_sample_der_alpha(alpha, sample); + } +}; +template +struct functor_traits > { + enum { + // 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out) + Cost = 30 * NumTraits::MulCost + 15 * NumTraits::AddCost, + PacketAccess = packet_traits::HasGammaSampleDerAlpha + }; +}; /** \internal * \brief Template functor to compute the complementary incomplete gamma function igammac(a, x) diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h index c5867002e..fbdfd299e 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h @@ -33,6 +33,14 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half erfc(const Eigen::h template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen::half& a, const Eigen::half& x) { return Eigen::half(Eigen::numext::igamma(static_cast(a), static_cast(x))); } +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma_der_a(const Eigen::half& a, const Eigen::half& x) { + return Eigen::half(Eigen::numext::igamma_der_a(static_cast(a), static_cast(x))); +} +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half gamma_sample_der_alpha(const Eigen::half& alpha, const Eigen::half& sample) { + return Eigen::half(Eigen::numext::gamma_sample_der_alpha(static_cast(alpha), static_cast(sample))); +} template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) { return Eigen::half(Eigen::numext::igammac(static_cast(a), static_cast(x))); } diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index 293b0597b..444fd14d9 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -521,6 +521,197 @@ struct cephes_helper { } }; +enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE }; + +template +EIGEN_DEVICE_FUNC +int igamma_num_iterations() { + /* Returns the maximum number of internal iterations for igamma computation. + */ + if (mode == VALUE) { + return 2000; + } + + if (internal::is_same::value) { + return 200; + } else if (internal::is_same::value) { + return 500; + } else { + return 2000; + } +} + +template +struct igammac_cf_impl { + /* Computes igamc(a, x) or derivative (depending on the mode) + * using the continued fraction expansion of the complementary + * incomplete Gamma function. + * + * Preconditions: + * a > 0 + * x >= 1 + * x >= a + */ + EIGEN_DEVICE_FUNC + static Scalar run(Scalar a, Scalar x) { + const Scalar zero = 0; + const Scalar one = 1; + const Scalar two = 2; + const Scalar machep = cephes_helper::machep(); + const Scalar big = cephes_helper::big(); + const Scalar biginv = cephes_helper::biginv(); + + if ((numext::isinf)(x)) { + return zero; + } + + // continued fraction + Scalar y = one - a; + Scalar z = x + y + one; + Scalar c = zero; + Scalar pkm2 = one; + Scalar qkm2 = x; + Scalar pkm1 = x + one; + Scalar qkm1 = z * x; + Scalar ans = pkm1 / qkm1; + + Scalar dpkm2_da = zero; + Scalar dqkm2_da = zero; + Scalar dpkm1_da = zero; + Scalar dqkm1_da = -x; + Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1; + + for (int i = 0; i < igamma_num_iterations(); i++) { + c += one; + y += one; + z += two; + + Scalar yc = y * c; + Scalar pk = pkm1 * z - pkm2 * yc; + Scalar qk = qkm1 * z - qkm2 * yc; + + Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c; + Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c; + + if (qk != zero) { + Scalar ans_prev = ans; + ans = pk / qk; + + Scalar dans_da_prev = dans_da; + dans_da = (dpk_da - ans * dqk_da) / qk; + + if (mode == VALUE) { + if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) { + break; + } + } else { + if (numext::abs(dans_da - dans_da_prev) <= machep) { + break; + } + } + } + + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + dpkm2_da = dpkm1_da; + dpkm1_da = dpk_da; + dqkm2_da = dqkm1_da; + dqkm1_da = dqk_da; + + if (numext::abs(pk) > big) { + pkm2 *= biginv; + pkm1 *= biginv; + qkm2 *= biginv; + qkm1 *= biginv; + + dpkm2_da *= biginv; + dpkm1_da *= biginv; + dqkm2_da *= biginv; + dqkm1_da *= biginv; + } + } + + /* Compute x**a * exp(-x) / gamma(a) */ + Scalar logax = a * numext::log(x) - x - lgamma_impl::run(a); + Scalar dlogax_da = numext::log(x) - digamma_impl::run(a); + Scalar ax = numext::exp(logax); + Scalar dax_da = ax * dlogax_da; + + switch (mode) { + case VALUE: + return ans * ax; + case DERIVATIVE: + return ans * dax_da + dans_da * ax; + case SAMPLE_DERIVATIVE: + return -(dans_da + ans * dlogax_da) * x; + } + } +}; + +template +struct igamma_series_impl { + /* Computes igam(a, x) or its derivative (depending on the mode) + * using the series expansion of the incomplete Gamma function. + * + * Preconditions: + * x > 0 + * a > 0 + * !(x > 1 && x > a) + */ + EIGEN_DEVICE_FUNC + static Scalar run(Scalar a, Scalar x) { + const Scalar zero = 0; + const Scalar one = 1; + const Scalar machep = cephes_helper::machep(); + + /* power series */ + Scalar r = a; + Scalar c = one; + Scalar ans = one; + + Scalar dc_da = zero; + Scalar dans_da = zero; + + for (int i = 0; i < igamma_num_iterations(); i++) { + r += one; + Scalar term = x / r; + Scalar dterm_da = -x / (r * r); + dc_da = term * dc_da + dterm_da * c; + dans_da += dc_da; + c *= term; + ans += c; + + if (mode == VALUE) { + if (c <= machep * ans) { + break; + } + } else { + if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) { + break; + } + } + } + + /* Compute x**a * exp(-x) / gamma(a + 1) */ + Scalar logax = a * numext::log(x) - x - lgamma_impl::run(a + one); + Scalar dlogax_da = numext::log(x) - digamma_impl::run(a + one); + Scalar ax = numext::exp(logax); + Scalar dax_da = ax * dlogax_da; + + switch (mode) { + case VALUE: + return ans * ax; + case DERIVATIVE: + return ans * dax_da + dans_da * ax; + case SAMPLE_DERIVATIVE: + return -(dans_da + ans * dlogax_da) * x / a; + } + } +}; + #if !EIGEN_HAS_C99_MATH template @@ -535,8 +726,6 @@ struct igammac_impl { #else -template struct igamma_impl; // predeclare igamma_impl - template struct igammac_impl { EIGEN_DEVICE_FUNC @@ -604,97 +793,15 @@ struct igammac_impl { return nan; } - if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans + if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans return nan; } if ((x < one) || (x < a)) { - /* The checks above ensure that we meet the preconditions for - * igamma_impl::Impl(), so call it, rather than igamma_impl::Run(). - * Calling Run() would also work, but in that case the compiler may not be - * able to prove that igammac_impl::Run and igamma_impl::Run are not - * mutually recursive. This leads to worse code, particularly on - * platforms like nvptx, where recursion is allowed only begrudgingly. - */ - return (one - igamma_impl::Impl(a, x)); + return (one - igamma_series_impl::run(a, x)); } - return Impl(a, x); - } - - private: - /* igamma_impl calls igammac_impl::Impl. */ - friend struct igamma_impl; - - /* Actually computes igamc(a, x). - * - * Preconditions: - * a > 0 - * x >= 1 - * x >= a - */ - EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) { - const Scalar zero = 0; - const Scalar one = 1; - const Scalar two = 2; - const Scalar machep = cephes_helper::machep(); - const Scalar maxlog = numext::log(NumTraits::highest()); - const Scalar big = cephes_helper::big(); - const Scalar biginv = cephes_helper::biginv(); - const Scalar inf = NumTraits::infinity(); - - Scalar ans, ax, c, yc, r, t, y, z; - Scalar pk, pkm1, pkm2, qk, qkm1, qkm2; - - if (x == inf) return zero; // std::isinf crashes on CUDA - - /* Compute x**a * exp(-x) / gamma(a) */ - ax = a * numext::log(x) - x - lgamma_impl::run(a); - if (ax < -maxlog) { // underflow - return zero; - } - ax = numext::exp(ax); - - // continued fraction - y = one - a; - z = x + y + one; - c = zero; - pkm2 = one; - qkm2 = x; - pkm1 = x + one; - qkm1 = z * x; - ans = pkm1 / qkm1; - - for (int i = 0; i < 2000; i++) { - c += one; - y += one; - z += two; - yc = y * c; - pk = pkm1 * z - pkm2 * yc; - qk = qkm1 * z - qkm2 * yc; - if (qk != zero) { - r = pk / qk; - t = numext::abs((ans - r) / r); - ans = r; - } else { - t = one; - } - pkm2 = pkm1; - pkm1 = pk; - qkm2 = qkm1; - qkm1 = qk; - if (numext::abs(pk) > big) { - pkm2 *= biginv; - pkm1 *= biginv; - qkm2 *= biginv; - qkm1 *= biginv; - } - if (t <= machep) { - break; - } - } - - return (ans * ax); + return igammac_cf_impl::run(a, x); } }; @@ -704,15 +811,10 @@ struct igammac_impl { * Implementation of igamma (incomplete gamma integral), based on Cephes but requires C++11/C99 * ************************************************************************************************/ -template -struct igamma_retval { - typedef Scalar type; -}; - #if !EIGEN_HAS_C99_MATH -template -struct igamma_impl { +template +struct igamma_generic_impl { EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) { EIGEN_STATIC_ASSERT((internal::is_same::value == false), @@ -723,69 +825,17 @@ struct igamma_impl { #else -template -struct igamma_impl { +template +struct igamma_generic_impl { EIGEN_DEVICE_FUNC static Scalar run(Scalar a, Scalar x) { - /* igam() - * Incomplete gamma integral - * - * - * - * SYNOPSIS: - * - * double a, x, y, igam(); - * - * y = igam( a, x ); - * - * DESCRIPTION: - * - * The function is defined by - * - * x - * - - * 1 | | -t a-1 - * igam(a,x) = ----- | e t dt. - * - | | - * | (a) - - * 0 - * - * - * In this implementation both arguments must be positive. - * The integral is evaluated by either a power series or - * continued fraction expansion, depending on the relative - * values of a and x. - * - * ACCURACY (double): - * - * Relative error: - * arithmetic domain # trials peak rms - * IEEE 0,30 200000 3.6e-14 2.9e-15 - * IEEE 0,100 300000 9.9e-14 1.5e-14 - * - * - * ACCURACY (float): - * - * Relative error: - * arithmetic domain # trials peak rms - * IEEE 0,30 20000 7.8e-6 5.9e-7 - * - */ - /* - Cephes Math Library Release 2.2: June, 1992 - Copyright 1985, 1987, 1992 by Stephen L. Moshier - Direct inquiries to 30 Frost Street, Cambridge, MA 02140 - */ - - - /* left tail of incomplete gamma function: - * - * inf. k - * a -x - x - * x e > ---------- - * - - - * k=0 | (a+k+1) + /* Depending on the mode, returns + * - VALUE: incomplete Gamma function igamma(a, x) + * - DERIVATIVE: derivative of incomplete Gamma function d/da igamma(a, x) + * - SAMPLE_DERIVATIVE: implicit derivative of a Gamma random variable + * x ~ Gamma(x | a, 1), dx/da = -1 / Gamma(x | a, 1) * d igamma(a, x) / dx * + * Derivatives are implemented by forward-mode differentiation. */ const Scalar zero = 0; const Scalar one = 1; @@ -797,71 +847,167 @@ struct igamma_impl { return nan; } - if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans + if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans return nan; } if ((x > one) && (x > a)) { - /* The checks above ensure that we meet the preconditions for - * igammac_impl::Impl(), so call it, rather than igammac_impl::Run(). - * Calling Run() would also work, but in that case the compiler may not be - * able to prove that igammac_impl::Run and igamma_impl::Run are not - * mutually recursive. This leads to worse code, particularly on - * platforms like nvptx, where recursion is allowed only begrudgingly. - */ - return (one - igammac_impl::Impl(a, x)); - } - - return Impl(a, x); - } - - private: - /* igammac_impl calls igamma_impl::Impl. */ - friend struct igammac_impl; - - /* Actually computes igam(a, x). - * - * Preconditions: - * x > 0 - * a > 0 - * !(x > 1 && x > a) - */ - EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) { - const Scalar zero = 0; - const Scalar one = 1; - const Scalar machep = cephes_helper::machep(); - const Scalar maxlog = numext::log(NumTraits::highest()); - - Scalar ans, ax, c, r; - - /* Compute x**a * exp(-x) / gamma(a) */ - ax = a * numext::log(x) - x - lgamma_impl::run(a); - if (ax < -maxlog) { - // underflow - return zero; - } - ax = numext::exp(ax); - - /* power series */ - r = a; - c = one; - ans = one; - - for (int i = 0; i < 2000; i++) { - r += one; - c *= x/r; - ans += c; - if (c/ans <= machep) { - break; + Scalar ret = igammac_cf_impl::run(a, x); + if (mode == VALUE) { + return one - ret; + } else { + return -ret; } } - return (ans * ax / a); + return igamma_series_impl::run(a, x); } }; #endif // EIGEN_HAS_C99_MATH +template +struct igamma_retval { + typedef Scalar type; +}; + +template +struct igamma_impl : igamma_generic_impl { + /* igam() + * Incomplete gamma integral. + * + * The CDF of Gamma(a, 1) random variable at the point x. + * + * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample + * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points. + * The ground truth is computed by mpmath. Mean absolute error: + * float: 1.26713e-05 + * double: 2.33606e-12 + * + * Cephes documentation below. + * + * SYNOPSIS: + * + * double a, x, y, igam(); + * + * y = igam( a, x ); + * + * DESCRIPTION: + * + * The function is defined by + * + * x + * - + * 1 | | -t a-1 + * igam(a,x) = ----- | e t dt. + * - | | + * | (a) - + * 0 + * + * + * In this implementation both arguments must be positive. + * The integral is evaluated by either a power series or + * continued fraction expansion, depending on the relative + * values of a and x. + * + * ACCURACY (double): + * + * Relative error: + * arithmetic domain # trials peak rms + * IEEE 0,30 200000 3.6e-14 2.9e-15 + * IEEE 0,100 300000 9.9e-14 1.5e-14 + * + * + * ACCURACY (float): + * + * Relative error: + * arithmetic domain # trials peak rms + * IEEE 0,30 20000 7.8e-6 5.9e-7 + * + */ + /* + Cephes Math Library Release 2.2: June, 1992 + Copyright 1985, 1987, 1992 by Stephen L. Moshier + Direct inquiries to 30 Frost Street, Cambridge, MA 02140 + */ + + /* left tail of incomplete gamma function: + * + * inf. k + * a -x - x + * x e > ---------- + * - - + * k=0 | (a+k+1) + * + */ +}; + +template +struct igamma_der_a_retval : igamma_retval {}; + +template +struct igamma_der_a_impl : igamma_generic_impl { + /* Derivative of the incomplete Gamma function with respect to a. + * + * Computes d/da igamma(a, x) by forward differentiation of the igamma code. + * + * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample + * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points. + * The ground truth is computed by mpmath. Mean absolute error: + * float: 6.17992e-07 + * double: 4.60453e-12 + * + * Reference: + * R. Moore. "Algorithm AS 187: Derivatives of the incomplete gamma + * integral". Journal of the Royal Statistical Society. 1982 + */ +}; + +template +struct gamma_sample_der_alpha_retval : igamma_retval {}; + +template +struct gamma_sample_der_alpha_impl + : igamma_generic_impl { + /* Derivative of a Gamma random variable sample with respect to alpha. + * + * Consider a sample of a Gamma random variable with the concentration + * parameter alpha: sample ~ Gamma(alpha, 1). The reparameterization + * derivative that we want to compute is dsample / dalpha = + * d igammainv(alpha, u) / dalpha, where u = igamma(alpha, sample). + * However, this formula is numerically unstable and expensive, so instead + * we use implicit differentiation: + * + * igamma(alpha, sample) = u, where u ~ Uniform(0, 1). + * Apply d / dalpha to both sides: + * d igamma(alpha, sample) / dalpha + * + d igamma(alpha, sample) / dsample * dsample/dalpha = 0 + * d igamma(alpha, sample) / dalpha + * + Gamma(sample | alpha, 1) dsample / dalpha = 0 + * dsample/dalpha = - (d igamma(alpha, sample) / dalpha) + * / Gamma(sample | alpha, 1) + * + * Here Gamma(sample | alpha, 1) is the PDF of the Gamma distribution + * (note that the derivative of the CDF w.r.t. sample is the PDF). + * See the reference below for more details. + * + * The derivative of igamma(alpha, sample) is computed by forward + * differentiation of the igamma code. Division by the Gamma PDF is performed + * in the same code, increasing the accuracy and speed due to cancellation + * of some terms. + * + * Accuracy estimation. For each alpha in [10^-2, 10^-1...10^3] we sample + * 50 Gamma random variables sample ~ Gamma(sample | alpha, 1), a total of 300 + * points. The ground truth is computed by mpmath. Mean absolute error: + * float: 2.1686e-06 + * double: 1.4774e-12 + * + * Reference: + * M. Figurnov, S. Mohamed, A. Mnih "Implicit Reparameterization Gradients". + * 2018 + */ +}; + /***************************************************************************** * Implementation of Riemann zeta function of two arguments, based on Cephes * *****************************************************************************/ @@ -1950,6 +2096,18 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma, Scalar) return EIGEN_MATHFUNC_IMPL(igamma, Scalar)::run(a, x); } +template +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma_der_a, Scalar) + igamma_der_a(const Scalar& a, const Scalar& x) { + return EIGEN_MATHFUNC_IMPL(igamma_der_a, Scalar)::run(a, x); +} + +template +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(gamma_sample_der_alpha, Scalar) + gamma_sample_der_alpha(const Scalar& a, const Scalar& x) { + return EIGEN_MATHFUNC_IMPL(gamma_sample_der_alpha, Scalar)::run(a, x); +} + template EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar) igammac(const Scalar& a, const Scalar& x) { diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h index 4c176716b..465f41d54 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h @@ -42,6 +42,21 @@ Packet perfc(const Packet& a) { using numext::erfc; return erfc(a); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pigamma(const Packet& a, const Packet& x) { using numext::igamma; return igamma(a, x); } +/** \internal \returns the derivative of the incomplete gamma function + * igamma_der_a(\a a, \a x) */ +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pigamma_der_a(const Packet& a, const Packet& x) { + using numext::igamma_der_a; return igamma_der_a(a, x); +} + +/** \internal \returns compute the derivative of the sample + * of Gamma(alpha, 1) random variable with respect to the parameter a + * gamma_sample_der_alpha(\a alpha, \a sample) */ +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pgamma_sample_der_alpha(const Packet& alpha, const Packet& sample) { + using numext::gamma_sample_der_alpha; return gamma_sample_der_alpha(alpha, sample); +} + /** \internal \returns the complementary incomplete gamma function igammac(\a a, \a x) */ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pigammac(const Packet& a, const Packet& x) { using numext::igammac; return igammac(a, x); } diff --git a/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h b/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h index c25fea0b3..020ac1b62 100644 --- a/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h +++ b/unsupported/Eigen/src/SpecialFunctions/arch/CUDA/CudaSpecialFunctions.h @@ -120,6 +120,41 @@ double2 pigamma(const double2& a, const double2& x) return make_double2(igamma(a.x, x.x), igamma(a.y, x.y)); } +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pigamma_der_a( + const float4& a, const float4& x) { + using numext::igamma_der_a; + return make_float4(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y), + igamma_der_a(a.z, x.z), igamma_der_a(a.w, x.w)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 +pigamma_der_a(const double2& a, const double2& x) { + using numext::igamma_der_a; + return make_double2(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pgamma_sample_der_alpha( + const float4& alpha, const float4& sample) { + using numext::gamma_sample_der_alpha; + return make_float4( + gamma_sample_der_alpha(alpha.x, sample.x), + gamma_sample_der_alpha(alpha.y, sample.y), + gamma_sample_der_alpha(alpha.z, sample.z), + gamma_sample_der_alpha(alpha.w, sample.w)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 +pgamma_sample_der_alpha(const double2& alpha, const double2& sample) { + using numext::gamma_sample_der_alpha; + return make_double2( + gamma_sample_der_alpha(alpha.x, sample.x), + gamma_sample_der_alpha(alpha.y, sample.y)); +} + template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pigammac(const float4& a, const float4& x) { diff --git a/unsupported/test/cxx11_tensor_cuda.cu b/unsupported/test/cxx11_tensor_cuda.cu index 63d0a345a..f238ed5be 100644 --- a/unsupported/test/cxx11_tensor_cuda.cu +++ b/unsupported/test/cxx11_tensor_cuda.cu @@ -1318,6 +1318,157 @@ void test_cuda_i1e() cudaFree(d_out); } +template +void test_cuda_igamma_der_a() +{ + Tensor in_x(30); + Tensor in_a(30); + Tensor out(30); + Tensor expected_out(30); + out.setZero(); + + Array in_a_array(30); + Array in_x_array(30); + Array expected_out_array(30); + + // See special_functions.cpp for the Python code that generates the test data. + + in_a_array << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, + 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0, + 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0; + + in_x_array << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05, + 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 0.0132865061065, + 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 0.333412038288, + 1.18135687766, 0.580629033777, 0.170631439426, 0.786686768458, + 7.63873279537, 13.1944344379, 11.896042354, 10.5830172417, 10.5020942233, + 92.8918587747, 95.003720371, 86.3715926467, 96.0330217672, 82.6389930677, + 968.702906754, 969.463546828, 1001.79726022, 955.047416547, 1044.27458568; + + expected_out_array << -32.7256441441, -36.4394150514, -9.66467612263, + -36.4394150514, -36.4394150514, -1.0891900302, -2.66351229645, + -2.48666868596, -0.929700494428, -3.56327722764, -0.455320135314, + -0.391437214323, -0.491352055991, -0.350454834292, -0.471773162921, + -0.104084440522, -0.0723646747909, -0.0992828975532, -0.121638215446, + -0.122619605294, -0.0317670267286, -0.0359974812869, -0.0154359225363, + -0.0375775365921, -0.00794899153653, -0.00777303219211, -0.00796085782042, + -0.0125850719397, -0.00455500206958, -0.00476436993148; + + for (int i = 0; i < 30; ++i) { + in_x(i) = in_x_array(i); + in_a(i) = in_a_array(i); + expected_out(i) = expected_out_array(i); + } + + std::size_t bytes = in_x.size() * sizeof(Scalar); + + Scalar* d_a; + Scalar* d_x; + Scalar* d_out; + cudaMalloc((void**)(&d_a), bytes); + cudaMalloc((void**)(&d_x), bytes); + cudaMalloc((void**)(&d_out), bytes); + + cudaMemcpy(d_a, in_a.data(), bytes, cudaMemcpyHostToDevice); + cudaMemcpy(d_x, in_x.data(), bytes, cudaMemcpyHostToDevice); + + Eigen::CudaStreamDevice stream; + Eigen::GpuDevice gpu_device(&stream); + + Eigen::TensorMap > gpu_a(d_a, 30); + Eigen::TensorMap > gpu_x(d_x, 30); + Eigen::TensorMap > gpu_out(d_out, 30); + + gpu_out.device(gpu_device) = gpu_a.igamma_der_a(gpu_x); + + assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost, + gpu_device.stream()) == cudaSuccess); + assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess); + + for (int i = 0; i < 30; ++i) { + VERIFY_IS_APPROX(out(i), expected_out(i)); + } + + cudaFree(d_a); + cudaFree(d_x); + cudaFree(d_out); +} + +template +void test_cuda_gamma_sample_der_alpha() +{ + Tensor in_alpha(30); + Tensor in_sample(30); + Tensor out(30); + Tensor expected_out(30); + out.setZero(); + + Array in_alpha_array(30); + Array in_sample_array(30); + Array expected_out_array(30); + + // See special_functions.cpp for the Python code that generates the test data. + + in_alpha_array << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, + 1.0, 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, + 100.0, 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0; + + in_sample_array << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05, + 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 0.0132865061065, + 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 0.333412038288, + 1.18135687766, 0.580629033777, 0.170631439426, 0.786686768458, + 7.63873279537, 13.1944344379, 11.896042354, 10.5830172417, 10.5020942233, + 92.8918587747, 95.003720371, 86.3715926467, 96.0330217672, 82.6389930677, + 968.702906754, 969.463546828, 1001.79726022, 955.047416547, 1044.27458568; + + expected_out_array << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738, + 1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243, + 0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302, + 1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534, + 0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812, + 1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061, + 0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206, + 1.00106492525, 0.97734200649, 1.02198794179; + + for (int i = 0; i < 30; ++i) { + in_alpha(i) = in_alpha_array(i); + in_sample(i) = in_sample_array(i); + expected_out(i) = expected_out_array(i); + } + + std::size_t bytes = in_alpha.size() * sizeof(Scalar); + + Scalar* d_alpha; + Scalar* d_sample; + Scalar* d_out; + cudaMalloc((void**)(&d_alpha), bytes); + cudaMalloc((void**)(&d_sample), bytes); + cudaMalloc((void**)(&d_out), bytes); + + cudaMemcpy(d_alpha, in_alpha.data(), bytes, cudaMemcpyHostToDevice); + cudaMemcpy(d_sample, in_sample.data(), bytes, cudaMemcpyHostToDevice); + + Eigen::CudaStreamDevice stream; + Eigen::GpuDevice gpu_device(&stream); + + Eigen::TensorMap > gpu_alpha(d_alpha, 30); + Eigen::TensorMap > gpu_sample(d_sample, 30); + Eigen::TensorMap > gpu_out(d_out, 30); + + gpu_out.device(gpu_device) = gpu_alpha.gamma_sample_der_alpha(gpu_sample); + + assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost, + gpu_device.stream()) == cudaSuccess); + assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess); + + for (int i = 0; i < 30; ++i) { + VERIFY_IS_APPROX(out(i), expected_out(i)); + } + + cudaFree(d_alpha); + cudaFree(d_sample); + cudaFree(d_out); +} void test_cxx11_tensor_cuda() { @@ -1396,5 +1547,11 @@ void test_cxx11_tensor_cuda() CALL_SUBTEST_6(test_cuda_i1e()); CALL_SUBTEST_6(test_cuda_i1e()); + + CALL_SUBTEST_6(test_cuda_igamma_der_a()); + CALL_SUBTEST_6(test_cuda_igamma_der_a()); + + CALL_SUBTEST_6(test_cuda_gamma_sample_der_alpha()); + CALL_SUBTEST_6(test_cuda_gamma_sample_der_alpha()); #endif } diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp index 0729dd4dc..802e16150 100644 --- a/unsupported/test/special_functions.cpp +++ b/unsupported/test/special_functions.cpp @@ -376,6 +376,100 @@ template void array_special_functions() CALL_SUBTEST(res = i1e(x); verify_component_wise(res, expected);); } + + /* Code to generate the data for the following two test cases. + N = 5 + np.random.seed(3) + + a = np.logspace(-2, 3, 6) + a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N])) + x = np.random.gamma(a, 1.0) + x = np.maximum(x, np.finfo(np.float32).tiny) + + def igamma(a, x): + return mpmath.gammainc(a, 0, x, regularized=True) + + def igamma_der_a(a, x): + res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a) + return np.float64(res) + + def gamma_sample_der_alpha(a, x): + igamma_x = igamma(a, x) + def igammainv_of_igamma(a_prime): + return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) - + igamma_x, x, solver='newton') + return np.float64(mpmath.diff(igammainv_of_igamma, a)) + + v_igamma_der_a = np.vectorize(igamma_der_a)(a, x) + v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x) + */ + +#if EIGEN_HAS_C99_MATH + // Test igamma_der_a + { + ArrayType a(30); + ArrayType x(30); + ArrayType res(30); + ArrayType v(30); + + a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, + 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0, + 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0; + + x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05, + 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, + 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, + 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426, + 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354, + 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371, + 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754, + 969.463546828, 1001.79726022, 955.047416547, 1044.27458568; + + v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514, + -36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596, + -0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323, + -0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522, + -0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294, + -0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921, + -0.00794899153653, -0.00777303219211, -0.00796085782042, + -0.0125850719397, -0.00455500206958, -0.00476436993148; + + CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v);); + } + + // Test gamma_sample_der_alpha + { + ArrayType alpha(30); + ArrayType sample(30); + ArrayType res(30); + ArrayType v(30); + + alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, + 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0, + 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0; + + sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05, + 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, + 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, + 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426, + 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354, + 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371, + 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754, + 969.463546828, 1001.79726022, 955.047416547, 1044.27458568; + + v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738, + 1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243, + 0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302, + 1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534, + 0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812, + 1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061, + 0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206, + 1.00106492525, 0.97734200649, 1.02198794179; + + CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample); + verify_component_wise(res, v);); + } +#endif // EIGEN_HAS_C99_MATH } void test_special_functions()