Fixes #2735: Component-wise cbrt

This commit is contained in:
Kyle Macfarlan 2023-10-25 03:06:13 +00:00 committed by Charles Schlosser
parent 48b254a4bc
commit 5de0f2f89e
14 changed files with 87 additions and 2 deletions

View File

@ -140,6 +140,7 @@ EIGEN_MKL_VML_DECLARE_UNARY_CALLS_CPLX(arg, Arg, _)
EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(round, Round, _)
EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(floor, Floor, _)
EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(ceil, Ceil, _)
EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(cbrt, Cbrt, _)
#define EIGEN_MKL_VML_DECLARE_POW_CALL(EIGENOP, VMLOP, EIGENTYPE, VMLTYPE, VMLMODE) \
template< typename DstXprType, typename SrcXprNested, typename Plain> \

View File

@ -1017,6 +1017,10 @@ Packet plog2(const Packet& a) {
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet psqrt(const Packet& a) { return numext::sqrt(a); }
/** \internal \returns the cube-root of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pcbrt(const Packet& a) { return numext::cbrt(a); }
/** \internal \returns the rounded value of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pround(const Packet& a) { using numext::round; return round(a); }

View File

@ -89,6 +89,7 @@ namespace Eigen
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(arg,scalar_arg_op,complex argument,\sa ArrayBase::arg DOXCOMMA MatrixBase::cwiseArg)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(carg, scalar_carg_op, complex argument, \sa ArrayBase::carg DOXCOMMA MatrixBase::cwiseCArg)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sqrt,scalar_sqrt_op,square root,\sa ArrayBase::sqrt DOXCOMMA MatrixBase::cwiseSqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cbrt,scalar_cbrt_op,cube root,\sa ArrayBase::cbrt DOXCOMMA MatrixBase::cwiseCbrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rsqrt,scalar_rsqrt_op,reciprocal square root,\sa ArrayBase::rsqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(square,scalar_square_op,square (power 2),\sa Eigen::abs2 DOXCOMMA Eigen::pow DOXCOMMA ArrayBase::square)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cube,scalar_cube_op,cube (power 3),\sa Eigen::pow DOXCOMMA ArrayBase::cube)

View File

@ -1394,6 +1394,14 @@ bool sqrt<bool>(const bool &x) { return x; }
SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt)
#endif
/** \returns the cube root of \a x. **/
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T cbrt(const T &x) {
EIGEN_USING_STD(cbrt);
return static_cast<T>(cbrt(x));
}
/** \returns the reciprocal square root of \a x. **/
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE

View File

@ -471,6 +471,20 @@ struct functor_traits<scalar_sqrt_op<bool> > {
enum { Cost = 1, PacketAccess = packet_traits<bool>::Vectorizable };
};
/** \internal
* \brief Template functor to compute the cube root of a scalar
* \sa class CwiseUnaryOp, Cwise::sqrt()
*/
template <typename Scalar>
struct scalar_cbrt_op {
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::cbrt(a); }
};
template <typename Scalar>
struct functor_traits<scalar_cbrt_op<Scalar> > {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
};
/** \internal
* \brief Template functor to compute the reciprocal square root of a scalar
* \sa class CwiseUnaryOp, Cwise::rsqrt()

View File

@ -185,6 +185,7 @@ template<typename Scalar> struct scalar_abs_op;
template<typename Scalar> struct scalar_abs2_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_absolute_difference_op;
template<typename Scalar> struct scalar_sqrt_op;
template<typename Scalar> struct scalar_cbrt_op;
template<typename Scalar> struct scalar_rsqrt_op;
template<typename Scalar> struct scalar_exp_op;
template<typename Scalar> struct scalar_log_op;

View File

@ -5,6 +5,7 @@ typedef CwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived> ArgReturnTy
typedef CwiseUnaryOp<internal::scalar_carg_op<Scalar>, const Derived> CArgReturnType;
typedef CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const Derived> Abs2ReturnType;
typedef CwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> SqrtReturnType;
typedef CwiseUnaryOp<internal::scalar_cbrt_op<Scalar>, const Derived> CbrtReturnType;
typedef CwiseUnaryOp<internal::scalar_rsqrt_op<Scalar>, const Derived> RsqrtReturnType;
typedef CwiseUnaryOp<internal::scalar_sign_op<Scalar>, const Derived> SignReturnType;
typedef CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> InverseReturnType;
@ -184,7 +185,7 @@ log2() const
* Example: \include Cwise_sqrt.cpp
* Output: \verbinclude Cwise_sqrt.out
*
* \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_sqrt">Math functions</a>, pow(), square()
* \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_sqrt">Math functions</a>, pow(), square(), cbrt()
*/
EIGEN_DEVICE_FUNC
inline const SqrtReturnType
@ -193,6 +194,22 @@ sqrt() const
return SqrtReturnType(derived());
}
/** \returns an expression of the coefficient-wise cube root of *this.
*
* This function computes the coefficient-wise cube root.
*
* Example: \include Cwise_cbrt.cpp
* Output: \verbinclude Cwise_cbrt.out
*
* \sa <a href="group__CoeffwiseMathFunctions.html#cwisetable_cbrt">Math functions</a>, sqrt(), pow(), square()
*/
EIGEN_DEVICE_FUNC
inline const CbrtReturnType
cbrt() const
{
return CbrtReturnType(derived());
}
/** \returns an expression of the coefficient-wise inverse square root of *this.
*
* This function computes the coefficient-wise inverse square root.

View File

@ -17,6 +17,7 @@ typedef CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const Derived> CwiseAbs2R
typedef CwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived> CwiseArgReturnType;
typedef CwiseUnaryOp<internal::scalar_carg_op<Scalar>, const Derived> CwiseCArgReturnType;
typedef CwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> CwiseSqrtReturnType;
typedef CwiseUnaryOp<internal::scalar_cbrt_op<Scalar>, const Derived> CwiseCbrtReturnType;
typedef CwiseUnaryOp<internal::scalar_sign_op<Scalar>, const Derived> CwiseSignReturnType;
typedef CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> CwiseInverseReturnType;
@ -53,12 +54,25 @@ cwiseAbs2() const { return CwiseAbs2ReturnType(derived()); }
///
EIGEN_DOC_UNARY_ADDONS(cwiseSqrt,square-root)
///
/// \sa cwisePow(), cwiseSquare()
/// \sa cwisePow(), cwiseSquare(), cwiseCbrt()
///
EIGEN_DEVICE_FUNC
inline const CwiseSqrtReturnType
cwiseSqrt() const { return CwiseSqrtReturnType(derived()); }
/// \returns an expression of the coefficient-wise cube root of *this.
///
/// Example: \include MatrixBase_cwiseCbrt.cpp
/// Output: \verbinclude MatrixBase_cwiseCbrt.out
///
EIGEN_DOC_UNARY_ADDONS(cwiseCbrt,cube-root)
///
/// \sa cwiseSqrt(), cwiseSquare(), cwisePow()
///
EIGEN_DEVICE_FUNC
inline const CwiseCbrtReturnType
cwiseCbrt() const { return CwiseSCbrtReturnType(derived()); }
/// \returns an expression of the coefficient-wise signum of *this.
///
/// Example: \include MatrixBase_cwiseSign.cpp

View File

@ -140,6 +140,8 @@ R.array().square() // P .^ 2
R.array().cube() // P .^ 3
R.cwiseSqrt() // sqrt(P)
R.array().sqrt() // sqrt(P)
R.cwiseCbrt() // cbrt(P)
R.array().cbrt() // cbrt(P)
R.array().exp() // exp(P)
R.array().log() // log(P)
R.cwiseMax(P) // max(R, P)

View File

@ -170,6 +170,19 @@ This also means that, unless specified, if the function \c std::foo is available
sqrt(a[i]);</td>
<td>SSE2, AVX (f,d)</td>
</tr>
<tr>
<td class="code">
\anchor cwisetable_cbrt
a.\link ArrayBase::cbrt cbrt\endlink(); \n
\link Eigen::cbrt cbrt\endlink(a);\n
m.\link MatrixBase::cwiseCbrt cwiseCbrt\endlink();
</td>
<td>computes cube root (\f$ \cbrt a_i \f$)</td>
<td class="code">
using <a href="http://en.cppreference.com/w/cpp/numeric/math/cbrt">std::cbrt</a>; \n
cbrt(a[i]);</td>
<td></td>
</tr>
<tr>
<td class="code">
\anchor cwisetable_rsqrt

View File

@ -458,6 +458,7 @@ mat1.cwiseMax(mat2) mat1.cwiseMax(scalar)
mat1.cwiseAbs2()
mat1.cwiseAbs()
mat1.cwiseSqrt()
mat1.cwiseCbrt()
mat1.cwiseInverse()
mat1.cwiseProduct(mat2)
mat1.cwiseQuotient(mat2)
@ -470,6 +471,7 @@ mat1.array().max(mat2.array()) mat1.array().max(scalar)
mat1.array().abs2()
mat1.array().abs()
mat1.array().sqrt()
mat1.array().cbrt()
mat1.array().inverse()
mat1.array() * mat2.array()
mat1.array() / mat2.array()

View File

@ -172,6 +172,7 @@ sm2 = perm * sm1; // Permute the columns
sm1.cwiseMax(sm2);
sm1.cwiseAbs();
sm1.cwiseSqrt();
sm1.cwiseCbrt();
\endcode</td>
<td>
sm1 and sm2 should have the same storage order

View File

@ -0,0 +1,2 @@
Array3d v(1,2,4);
cout << v.cbrt() << endl;

View File

@ -169,7 +169,9 @@ void unary_op_test(std::string name, Fn fun, RefFn ref) {
template <typename Scalar>
void unary_ops_test() {
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(sqrt));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(cbrt));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(exp));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(log));
unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(sin));
@ -821,6 +823,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
m3 = m4.abs();
VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m3)));
VERIFY_IS_APPROX(m3.cbrt(), cbrt(m3));
VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m3)));
VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m3)));
VERIFY_IS_APPROX(m3.log(), log(m3));
@ -882,6 +885,8 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt());
VERIFY_IS_APPROX(pow(m3,RealScalar(0.5)), m3.sqrt());
VERIFY_IS_APPROX(m3.pow(RealScalar(1.0/3.0)), m3.cbrt());
VERIFY_IS_APPROX(pow(m3,RealScalar(1.0/3.0)), m3.cbrt());
VERIFY_IS_APPROX(m3.pow(RealScalar(-0.5)), m3.rsqrt());
VERIFY_IS_APPROX(pow(m3,RealScalar(-0.5)), m3.rsqrt());