mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-31 19:00:35 +08:00
Add TernaryFunctors and the betainc SpecialFunction.
TernaryFunctors and their executors allow operations on 3-tuples of inputs. API fully implemented for Arrays and Tensors based on binary functors. Ported the cephes betainc function (regularized incomplete beta integral) to Eigen, with support for CPU and GPU, floats, doubles, and half types. Added unit tests in array.cpp and cxx11_tensor_cuda.cu Collapsed revision * Merged helper methods for betainc across floats and doubles. * Added TensorGlobalFunctions with betainc(). Removed betainc() from TensorBase. * Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper. * betainc: merge incbcf and incbd into incbeta_cfe. and more cleanup. * Update TernaryOp and SpecialFunctions (betainc) based on review comments.
This commit is contained in:
parent
02db4e1a82
commit
39baff850c
@ -371,6 +371,7 @@ using std::ptrdiff_t;
|
||||
|
||||
#include "src/Core/arch/Default/Settings.h"
|
||||
|
||||
#include "src/Core/functors/TernaryFunctors.h"
|
||||
#include "src/Core/functors/BinaryFunctors.h"
|
||||
#include "src/Core/functors/UnaryFunctors.h"
|
||||
#include "src/Core/functors/NullaryFunctors.h"
|
||||
@ -403,6 +404,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/PlainObjectBase.h"
|
||||
#include "src/Core/Matrix.h"
|
||||
#include "src/Core/Array.h"
|
||||
#include "src/Core/CwiseTernaryOp.h"
|
||||
#include "src/Core/CwiseBinaryOp.h"
|
||||
#include "src/Core/CwiseUnaryOp.h"
|
||||
#include "src/Core/CwiseNullaryOp.h"
|
||||
|
@ -41,10 +41,19 @@ template<> struct storage_kind_to_shape<TranspositionsStorage> { typedef Transp
|
||||
// We currently distinguish the following kind of evaluators:
|
||||
// - unary_evaluator for expressions taking only one arguments (CwiseUnaryOp, CwiseUnaryView, Transpose, MatrixWrapper, ArrayWrapper, Reverse, Replicate)
|
||||
// - binary_evaluator for expression taking two arguments (CwiseBinaryOp)
|
||||
// - ternary_evaluator for expression taking three arguments (CwiseTernaryOp)
|
||||
// - product_evaluator for linear algebra products (Product); special case of binary_evaluator because it requires additional tags for dispatching.
|
||||
// - mapbase_evaluator for Map, Block, Ref
|
||||
// - block_evaluator for Block (special dispatching to a mapbase_evaluator or unary_evaluator)
|
||||
|
||||
template< typename T,
|
||||
typename Arg1Kind = typename evaluator_traits<typename T::Arg1>::Kind,
|
||||
typename Arg2Kind = typename evaluator_traits<typename T::Arg2>::Kind,
|
||||
typename Arg3Kind = typename evaluator_traits<typename T::Arg3>::Kind,
|
||||
typename Arg1Scalar = typename traits<typename T::Arg1>::Scalar,
|
||||
typename Arg2Scalar = typename traits<typename T::Arg2>::Scalar,
|
||||
typename Arg3Scalar = typename traits<typename T::Arg3>::Scalar> struct ternary_evaluator;
|
||||
|
||||
template< typename T,
|
||||
typename LhsKind = typename evaluator_traits<typename T::Lhs>::Kind,
|
||||
typename RhsKind = typename evaluator_traits<typename T::Rhs>::Kind,
|
||||
@ -442,6 +451,96 @@ protected:
|
||||
evaluator<ArgType> m_argImpl;
|
||||
};
|
||||
|
||||
// -------------------- CwiseTernaryOp --------------------
|
||||
|
||||
// this is a ternary expression
|
||||
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
|
||||
struct evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
|
||||
: public ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
|
||||
{
|
||||
typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType;
|
||||
typedef ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > Base;
|
||||
|
||||
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {}
|
||||
};
|
||||
|
||||
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
|
||||
struct ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>, IndexBased, IndexBased>
|
||||
: evaluator_base<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
|
||||
{
|
||||
typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType;
|
||||
|
||||
enum {
|
||||
CoeffReadCost = evaluator<Arg1>::CoeffReadCost + evaluator<Arg2>::CoeffReadCost + evaluator<Arg3>::CoeffReadCost + functor_traits<TernaryOp>::Cost,
|
||||
|
||||
Arg1Flags = evaluator<Arg1>::Flags,
|
||||
Arg2Flags = evaluator<Arg2>::Flags,
|
||||
Arg3Flags = evaluator<Arg3>::Flags,
|
||||
SameType = is_same<typename Arg1::Scalar,typename Arg2::Scalar>::value && is_same<typename Arg1::Scalar,typename Arg3::Scalar>::value,
|
||||
StorageOrdersAgree = (int(Arg1Flags)&RowMajorBit)==(int(Arg2Flags)&RowMajorBit) && (int(Arg1Flags)&RowMajorBit)==(int(Arg3Flags)&RowMajorBit),
|
||||
Flags0 = (int(Arg1Flags) | int(Arg2Flags) | int(Arg3Flags)) & (
|
||||
HereditaryBits
|
||||
| (int(Arg1Flags) & int(Arg2Flags) & int(Arg3Flags) &
|
||||
( (StorageOrdersAgree ? LinearAccessBit : 0)
|
||||
| (functor_traits<TernaryOp>::PacketAccess && StorageOrdersAgree && SameType ? PacketAccessBit : 0)
|
||||
)
|
||||
)
|
||||
),
|
||||
Flags = (Flags0 & ~RowMajorBit) | (Arg1Flags & RowMajorBit),
|
||||
Alignment = EIGEN_PLAIN_ENUM_MIN(
|
||||
EIGEN_PLAIN_ENUM_MIN(evaluator<Arg1>::Alignment, evaluator<Arg2>::Alignment),
|
||||
evaluator<Arg3>::Alignment)
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr)
|
||||
: m_functor(xpr.functor()),
|
||||
m_arg1Impl(xpr.arg1()),
|
||||
m_arg2Impl(xpr.arg2()),
|
||||
m_arg3Impl(xpr.arg3())
|
||||
{
|
||||
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<TernaryOp>::Cost);
|
||||
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
|
||||
}
|
||||
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
CoeffReturnType coeff(Index row, Index col) const
|
||||
{
|
||||
return m_functor(m_arg1Impl.coeff(row, col), m_arg2Impl.coeff(row, col), m_arg3Impl.coeff(row, col));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
|
||||
}
|
||||
|
||||
template<int LoadMode, typename PacketType>
|
||||
EIGEN_STRONG_INLINE
|
||||
PacketType packet(Index row, Index col) const
|
||||
{
|
||||
return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(row, col),
|
||||
m_arg2Impl.template packet<LoadMode,PacketType>(row, col),
|
||||
m_arg3Impl.template packet<LoadMode,PacketType>(row, col));
|
||||
}
|
||||
|
||||
template<int LoadMode, typename PacketType>
|
||||
EIGEN_STRONG_INLINE
|
||||
PacketType packet(Index index) const
|
||||
{
|
||||
return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(index),
|
||||
m_arg2Impl.template packet<LoadMode,PacketType>(index),
|
||||
m_arg3Impl.template packet<LoadMode,PacketType>(index));
|
||||
}
|
||||
|
||||
protected:
|
||||
const TernaryOp m_functor;
|
||||
evaluator<Arg1> m_arg1Impl;
|
||||
evaluator<Arg2> m_arg2Impl;
|
||||
evaluator<Arg3> m_arg3Impl;
|
||||
};
|
||||
|
||||
// -------------------- CwiseBinaryOp --------------------
|
||||
|
||||
// this is a binary expression
|
||||
|
212
Eigen/src/Core/CwiseTernaryOp.h
Normal file
212
Eigen/src/Core/CwiseTernaryOp.h
Normal file
@ -0,0 +1,212 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
|
||||
// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CWISE_TERNARY_OP_H
|
||||
#define EIGEN_CWISE_TERNARY_OP_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
|
||||
struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
|
||||
// we must not inherit from traits<Arg1> since it has
|
||||
// the potential to cause problems with MSVC
|
||||
typedef typename remove_all<Arg1>::type Ancestor;
|
||||
typedef typename traits<Ancestor>::XprKind XprKind;
|
||||
enum {
|
||||
RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
|
||||
ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime,
|
||||
MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime,
|
||||
MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime
|
||||
};
|
||||
|
||||
// even though we require Arg1, Arg2, and Arg3 to have the same scalar type
|
||||
// (see CwiseTernaryOp constructor),
|
||||
// we still want to handle the case when the result type is different.
|
||||
typedef typename result_of<TernaryOp(
|
||||
const typename Arg1::Scalar&, const typename Arg2::Scalar&,
|
||||
const typename Arg3::Scalar&)>::type Scalar;
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename internal::traits<Arg1>::StorageKind,
|
||||
typename internal::traits<Arg2>::StorageKind>::value),
|
||||
STORAGE_KIND_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename internal::traits<Arg1>::StorageKind,
|
||||
typename internal::traits<Arg3>::StorageKind>::value),
|
||||
STORAGE_KIND_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename internal::traits<Arg1>::StorageIndex,
|
||||
typename internal::traits<Arg3>::StorageIndex>::value),
|
||||
STORAGE_INDEX_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename internal::traits<Arg1>::StorageIndex,
|
||||
typename internal::traits<Arg3>::StorageIndex>::value),
|
||||
STORAGE_INDEX_MUST_MATCH)
|
||||
|
||||
typedef typename internal::traits<Arg1>::StorageKind StorageKind;
|
||||
typedef typename internal::traits<Arg1>::StorageIndex StorageIndex;
|
||||
|
||||
typedef typename Arg1::Nested Arg1Nested;
|
||||
typedef typename Arg2::Nested Arg2Nested;
|
||||
typedef typename Arg3::Nested Arg3Nested;
|
||||
typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
|
||||
typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
|
||||
typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
|
||||
enum { Flags = _Arg1Nested::Flags & RowMajorBit };
|
||||
};
|
||||
} // end namespace internal
|
||||
|
||||
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
|
||||
typename StorageKind>
|
||||
class CwiseTernaryOpImpl;
|
||||
|
||||
/** \class CwiseTernaryOp
|
||||
* \ingroup Core_Module
|
||||
*
|
||||
* \brief Generic expression where a coefficient-wise ternary operator is
|
||||
* applied to two expressions
|
||||
*
|
||||
* \tparam TernaryOp template functor implementing the operator
|
||||
* \tparam Arg1Type the type of the first argument
|
||||
* \tparam Arg2Type the type of the second argument
|
||||
* \tparam Arg3Type the type of the third argument
|
||||
*
|
||||
* This class represents an expression where a coefficient-wise ternary
|
||||
* operator is applied to three expressions.
|
||||
* It is the return type of ternary operators, by which we mean only those
|
||||
* ternary operators where
|
||||
* all three arguments are Eigen expressions.
|
||||
* For example, the return type of betainc(matrix1, matrix2, matrix3) is a
|
||||
* CwiseTernaryOp.
|
||||
*
|
||||
* Most of the time, this is the only way that it is used, so you typically
|
||||
* don't have to name
|
||||
* CwiseTernaryOp types explicitly.
|
||||
*
|
||||
* \sa MatrixBase::ternaryExpr(const MatrixBase<Argument2> &, const
|
||||
* MatrixBase<Argument3> &, const CustomTernaryOp &) const, class CwiseBinaryOp,
|
||||
* class CwiseUnaryOp, class CwiseNullaryOp
|
||||
*/
|
||||
template <typename TernaryOp, typename Arg1Type, typename Arg2Type,
|
||||
typename Arg3Type>
|
||||
class CwiseTernaryOp : public CwiseTernaryOpImpl<
|
||||
TernaryOp, Arg1Type, Arg2Type, Arg3Type,
|
||||
typename internal::traits<Arg1Type>::StorageKind>,
|
||||
internal::no_assignment_operator {
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<
|
||||
typename internal::traits<Arg1Type>::StorageKind,
|
||||
typename internal::traits<Arg2Type>::StorageKind>::value),
|
||||
STORAGE_KIND_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<
|
||||
typename internal::traits<Arg1Type>::StorageKind,
|
||||
typename internal::traits<Arg3Type>::StorageKind>::value),
|
||||
STORAGE_KIND_MUST_MATCH)
|
||||
|
||||
public:
|
||||
typedef typename internal::remove_all<Arg1Type>::type Arg1;
|
||||
typedef typename internal::remove_all<Arg2Type>::type Arg2;
|
||||
typedef typename internal::remove_all<Arg3Type>::type Arg3;
|
||||
|
||||
typedef typename CwiseTernaryOpImpl<
|
||||
TernaryOp, Arg1Type, Arg2Type, Arg3Type,
|
||||
typename internal::traits<Arg1Type>::StorageKind>::Base Base;
|
||||
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp)
|
||||
|
||||
typedef typename internal::ref_selector<Arg1Type>::type Arg1Nested;
|
||||
typedef typename internal::ref_selector<Arg2Type>::type Arg2Nested;
|
||||
typedef typename internal::ref_selector<Arg3Type>::type Arg3Nested;
|
||||
typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested;
|
||||
typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested;
|
||||
typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& a1, const Arg2& a2,
|
||||
const Arg3& a3,
|
||||
const TernaryOp& func = TernaryOp())
|
||||
: m_arg1(a1), m_arg2(a2), m_arg3(a3), m_functor(func) {
|
||||
// require the sizes to match
|
||||
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
|
||||
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3)
|
||||
eigen_assert(a1.rows() == a2.rows() && a1.cols() == a2.cols() &&
|
||||
a1.rows() == a3.rows() && a1.cols() == a3.cols());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Index rows() const {
|
||||
// return the fixed size type if available to enable compile time
|
||||
// optimizations
|
||||
if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
|
||||
RowsAtCompileTime == Dynamic &&
|
||||
internal::traits<typename internal::remove_all<Arg2Nested>::type>::
|
||||
RowsAtCompileTime == Dynamic)
|
||||
return m_arg3.rows();
|
||||
else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
|
||||
RowsAtCompileTime == Dynamic &&
|
||||
internal::traits<typename internal::remove_all<Arg3Nested>::type>::
|
||||
RowsAtCompileTime == Dynamic)
|
||||
return m_arg2.rows();
|
||||
else
|
||||
return m_arg1.rows();
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Index cols() const {
|
||||
// return the fixed size type if available to enable compile time
|
||||
// optimizations
|
||||
if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
|
||||
ColsAtCompileTime == Dynamic &&
|
||||
internal::traits<typename internal::remove_all<Arg2Nested>::type>::
|
||||
ColsAtCompileTime == Dynamic)
|
||||
return m_arg3.cols();
|
||||
else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
|
||||
ColsAtCompileTime == Dynamic &&
|
||||
internal::traits<typename internal::remove_all<Arg3Nested>::type>::
|
||||
ColsAtCompileTime == Dynamic)
|
||||
return m_arg2.cols();
|
||||
else
|
||||
return m_arg1.cols();
|
||||
}
|
||||
|
||||
/** \returns the first argument nested expression */
|
||||
EIGEN_DEVICE_FUNC
|
||||
const _Arg1Nested& arg1() const { return m_arg1; }
|
||||
/** \returns the first argument nested expression */
|
||||
EIGEN_DEVICE_FUNC
|
||||
const _Arg2Nested& arg2() const { return m_arg2; }
|
||||
/** \returns the third argument nested expression */
|
||||
EIGEN_DEVICE_FUNC
|
||||
const _Arg3Nested& arg3() const { return m_arg3; }
|
||||
/** \returns the functor representing the ternary operation */
|
||||
EIGEN_DEVICE_FUNC
|
||||
const TernaryOp& functor() const { return m_functor; }
|
||||
|
||||
protected:
|
||||
Arg1Nested m_arg1;
|
||||
Arg2Nested m_arg2;
|
||||
Arg3Nested m_arg3;
|
||||
const TernaryOp m_functor;
|
||||
};
|
||||
|
||||
// Generic API dispatcher
|
||||
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
|
||||
typename StorageKind>
|
||||
class CwiseTernaryOpImpl
|
||||
: public internal::generic_xpr_base<
|
||||
CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type {
|
||||
public:
|
||||
typedef typename internal::generic_xpr_base<
|
||||
CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base;
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CWISE_TERNARY_OP_H
|
@ -83,6 +83,7 @@ struct default_packet_traits
|
||||
HasErfc = 0,
|
||||
HasIGamma = 0,
|
||||
HasIGammac = 0,
|
||||
HasBetaInc = 0,
|
||||
|
||||
HasRound = 0,
|
||||
HasFloor = 0,
|
||||
@ -466,6 +467,10 @@ Packet pigamma(const Packet& a, const Packet& x) { using numext::igamma; return
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Packet pigammac(const Packet& a, const Packet& x) { using numext::igammac; return igammac(a, x); }
|
||||
|
||||
/** \internal \returns the complementary incomplete gamma function betainc(\a a, \a b, \a x) */
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Packet pbetainc(const Packet& a, const Packet& b,const Packet& x) { using numext::betainc; return betainc(a, b, x); }
|
||||
|
||||
/***************************************************************************
|
||||
* The following functions might not have to be overwritten for vectorized types
|
||||
***************************************************************************/
|
||||
|
@ -213,6 +213,28 @@ namespace Eigen
|
||||
);
|
||||
}
|
||||
|
||||
/** \cpp11 \returns an expression of the coefficient-wise betainc(\a x, \a a, \a b) to the given arrays.
|
||||
*
|
||||
* This function computes the regularized incomplete beta function (integral).
|
||||
*
|
||||
* \note This function supports only float and double scalar types in c++11 mode. To support other scalar types,
|
||||
* or float/double in non c++11 mode, the user has to provide implementations of betainc(T,T,T) for any scalar
|
||||
* type T to be supported.
|
||||
*
|
||||
* \sa Eigen::betainc(), Eigen::lgamma()
|
||||
*/
|
||||
template<typename ArgADerived, typename ArgBDerived, typename ArgXDerived>
|
||||
inline const Eigen::CwiseTernaryOp<Eigen::internal::scalar_betainc_op<typename ArgXDerived::Scalar>, const ArgADerived, const ArgBDerived, const ArgXDerived>
|
||||
betainc(const Eigen::ArrayBase<ArgADerived>& a, const Eigen::ArrayBase<ArgBDerived>& b, const Eigen::ArrayBase<ArgXDerived>& x)
|
||||
{
|
||||
return Eigen::CwiseTernaryOp<Eigen::internal::scalar_betainc_op<typename ArgXDerived::Scalar>, const ArgADerived, const ArgBDerived, const ArgXDerived>(
|
||||
a.derived(),
|
||||
b.derived(),
|
||||
x.derived()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/** \returns an expression of the coefficient-wise zeta(\a x, \a q) to the given arrays.
|
||||
*
|
||||
* It returns the Riemann zeta function of two arguments \a x and \a q:
|
||||
|
@ -392,17 +392,19 @@ struct igammac_retval {
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
// NOTE: igamma_helper is also used to implement zeta
|
||||
// NOTE: cephes_helper is also used to implement zeta
|
||||
template <typename Scalar>
|
||||
struct igamma_helper {
|
||||
struct cephes_helper {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar biginv() { assert(false && "biginv not supported for this type"); return 0.0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct igamma_helper<float> {
|
||||
struct cephes_helper<float> {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE float machep() {
|
||||
return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
|
||||
@ -412,10 +414,15 @@ struct igamma_helper<float> {
|
||||
// use epsneg (1.0 - epsneg == 1.0)
|
||||
return 1.0f / (NumTraits<float>::epsilon() / 2);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE float biginv() {
|
||||
// epsneg
|
||||
return machep();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct igamma_helper<double> {
|
||||
struct cephes_helper<double> {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE double machep() {
|
||||
return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
|
||||
@ -424,6 +431,11 @@ struct igamma_helper<double> {
|
||||
static EIGEN_STRONG_INLINE double big() {
|
||||
return 1.0 / NumTraits<double>::epsilon();
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE double biginv() {
|
||||
// inverse of eps
|
||||
return NumTraits<double>::epsilon();
|
||||
}
|
||||
};
|
||||
|
||||
#if !EIGEN_HAS_C99_MATH
|
||||
@ -538,10 +550,10 @@ struct igammac_impl {
|
||||
const Scalar zero = 0;
|
||||
const Scalar one = 1;
|
||||
const Scalar two = 2;
|
||||
const Scalar machep = igamma_helper<Scalar>::machep();
|
||||
const Scalar machep = cephes_helper<Scalar>::machep();
|
||||
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
|
||||
const Scalar big = igamma_helper<Scalar>::big();
|
||||
const Scalar biginv = 1 / big;
|
||||
const Scalar big = cephes_helper<Scalar>::big();
|
||||
const Scalar biginv = cephes_helper<Scalar>::biginv();
|
||||
const Scalar inf = NumTraits<Scalar>::infinity();
|
||||
|
||||
Scalar ans, ax, c, yc, r, t, y, z;
|
||||
@ -590,7 +602,9 @@ struct igammac_impl {
|
||||
qkm2 *= biginv;
|
||||
qkm1 *= biginv;
|
||||
}
|
||||
if (t <= machep) break;
|
||||
if (t <= machep) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return (ans * ax);
|
||||
@ -724,7 +738,7 @@ struct igamma_impl {
|
||||
EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) {
|
||||
const Scalar zero = 0;
|
||||
const Scalar one = 1;
|
||||
const Scalar machep = igamma_helper<Scalar>::machep();
|
||||
const Scalar machep = cephes_helper<Scalar>::machep();
|
||||
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
|
||||
|
||||
Scalar ans, ax, c, r;
|
||||
@ -746,7 +760,9 @@ struct igamma_impl {
|
||||
r += one;
|
||||
c *= x/r;
|
||||
ans += c;
|
||||
if (c/ans <= machep) break;
|
||||
if (c/ans <= machep) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return (ans * ax / a);
|
||||
@ -899,7 +915,7 @@ struct zeta_impl {
|
||||
|
||||
const Scalar maxnum = NumTraits<Scalar>::infinity();
|
||||
const Scalar zero = 0.0, half = 0.5, one = 1.0;
|
||||
const Scalar machep = igamma_helper<Scalar>::machep();
|
||||
const Scalar machep = cephes_helper<Scalar>::machep();
|
||||
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
|
||||
|
||||
if( x == one )
|
||||
@ -947,8 +963,9 @@ struct zeta_impl {
|
||||
t = a*b/A[i];
|
||||
s = s + t;
|
||||
t = numext::abs(t/s);
|
||||
if( t < machep )
|
||||
return s;
|
||||
if( t < machep ) {
|
||||
break;
|
||||
}
|
||||
k += one;
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
@ -1007,6 +1024,467 @@ struct polygamma_impl {
|
||||
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
|
||||
/************************************************************************************************
|
||||
* Implementation of betainc (incomplete beta integral), based on Cephes but requires C++11/C99 *
|
||||
************************************************************************************************/
|
||||
|
||||
template <typename Scalar>
|
||||
struct betainc_retval {
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
#if !EIGEN_HAS_C99_MATH
|
||||
|
||||
template <typename Scalar>
|
||||
struct betainc_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar>
|
||||
struct betainc_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
|
||||
/* betaincf.c
|
||||
*
|
||||
* Incomplete beta integral
|
||||
*
|
||||
*
|
||||
* SYNOPSIS:
|
||||
*
|
||||
* float a, b, x, y, betaincf();
|
||||
*
|
||||
* y = betaincf( a, b, x );
|
||||
*
|
||||
*
|
||||
* DESCRIPTION:
|
||||
*
|
||||
* Returns incomplete beta integral of the arguments, evaluated
|
||||
* from zero to x. The function is defined as
|
||||
*
|
||||
* x
|
||||
* - -
|
||||
* | (a+b) | | a-1 b-1
|
||||
* ----------- | t (1-t) dt.
|
||||
* - - | |
|
||||
* | (a) | (b) -
|
||||
* 0
|
||||
*
|
||||
* The domain of definition is 0 <= x <= 1. In this
|
||||
* implementation a and b are restricted to positive values.
|
||||
* The integral from x to 1 may be obtained by the symmetry
|
||||
* relation
|
||||
*
|
||||
* 1 - betainc( a, b, x ) = betainc( b, a, 1-x ).
|
||||
*
|
||||
* The integral is evaluated by a continued fraction expansion.
|
||||
* If a < 1, the function calls itself recursively after a
|
||||
* transformation to increase a to a+1.
|
||||
*
|
||||
* ACCURACY (float):
|
||||
*
|
||||
* Tested at random points (a,b,x) with a and b in the indicated
|
||||
* interval and x between 0 and 1.
|
||||
*
|
||||
* arithmetic domain # trials peak rms
|
||||
* Relative error:
|
||||
* IEEE 0,30 10000 3.7e-5 5.1e-6
|
||||
* IEEE 0,100 10000 1.7e-4 2.5e-5
|
||||
* The useful domain for relative error is limited by underflow
|
||||
* of the single precision exponential function.
|
||||
* Absolute error:
|
||||
* IEEE 0,30 100000 2.2e-5 9.6e-7
|
||||
* IEEE 0,100 10000 6.5e-5 3.7e-6
|
||||
*
|
||||
* Larger errors may occur for extreme ratios of a and b.
|
||||
*
|
||||
* ACCURACY (double):
|
||||
* arithmetic domain # trials peak rms
|
||||
* IEEE 0,5 10000 6.9e-15 4.5e-16
|
||||
* IEEE 0,85 250000 2.2e-13 1.7e-14
|
||||
* IEEE 0,1000 30000 5.3e-12 6.3e-13
|
||||
* IEEE 0,10000 250000 9.3e-11 7.1e-12
|
||||
* IEEE 0,100000 10000 8.7e-10 4.8e-11
|
||||
* Outputs smaller than the IEEE gradual underflow threshold
|
||||
* were excluded from these statistics.
|
||||
*
|
||||
* ERROR MESSAGES:
|
||||
* message condition value returned
|
||||
* incbet domain x<0, x>1 nan
|
||||
* incbet underflow nan
|
||||
*/
|
||||
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
}
|
||||
};
|
||||
|
||||
/* Continued fraction expansion #1 for incomplete beta integral (small_branch = True)
|
||||
* Continued fraction expansion #2 for incomplete beta integral (small_branch = False)
|
||||
*/
|
||||
template <typename Scalar>
|
||||
struct incbeta_cfe {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x, bool small_branch) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value ||
|
||||
internal::is_same<Scalar, double>::value),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
const Scalar big = cephes_helper<Scalar>::big();
|
||||
const Scalar machep = cephes_helper<Scalar>::machep();
|
||||
const Scalar biginv = cephes_helper<Scalar>::biginv();
|
||||
|
||||
const Scalar zero = 0;
|
||||
const Scalar one = 1;
|
||||
const Scalar two = 2;
|
||||
|
||||
Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
|
||||
Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
|
||||
Scalar ans;
|
||||
int n;
|
||||
|
||||
const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
|
||||
const Scalar thresh =
|
||||
(internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
|
||||
Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
|
||||
|
||||
if (small_branch) {
|
||||
k1 = a;
|
||||
k2 = a + b;
|
||||
k3 = a;
|
||||
k4 = a + one;
|
||||
k5 = one;
|
||||
k6 = b - one;
|
||||
k7 = k4;
|
||||
k8 = a + two;
|
||||
k26update = one;
|
||||
} else {
|
||||
k1 = a;
|
||||
k2 = b - one;
|
||||
k3 = a;
|
||||
k4 = a + one;
|
||||
k5 = one;
|
||||
k6 = a + b;
|
||||
k7 = a + one;
|
||||
k8 = a + two;
|
||||
k26update = -one;
|
||||
x = x / (one - x);
|
||||
}
|
||||
|
||||
pkm2 = zero;
|
||||
qkm2 = one;
|
||||
pkm1 = one;
|
||||
qkm1 = one;
|
||||
ans = one;
|
||||
n = 0;
|
||||
|
||||
do {
|
||||
xk = -(x * k1 * k2) / (k3 * k4);
|
||||
pk = pkm1 + pkm2 * xk;
|
||||
qk = qkm1 + qkm2 * xk;
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
|
||||
xk = (x * k5 * k6) / (k7 * k8);
|
||||
pk = pkm1 + pkm2 * xk;
|
||||
qk = qkm1 + qkm2 * xk;
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
|
||||
if (qk != zero) {
|
||||
r = pk / qk;
|
||||
if (numext::abs(ans - r) < numext::abs(r) * thresh) {
|
||||
return r;
|
||||
}
|
||||
ans = r;
|
||||
}
|
||||
|
||||
k1 += one;
|
||||
k2 += k26update;
|
||||
k3 += two;
|
||||
k4 += two;
|
||||
k5 += one;
|
||||
k6 -= k26update;
|
||||
k7 += two;
|
||||
k8 += two;
|
||||
|
||||
if ((numext::abs(qk) + numext::abs(pk)) > big) {
|
||||
pkm2 *= biginv;
|
||||
pkm1 *= biginv;
|
||||
qkm2 *= biginv;
|
||||
qkm1 *= biginv;
|
||||
}
|
||||
if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
|
||||
pkm2 *= big;
|
||||
pkm1 *= big;
|
||||
qkm2 *= big;
|
||||
qkm1 *= big;
|
||||
}
|
||||
} while (++n < num_iters);
|
||||
|
||||
return ans;
|
||||
}
|
||||
};
|
||||
|
||||
/* Helper functions depending on the Scalar type */
|
||||
template <typename Scalar>
|
||||
struct betainc_helper {};
|
||||
|
||||
template <>
|
||||
struct betainc_helper<float> {
|
||||
/* Core implementation, assumes a large (> 1.0) */
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float incbsa(float aa, float bb,
|
||||
float xx) {
|
||||
float ans, a, b, t, x, onemx;
|
||||
bool reversed_a_b = false;
|
||||
|
||||
onemx = 1.0f - xx;
|
||||
|
||||
/* see if x is greater than the mean */
|
||||
if (xx > (aa / (aa + bb))) {
|
||||
reversed_a_b = true;
|
||||
a = bb;
|
||||
b = aa;
|
||||
t = xx;
|
||||
x = onemx;
|
||||
} else {
|
||||
a = aa;
|
||||
b = bb;
|
||||
t = onemx;
|
||||
x = xx;
|
||||
}
|
||||
|
||||
/* Choose expansion for optimal convergence */
|
||||
if (b > 10.0f) {
|
||||
if (numext::abs(b * x / a) < 0.3f) {
|
||||
t = betainc_helper<float>::incbps(a, b, x);
|
||||
if (reversed_a_b) t = 1.0f - t;
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
ans = x * (a + b - 2.0f) / (a - 1.0f);
|
||||
if (ans < 1.0f) {
|
||||
ans = incbeta_cfe<float>::run(a, b, x, true /* small_branch */);
|
||||
t = b * numext::log(t);
|
||||
} else {
|
||||
ans = incbeta_cfe<float>::run(a, b, x, false /* small_branch */);
|
||||
t = (b - 1.0f) * numext::log(t);
|
||||
}
|
||||
|
||||
t += a * numext::log(x) + lgamma_impl<float>::run(a + b) -
|
||||
lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
|
||||
t += numext::log(ans / a);
|
||||
t = numext::exp(t);
|
||||
|
||||
if (reversed_a_b) t = 1.0f - t;
|
||||
return t;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE float incbps(float a, float b, float x) {
|
||||
float t, u, y, s;
|
||||
const float machep = cephes_helper<float>::machep();
|
||||
|
||||
y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
|
||||
y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
|
||||
y += lgamma_impl<float>::run(a + b);
|
||||
|
||||
t = x / (1.0f - x);
|
||||
s = 0.0f;
|
||||
u = 1.0f;
|
||||
do {
|
||||
b -= 1.0f;
|
||||
if (b == 0.0f) {
|
||||
break;
|
||||
}
|
||||
a += 1.0f;
|
||||
u *= t * b / a;
|
||||
s += u;
|
||||
} while (numext::abs(u) > machep);
|
||||
|
||||
return numext::exp(y) * (1.0f + s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct betainc_impl<float> {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static float run(float a, float b, float x) {
|
||||
const float nan = NumTraits<float>::quiet_NaN();
|
||||
float ans, t;
|
||||
|
||||
if (a <= 0.0f) return nan;
|
||||
if (b <= 0.0f) return nan;
|
||||
if ((x <= 0.0f) || (x >= 1.0f)) {
|
||||
if (x == 0.0f) return 0.0f;
|
||||
if (x == 1.0f) return 1.0f;
|
||||
// mtherr("betaincf", DOMAIN);
|
||||
return nan;
|
||||
}
|
||||
|
||||
/* transformation for small aa */
|
||||
if (a <= 1.0f) {
|
||||
ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
|
||||
t = a * numext::log(x) + b * numext::log1p(-x) +
|
||||
lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a + 1.0f) -
|
||||
lgamma_impl<float>::run(b);
|
||||
return (ans + numext::exp(t));
|
||||
} else {
|
||||
return betainc_helper<float>::incbsa(a, b, x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct betainc_helper<double> {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE double incbps(double a, double b, double x) {
|
||||
const double machep = cephes_helper<double>::machep();
|
||||
|
||||
double s, t, u, v, n, t1, z, ai;
|
||||
|
||||
ai = 1.0 / a;
|
||||
u = (1.0 - b) * x;
|
||||
v = u / (a + 1.0);
|
||||
t1 = v;
|
||||
t = u;
|
||||
n = 2.0;
|
||||
s = 0.0;
|
||||
z = machep * ai;
|
||||
while (numext::abs(v) > z) {
|
||||
u = (n - b) * x / n;
|
||||
t *= u;
|
||||
v = t / (a + n);
|
||||
s += v;
|
||||
n += 1.0;
|
||||
}
|
||||
s += t1;
|
||||
s += ai;
|
||||
|
||||
u = a * numext::log(x);
|
||||
// TODO: gamma() is not directly implemented in Eigen.
|
||||
/*
|
||||
if ((a + b) < maxgam && numext::abs(u) < maxlog) {
|
||||
t = gamma(a + b) / (gamma(a) * gamma(b));
|
||||
s = s * t * pow(x, a);
|
||||
} else {
|
||||
*/
|
||||
t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
|
||||
lgamma_impl<double>::run(b) + u + numext::log(s);
|
||||
return s = exp(t);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct betainc_impl<double> {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static double run(double aa, double bb, double xx) {
|
||||
const double nan = NumTraits<double>::quiet_NaN();
|
||||
const double machep = cephes_helper<double>::machep();
|
||||
// const double maxgam = 171.624376956302725;
|
||||
|
||||
double a, b, t, x, xc, w, y;
|
||||
bool reversed_a_b = false;
|
||||
|
||||
if (aa <= 0.0 || bb <= 0.0) {
|
||||
return nan; // goto domerr;
|
||||
}
|
||||
|
||||
if ((xx <= 0.0) || (xx >= 1.0)) {
|
||||
if (xx == 0.0) return (0.0);
|
||||
if (xx == 1.0) return (1.0);
|
||||
// mtherr("incbet", DOMAIN);
|
||||
return nan;
|
||||
}
|
||||
|
||||
if ((bb * xx) <= 1.0 && xx <= 0.95) {
|
||||
return betainc_helper<double>::incbps(aa, bb, xx);
|
||||
}
|
||||
|
||||
w = 1.0 - xx;
|
||||
|
||||
/* Reverse a and b if x is greater than the mean. */
|
||||
if (xx > (aa / (aa + bb))) {
|
||||
reversed_a_b = true;
|
||||
a = bb;
|
||||
b = aa;
|
||||
xc = xx;
|
||||
x = w;
|
||||
} else {
|
||||
a = aa;
|
||||
b = bb;
|
||||
xc = w;
|
||||
x = xx;
|
||||
}
|
||||
|
||||
if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
|
||||
t = betainc_helper<double>::incbps(a, b, x);
|
||||
if (t <= machep) {
|
||||
t = 1.0 - machep;
|
||||
} else {
|
||||
t = 1.0 - t;
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
/* Choose expansion for better convergence. */
|
||||
y = x * (a + b - 2.0) - (a - 1.0);
|
||||
if (y < 0.0) {
|
||||
w = incbeta_cfe<double>::run(a, b, x, true /* small_branch */);
|
||||
} else {
|
||||
w = incbeta_cfe<double>::run(a, b, x, false /* small_branch */) / xc;
|
||||
}
|
||||
|
||||
/* Multiply w by the factor
|
||||
a b _ _ _
|
||||
x (1-x) | (a+b) / ( a | (a) | (b) ) . */
|
||||
|
||||
y = a * numext::log(x);
|
||||
t = b * numext::log(xc);
|
||||
// TODO: gamma is not directly implemented in Eigen.
|
||||
/*
|
||||
if ((a + b) < maxgam && numext::abs(y) < maxlog && numext::abs(t) < maxlog)
|
||||
{
|
||||
t = pow(xc, b);
|
||||
t *= pow(x, a);
|
||||
t /= a;
|
||||
t *= w;
|
||||
t *= gamma(a + b) / (gamma(a) * gamma(b));
|
||||
} else {
|
||||
*/
|
||||
/* Resort to logarithms. */
|
||||
y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
|
||||
lgamma_impl<double>::run(b);
|
||||
y += numext::log(w / a);
|
||||
t = numext::exp(y);
|
||||
|
||||
/* } */
|
||||
// done:
|
||||
|
||||
if (reversed_a_b) {
|
||||
if (t <= machep) {
|
||||
t = 1.0 - machep;
|
||||
} else {
|
||||
t = 1.0 - t;
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
namespace numext {
|
||||
@ -1022,7 +1500,7 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(digamma, Scalar)
|
||||
digamma(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(digamma, Scalar)::run(x);
|
||||
}
|
||||
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(zeta, Scalar)
|
||||
zeta(const Scalar& x, const Scalar& q) {
|
||||
@ -1059,6 +1537,12 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar)
|
||||
return EIGEN_MATHFUNC_IMPL(igammac, Scalar)::run(a, x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(betainc, Scalar)
|
||||
betainc(const Scalar& a, const Scalar& b, const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(betainc, Scalar)::run(a, b, x);
|
||||
}
|
||||
|
||||
} // end namespace numext
|
||||
|
||||
|
||||
|
@ -480,6 +480,9 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen:
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) {
|
||||
return Eigen::half(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half betainc(const Eigen::half& a, const Eigen::half& b, const Eigen::half& x) {
|
||||
return Eigen::half(Eigen::numext::betainc(static_cast<float>(a), static_cast<float>(b), static_cast<float>(x)));
|
||||
}
|
||||
#endif
|
||||
} // end namespace numext
|
||||
|
||||
|
@ -181,6 +181,24 @@ double2 pigammac<double2>(const double2& a, const double2& x)
|
||||
return make_double2(igammac(a.x, x.x), igammac(a.y, x.y));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
float4 pbetainc<float4>(const float4& a, const float4& b, const float4& x)
|
||||
{
|
||||
using numext::betainc;
|
||||
return make_float4(
|
||||
betainc(a.x, b.x, x.x),
|
||||
betainc(a.y, b.y, x.y),
|
||||
betainc(a.z, b.z, x.z),
|
||||
betainc(a.w, b.w, x.w));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
double2 pbetainc<double2>(const double2& a, const double2& b, const double2& x)
|
||||
{
|
||||
using numext::betainc;
|
||||
return make_double2(betainc(a.x, b.x, x.x), betainc(a.y, b.y, x.y));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
@ -44,8 +44,9 @@ template<> struct packet_traits<float> : default_packet_traits
|
||||
HasPolygamma = 1,
|
||||
HasErf = 1,
|
||||
HasErfc = 1,
|
||||
HasIgamma = 1,
|
||||
HasIGamma = 1,
|
||||
HasIGammac = 1,
|
||||
HasBetaInc = 1,
|
||||
|
||||
HasBlend = 0,
|
||||
};
|
||||
@ -68,10 +69,13 @@ template<> struct packet_traits<double> : default_packet_traits
|
||||
HasRsqrt = 1,
|
||||
HasLGamma = 1,
|
||||
HasDiGamma = 1,
|
||||
HasZeta = 1,
|
||||
HasPolygamma = 1,
|
||||
HasErf = 1,
|
||||
HasErfc = 1,
|
||||
HasIGamma = 1,
|
||||
HasIGammac = 1,
|
||||
HasBetaInc = 1,
|
||||
|
||||
HasBlend = 0,
|
||||
};
|
||||
|
@ -372,7 +372,7 @@ template<typename Scalar> struct scalar_igamma_op {
|
||||
}
|
||||
template<typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
|
||||
return internal::pigammac(a, x);
|
||||
return internal::pigamma(a, x);
|
||||
}
|
||||
};
|
||||
template<typename Scalar>
|
||||
|
47
Eigen/src/Core/functors/TernaryFunctors.h
Normal file
47
Eigen/src/Core/functors/TernaryFunctors.h
Normal file
@ -0,0 +1,47 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_TERNARY_FUNCTORS_H
|
||||
#define EIGEN_TERNARY_FUNCTORS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
//---------- associative ternary functors ----------
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the incomplete beta integral betainc(a, b, x)
|
||||
*
|
||||
*/
|
||||
template<typename Scalar> struct scalar_betainc_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_betainc_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& a, const Scalar& b) const {
|
||||
using numext::betainc; return betainc(x, a, b);
|
||||
}
|
||||
template<typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& x, const Packet& a, const Packet& b) const
|
||||
{
|
||||
return internal::pbetainc(x, a, b);
|
||||
}
|
||||
};
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_betainc_op<Scalar> > {
|
||||
enum {
|
||||
// Guesstimate
|
||||
Cost = 400 * NumTraits<Scalar>::MulCost + 400 * NumTraits<Scalar>::AddCost,
|
||||
PacketAccess = packet_traits<Scalar>::HasBetaInc
|
||||
};
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_TERNARY_FUNCTORS_H
|
@ -91,6 +91,7 @@ template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
|
||||
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
||||
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
||||
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
||||
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
|
||||
template<typename Decomposition, typename Rhstype> class Solve;
|
||||
template<typename XprType> class Inverse;
|
||||
|
||||
@ -208,6 +209,7 @@ template<typename Scalar> struct scalar_identity_op;
|
||||
template<typename Scalar,bool iscpx> struct scalar_sign_op;
|
||||
template<typename Scalar> struct scalar_igamma_op;
|
||||
template<typename Scalar> struct scalar_igammac_op;
|
||||
template<typename Scalar> struct scalar_betainc_op;
|
||||
|
||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
|
||||
template<typename LhsScalar,typename RhsScalar> struct scalar_multiple2_op;
|
||||
|
@ -98,7 +98,9 @@
|
||||
EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT__INVALID_COST_VALUE,
|
||||
THIS_COEFFICIENT_ACCESSOR_TAKING_ONE_ACCESS_IS_ONLY_FOR_EXPRESSIONS_ALLOWING_LINEAR_ACCESS,
|
||||
MATRIX_FREE_CONJUGATE_GRADIENT_IS_COMPATIBLE_WITH_UPPER_UNION_LOWER_MODE_ONLY,
|
||||
THIS_TYPE_IS_NOT_SUPPORTED
|
||||
THIS_TYPE_IS_NOT_SUPPORTED,
|
||||
STORAGE_KIND_MUST_MATCH,
|
||||
STORAGE_INDEX_MUST_MATCH
|
||||
};
|
||||
};
|
||||
|
||||
|
113
test/array.cpp
113
test/array.cpp
@ -592,16 +592,123 @@ template<typename ArrayType> void array_special_functions()
|
||||
ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
|
||||
CALL_SUBTEST( verify_component_wise(ref, ref); );
|
||||
|
||||
if(sizeof(RealScalar)>=64) {
|
||||
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
|
||||
if(sizeof(RealScalar)>=8) { // double
|
||||
// Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
|
||||
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
|
||||
CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
|
||||
}
|
||||
else {
|
||||
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
|
||||
// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
|
||||
CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if EIGEN_HAS_C99_MATH
|
||||
{
|
||||
// Inputs and ground truth generated with scipy via:
|
||||
// a = np.logspace(-3, 3, 5) - 1e-3
|
||||
// b = np.logspace(-3, 3, 5) - 1e-3
|
||||
// x = np.linspace(-0.1, 1.1, 5)
|
||||
// (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
|
||||
// full_a = full_a.flatten().tolist() # same for full_b, full_x
|
||||
// v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
|
||||
//
|
||||
// Note in Eigen, we call betainc with arguments in the order (x, a, b).
|
||||
ArrayType a(125);
|
||||
ArrayType b(125);
|
||||
ArrayType x(125);
|
||||
ArrayType v(125);
|
||||
ArrayType res(125);
|
||||
|
||||
a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999;
|
||||
|
||||
b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
|
||||
999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
|
||||
0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999;
|
||||
|
||||
x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
|
||||
0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
|
||||
0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
|
||||
0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
|
||||
-0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
|
||||
1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
|
||||
0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
|
||||
0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
|
||||
0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
|
||||
0.8, 1.1;
|
||||
|
||||
v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
|
||||
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
|
||||
nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
|
||||
0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
|
||||
0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
|
||||
0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
|
||||
nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
|
||||
0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
|
||||
0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
|
||||
0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
|
||||
0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
|
||||
1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
|
||||
nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
|
||||
0.0008598571564165444, nan, nan, 6.031987710123844e-08,
|
||||
0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
|
||||
0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
|
||||
nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
|
||||
0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
|
||||
3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
|
||||
2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
|
||||
|
||||
CALL_SUBTEST(res = betainc(a, b, x);
|
||||
verify_component_wise(res, v););
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void test_array()
|
||||
|
@ -80,6 +80,7 @@ typedef unsigned __int64 uint64_t;
|
||||
#include "src/Tensor/TensorTraits.h"
|
||||
#include "src/Tensor/TensorUInt128.h"
|
||||
#include "src/Tensor/TensorIntDiv.h"
|
||||
#include "src/Tensor/TensorGlobalFunctions.h"
|
||||
|
||||
#include "src/Tensor/TensorBase.h"
|
||||
|
||||
|
@ -307,7 +307,6 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return unaryExpr(internal::scalar_floor_op<Scalar>());
|
||||
}
|
||||
|
||||
|
||||
// Generic binary operation support.
|
||||
template <typename CustomBinaryOp, typename OtherDerived> EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<CustomBinaryOp, const Derived, const OtherDerived>
|
||||
|
@ -403,6 +403,87 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
|
||||
TensorEvaluator<RightArgType, Device> m_rightImpl;
|
||||
};
|
||||
|
||||
// -------------------- CwiseTernaryOp --------------------
|
||||
|
||||
template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
|
||||
struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
|
||||
{
|
||||
typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
|
||||
|
||||
enum {
|
||||
IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
|
||||
PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
|
||||
internal::functor_traits<TernaryOp>::PacketAccess,
|
||||
Layout = TensorEvaluator<Arg1Type, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
RawAccess = false
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_functor(op.functor()),
|
||||
m_arg1Impl(op.arg1Expression(), device),
|
||||
m_arg2Impl(op.arg2Expression(), device),
|
||||
m_arg3Impl(op.arg3Expression(), device)
|
||||
{
|
||||
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
|
||||
}
|
||||
|
||||
typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
|
||||
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
|
||||
static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||
typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
|
||||
|
||||
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
|
||||
{
|
||||
// TODO: use arg2 or arg3 dimensions if they are known at compile time.
|
||||
return m_arg1Impl.dimensions();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
|
||||
m_arg1Impl.evalSubExprsIfNeeded(NULL);
|
||||
m_arg2Impl.evalSubExprsIfNeeded(NULL);
|
||||
m_arg3Impl.evalSubExprsIfNeeded(NULL);
|
||||
return true;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
m_arg1Impl.cleanup();
|
||||
m_arg2Impl.cleanup();
|
||||
m_arg3Impl.cleanup();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
|
||||
}
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
|
||||
m_arg2Impl.template packet<LoadMode>(index),
|
||||
m_arg3Impl.template packet<LoadMode>(index));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
|
||||
costPerCoeff(bool vectorized) const {
|
||||
const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
|
||||
return m_arg1Impl.costPerCoeff(vectorized) +
|
||||
m_arg2Impl.costPerCoeff(vectorized) +
|
||||
m_arg3Impl.costPerCoeff(vectorized) +
|
||||
TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
|
||||
|
||||
private:
|
||||
const TernaryOp m_functor;
|
||||
TensorEvaluator<Arg1Type, Device> m_arg1Impl;
|
||||
TensorEvaluator<Arg1Type, Device> m_arg2Impl;
|
||||
TensorEvaluator<Arg3Type, Device> m_arg3Impl;
|
||||
};
|
||||
|
||||
|
||||
// -------------------- SelectOp --------------------
|
||||
|
||||
|
@ -218,6 +218,102 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
|
||||
};
|
||||
|
||||
|
||||
namespace internal {
|
||||
template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
|
||||
struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
|
||||
{
|
||||
// Type promotion to handle the case where the types of the args are different.
|
||||
typedef typename result_of<
|
||||
TernaryOp(typename Arg1XprType::Scalar,
|
||||
typename Arg2XprType::Scalar,
|
||||
typename Arg3XprType::Scalar)>::type Scalar;
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename traits<Arg1XprType>::StorageKind,
|
||||
typename traits<Arg2XprType>::StorageKind>::value),
|
||||
STORAGE_KIND_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename traits<Arg1XprType>::StorageKind,
|
||||
typename traits<Arg3XprType>::StorageKind>::value),
|
||||
STORAGE_KIND_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename traits<Arg1XprType>::Index,
|
||||
typename traits<Arg2XprType>::Index>::value),
|
||||
STORAGE_INDEX_MUST_MATCH)
|
||||
EIGEN_STATIC_ASSERT(
|
||||
(internal::is_same<typename traits<Arg1XprType>::Index,
|
||||
typename traits<Arg3XprType>::Index>::value),
|
||||
STORAGE_INDEX_MUST_MATCH)
|
||||
typedef traits<Arg1XprType> XprTraits;
|
||||
typedef typename traits<Arg1XprType>::StorageKind StorageKind;
|
||||
typedef typename traits<Arg1XprType>::Index Index;
|
||||
typedef typename Arg1XprType::Nested Arg1Nested;
|
||||
typedef typename Arg2XprType::Nested Arg2Nested;
|
||||
typedef typename Arg3XprType::Nested Arg3Nested;
|
||||
typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
|
||||
typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
|
||||
typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
|
||||
static const int NumDimensions = XprTraits::NumDimensions;
|
||||
static const int Layout = XprTraits::Layout;
|
||||
|
||||
enum {
|
||||
Flags = 0
|
||||
};
|
||||
};
|
||||
|
||||
template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
|
||||
struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
|
||||
{
|
||||
typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
|
||||
};
|
||||
|
||||
template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
|
||||
struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
|
||||
{
|
||||
typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
|
||||
class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef Scalar CoeffReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
|
||||
: m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const TernaryOp& functor() const { return m_functor; }
|
||||
|
||||
/** \returns the nested expressions */
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename Arg1XprType::Nested>::type&
|
||||
arg1Expression() const { return m_arg1_xpr; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename Arg1XprType::Nested>::type&
|
||||
arg2Expression() const { return m_arg2_xpr; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename Arg3XprType::Nested>::type&
|
||||
arg3Expression() const { return m_arg3_xpr; }
|
||||
|
||||
protected:
|
||||
typename Arg1XprType::Nested m_arg1_xpr;
|
||||
typename Arg1XprType::Nested m_arg2_xpr;
|
||||
typename Arg3XprType::Nested m_arg3_xpr;
|
||||
const TernaryOp m_functor;
|
||||
};
|
||||
|
||||
|
||||
namespace internal {
|
||||
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
|
||||
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
|
||||
|
@ -21,6 +21,7 @@ template<typename Derived, int AccessLevel = internal::accessors_level<Derived>:
|
||||
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
|
||||
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
|
||||
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
|
||||
template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> class TensorCwiseTernaryOp;
|
||||
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
|
||||
template<typename Op, typename Dims, typename XprType> class TensorReductionOp;
|
||||
template<typename XprType> class TensorIndexTupleOp;
|
||||
|
33
unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
Normal file
33
unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
Normal file
@ -0,0 +1,33 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
|
||||
#define EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
/** \cpp11 \returns an expression of the coefficient-wise betainc(\a x, \a a, \a b) to the given tensors.
|
||||
*
|
||||
* This function computes the regularized incomplete beta function (integral).
|
||||
*
|
||||
*/
|
||||
template <typename ADerived, typename BDerived, typename XDerived>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const
|
||||
TensorCwiseTernaryOp<internal::scalar_betainc_op<typename XDerived::Scalar>,
|
||||
const ADerived, const BDerived, const XDerived>
|
||||
betainc(const ADerived& a, const BDerived& b, const XDerived& x) {
|
||||
return TensorCwiseTernaryOp<
|
||||
internal::scalar_betainc_op<typename XDerived::Scalar>, const ADerived,
|
||||
const BDerived, const XDerived>(
|
||||
a, b, x, internal::scalar_betainc_op<typename XDerived::Scalar>());
|
||||
}
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
|
@ -1019,6 +1019,153 @@ void test_cuda_erfc(const Scalar stddev)
|
||||
cudaFree(d_out);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void test_cuda_betainc()
|
||||
{
|
||||
Tensor<Scalar, 1> in_x(125);
|
||||
Tensor<Scalar, 1> in_a(125);
|
||||
Tensor<Scalar, 1> in_b(125);
|
||||
Tensor<Scalar, 1> out(125);
|
||||
Tensor<Scalar, 1> expected_out(125);
|
||||
out.setZero();
|
||||
|
||||
Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
|
||||
|
||||
Array<Scalar, 1, Dynamic> x(125);
|
||||
Array<Scalar, 1, Dynamic> a(125);
|
||||
Array<Scalar, 1, Dynamic> b(125);
|
||||
Array<Scalar, 1, Dynamic> v(125);
|
||||
|
||||
a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999;
|
||||
|
||||
b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
|
||||
0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
|
||||
0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
|
||||
31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
|
||||
999.999, 999.999, 999.999;
|
||||
|
||||
x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
|
||||
1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
|
||||
0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
|
||||
0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
|
||||
0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
|
||||
-0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
|
||||
1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
|
||||
0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
|
||||
0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1;
|
||||
|
||||
v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
|
||||
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
|
||||
nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
|
||||
0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
|
||||
0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
|
||||
0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan, nan,
|
||||
nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
|
||||
0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
|
||||
0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
|
||||
0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
|
||||
0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
|
||||
1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06, nan,
|
||||
nan, 7.864342668429763e-23, 3.015969667594166e-10, 0.0008598571564165444,
|
||||
nan, nan, 6.031987710123844e-08, 0.5000000000000007, 0.9999999396801229,
|
||||
nan, nan, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan,
|
||||
nan, nan, nan, nan, nan, nan, 0.0, 7.029920380986636e-306,
|
||||
2.2450728208591345e-101, nan, nan, 0.0, 9.275871147869727e-302,
|
||||
1.2232913026152827e-97, nan, nan, 0.0, 3.0891393081932924e-252,
|
||||
2.9303043666183996e-60, nan, nan, 2.248913486879199e-196,
|
||||
0.5000000000004947, 0.9999999999999999, nan;
|
||||
|
||||
for (int i = 0; i < 125; ++i) {
|
||||
in_x(i) = x(i);
|
||||
in_a(i) = a(i);
|
||||
in_b(i) = b(i);
|
||||
expected_out(i) = v(i);
|
||||
}
|
||||
|
||||
std::size_t bytes = in_x.size() * sizeof(Scalar);
|
||||
|
||||
Scalar* d_in_x;
|
||||
Scalar* d_in_a;
|
||||
Scalar* d_in_b;
|
||||
Scalar* d_out;
|
||||
cudaMalloc((void**)(&d_in_x), bytes);
|
||||
cudaMalloc((void**)(&d_in_a), bytes);
|
||||
cudaMalloc((void**)(&d_in_b), bytes);
|
||||
cudaMalloc((void**)(&d_out), bytes);
|
||||
|
||||
cudaMemcpy(d_in_x, in_x.data(), bytes, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_in_a, in_a.data(), bytes, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_in_b, in_b.data(), bytes, cudaMemcpyHostToDevice);
|
||||
|
||||
Eigen::CudaStreamDevice stream;
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_x(d_in_x, 125);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_a(d_in_a, 125);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_b(d_in_b, 125);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 125);
|
||||
|
||||
gpu_out.device(gpu_device) = betainc(gpu_in_a, gpu_in_b, gpu_in_x);
|
||||
|
||||
assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||
|
||||
for (int i = 1; i < 125; ++i) {
|
||||
if ((std::isnan)(expected_out(i))) {
|
||||
VERIFY((std::isnan)(out(i)));
|
||||
} else {
|
||||
VERIFY_IS_APPROX(out(i), expected_out(i));
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(d_in_x);
|
||||
cudaFree(d_in_a);
|
||||
cudaFree(d_in_b);
|
||||
cudaFree(d_out);
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_cuda()
|
||||
{
|
||||
CALL_SUBTEST_1(test_cuda_elementwise_small());
|
||||
@ -1086,5 +1233,8 @@ void test_cxx11_tensor_cuda()
|
||||
|
||||
CALL_SUBTEST_5(test_cuda_igamma<double>());
|
||||
CALL_SUBTEST_5(test_cuda_igammac<double>());
|
||||
|
||||
CALL_SUBTEST_6(test_cuda_betainc<float>());
|
||||
CALL_SUBTEST_6(test_cuda_betainc<double>());
|
||||
#endif
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user