mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Added support for sigmoid function to the tensor module
This commit is contained in:
parent
979b73cebf
commit
06a22ca5bd
@ -380,6 +380,7 @@ UpperBidiagonalization<_MatrixType>& UpperBidiagonalization<_MatrixType>::comput
|
||||
{
|
||||
Index rows = matrix.rows();
|
||||
Index cols = matrix.cols();
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(rows);
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(cols);
|
||||
|
||||
eigen_assert(rows >= cols && "UpperBidiagonalization is only for Arices satisfying rows>=cols.");
|
||||
|
@ -116,6 +116,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return unaryExpr(internal::scalar_tanh_op<Scalar>());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sigmoid_op<Scalar>, const Derived>
|
||||
sigmoid() const {
|
||||
return unaryExpr(internal::scalar_sigmoid_op<Scalar>());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived>
|
||||
exp() const {
|
||||
|
@ -13,6 +13,36 @@
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the sigmoid of a scalar
|
||||
* \sa class CwiseUnaryOp, ArrayBase::sigmoid()
|
||||
*/
|
||||
template <typename T>
|
||||
struct scalar_sigmoid_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
|
||||
const T one = T(1);
|
||||
return one / (one + std::exp(-x));
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
inline Packet packetOp(const Packet& x) const {
|
||||
const Packet one = pset1<Packet>(1);
|
||||
return pdiv(one, padd(one, pexp(pnegate(x))));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct functor_traits<scalar_sigmoid_op<T> > {
|
||||
enum {
|
||||
Cost = NumTraits<T>::AddCost * 2 + NumTraits<T>::MulCost * 6,
|
||||
PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||
packet_traits<T>::HasNegate && packet_traits<T>::HasExp
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
// Standard reduction functors
|
||||
template <typename T> struct SumReducer
|
||||
{
|
||||
|
@ -109,6 +109,7 @@ if(EIGEN_TEST_CXX11)
|
||||
ei_add_test(cxx11_tensor_contraction "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_convolution "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_expr "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_math "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_forced_eval "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_fixed_size "-std=c++0x")
|
||||
ei_add_test(cxx11_tensor_const "-std=c++0x")
|
||||
|
46
unsupported/test/cxx11_tensor_math.cpp
Normal file
46
unsupported/test/cxx11_tensor_math.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
#include <Eigen/CXX11/Tensor>
|
||||
|
||||
using Eigen::Tensor;
|
||||
using Eigen::RowMajor;
|
||||
|
||||
static void test_tanh()
|
||||
{
|
||||
Tensor<float, 1> vec1({6});
|
||||
vec1.setRandom();
|
||||
|
||||
Tensor<float, 1> vec2 = vec1.tanh();
|
||||
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
VERIFY_IS_APPROX(vec2(i), tanhf(vec1(i)));
|
||||
}
|
||||
}
|
||||
|
||||
static void test_sigmoid()
|
||||
{
|
||||
Tensor<float, 1> vec1({6});
|
||||
vec1.setRandom();
|
||||
|
||||
Tensor<float, 1> vec2 = vec1.sigmoid();
|
||||
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
VERIFY_IS_APPROX(vec2(i), 1.0f / (1.0f + std::exp(-vec1(i))));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_math()
|
||||
{
|
||||
CALL_SUBTEST(test_tanh());
|
||||
CALL_SUBTEST(test_sigmoid());
|
||||
}
|
Loading…
Reference in New Issue
Block a user