diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index 50753a53a..3f9aa1026 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -91,6 +91,7 @@ #include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index b5bdf9d8e..1c31b9d28 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -481,6 +481,18 @@ class TensorBase return TensorStridingOp(derived(), strides); } + // Added support for custom unary and binary operations + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCustomUnaryOp customOp(const CustomUnaryFunc& op) const { + return TensorCustomUnaryOp(derived(), op); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCustomBinaryOp customOp(const OtherDerived& other, const CustomBinaryFunc& op) const { + return TensorCustomBinaryOp(derived(), other, op); + } + // Force the evaluation of the expression. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorForcedEvalOp eval() const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h b/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h new file mode 100644 index 000000000..0157f6fab --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h @@ -0,0 +1,310 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H +#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H + +namespace Eigen { + +/** \class TensorCustomUnaryOp + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor custom class. + * + * + */ +namespace internal { +template +struct traits > +{ + typedef typename XprType::Scalar Scalar; + typedef typename packet_traits::type Packet; + typedef typename XprType::StorageKind StorageKind; + typedef typename XprType::Index Index; + typedef typename XprType::Nested Nested; + typedef typename remove_reference::type _Nested; + static const int NumDimensions = traits::NumDimensions; + static const int Layout = traits::Layout; +}; + +template +struct eval, Eigen::Dense> +{ + typedef const TensorCustomUnaryOp& type; +}; + +template +struct nested > +{ + typedef TensorCustomUnaryOp type; +}; + +} // end namespace internal + + + +template +class TensorCustomUnaryOp : public TensorBase, ReadOnlyAccessors> +{ + public: + typedef typename internal::traits::Scalar Scalar; + typedef typename internal::traits::Packet Packet; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef typename internal::nested::type Nested; + typedef typename internal::traits::StorageKind StorageKind; + typedef typename internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func) + : m_expr(expr), m_func(func) {} + + EIGEN_DEVICE_FUNC + const CustomUnaryFunc& func() const { return m_func; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + expression() const { return m_expr; } + + protected: + typename XprType::Nested m_expr; + const CustomUnaryFunc m_func; +}; + + +// Eval as rvalue +template +struct TensorEvaluator, Device> +{ + typedef TensorCustomUnaryOp ArgType; + typedef typename internal::traits::Index Index; + static const int NumDims = internal::traits::NumDimensions; + typedef DSizes Dimensions; + typedef + typename internal::remove_const::type Scalar; + + enum { + IsAligned = false, + PacketAccess = (internal::packet_traits::size > 1), + BlockAccess = false, + Layout = TensorEvaluator::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device) + : m_op(op), m_device(device), m_result(NULL) + { + m_dimensions = op.func().dimensions(op.expression()); + } + + typedef typename internal::remove_const::type CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + if (data) { + evalTo(data); + return false; + } else { + m_result = static_cast( + m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); + evalTo(m_result); + return true; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + if (m_result != NULL) { + m_device.deallocate(m_result); + m_result = NULL; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + return m_result[index]; + } + + template + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { + return internal::ploadt(m_result + index); + } + + EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } + + protected: + EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { + TensorMap > result( + data, m_dimensions); + m_op.func().eval(m_op.expression(), result, m_device); + } + + Dimensions m_dimensions; + const ArgType m_op; + const Device& m_device; + CoeffReturnType* m_result; +}; + + + +/** \class TensorCustomBinaryOp + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor custom class. + * + * + */ +namespace internal { +template +struct traits > +{ + typedef typename internal::promote_storage_type::ret Scalar; + typedef typename packet_traits::type Packet; + typedef typename internal::promote_storage_type::ret CoeffReturnType; + typedef typename internal::promote_storage_type::ret PacketReturnType; + typedef typename promote_storage_type::StorageKind, + typename traits::StorageKind>::ret StorageKind; + typedef typename promote_index_type::Index, + typename traits::Index>::type Index; + typedef typename LhsXprType::Nested LhsNested; + typedef typename RhsXprType::Nested RhsNested; + typedef typename remove_reference::type _LhsNested; + typedef typename remove_reference::type _RhsNested; + static const int NumDimensions = traits::NumDimensions; + static const int Layout = traits::Layout; +}; + +template +struct eval, Eigen::Dense> +{ + typedef const TensorCustomBinaryOp& type; +}; + +template +struct nested > +{ + typedef TensorCustomBinaryOp type; +}; + +} // end namespace internal + + + +template +class TensorCustomBinaryOp : public TensorBase, ReadOnlyAccessors> +{ + public: + typedef typename internal::traits::Scalar Scalar; + typedef typename internal::traits::Packet Packet; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename internal::traits::CoeffReturnType CoeffReturnType; + typedef typename internal::traits::PacketReturnType PacketReturnType; + typedef typename internal::nested::type Nested; + typedef typename internal::traits::StorageKind StorageKind; + typedef typename internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func) + + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {} + + EIGEN_DEVICE_FUNC + const CustomBinaryFunc& func() const { return m_func; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + lhsExpression() const { return m_lhs_xpr; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + rhsExpression() const { return m_rhs_xpr; } + + protected: + typename LhsXprType::Nested m_lhs_xpr; + typename RhsXprType::Nested m_rhs_xpr; + const CustomBinaryFunc m_func; +}; + + +// Eval as rvalue +template +struct TensorEvaluator, Device> +{ + typedef TensorCustomBinaryOp XprType; + typedef typename internal::traits::Index Index; + static const int NumDims = internal::traits::NumDimensions; + typedef DSizes Dimensions; + typedef typename XprType::Scalar Scalar; + + enum { + IsAligned = false, + PacketAccess = (internal::packet_traits::size > 1), + BlockAccess = false, + Layout = TensorEvaluator::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_op(op), m_device(device), m_result(NULL) + { + m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression()); + } + + typedef typename internal::remove_const::type CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + if (data) { + evalTo(data); + return false; + } else { + m_result = static_cast(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); + evalTo(m_result); + return true; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + if (m_result != NULL) { + m_device.deallocate(m_result); + m_result = NULL; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + return m_result[index]; + } + + template + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { + return internal::ploadt(m_result + index); + } + + EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } + + protected: + EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { + TensorMap > result(data, m_dimensions); + m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); + } + + Dimensions m_dimensions; + const XprType m_op; + const Device& m_device; + CoeffReturnType* m_result; +}; + + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index a3fb26acc..7df8d1453 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -42,6 +42,9 @@ template class TensorStridingOp; template class TensorGeneratorOp; template class TensorAssignOp; +template class TensorCustomUnaryOp; +template class TensorCustomBinaryOp; + template class TensorEvalToOp; template class TensorForcedEvalOp; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h index 757d6cb38..2e6b93bfb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h @@ -156,9 +156,9 @@ struct eval, Eigen::Dense> }; // TODO nested<> does not exist anymore in Eigen/Core, and it thus has to be removed in favor of ref_selector. -template struct nested -{ - typedef typename ref_selector::type type; +template struct nested +{ + typedef typename ref_selector::type type; }; template diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index 5ab034a5e..155bfcd76 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -137,6 +137,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_layout_swap "-std=c++0x") ei_add_test(cxx11_tensor_io "-std=c++0x") ei_add_test(cxx11_tensor_generator "-std=c++0x") + ei_add_test(cxx11_tensor_custom_op "-std=c++0x") # These tests needs nvcc # ei_add_test(cxx11_tensor_device "-std=c++0x") diff --git a/unsupported/test/cxx11_tensor_custom_op.cpp b/unsupported/test/cxx11_tensor_custom_op.cpp new file mode 100644 index 000000000..7e33c9580 --- /dev/null +++ b/unsupported/test/cxx11_tensor_custom_op.cpp @@ -0,0 +1,107 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + +using Eigen::Tensor; + + +struct InsertZeros { + DSizes dimensions(const Tensor& input) const { + DSizes result; + result[0] = input.dimension(0) * 2; + result[1] = input.dimension(1) * 2; + return result; + } + + template + void eval(const Tensor& input, Output& output, const Device& device) const + { + array strides{{2, 2}}; + output.stride(strides).device(device) = input; + + Eigen::DSizes offsets(1,1); + Eigen::DSizes extents(output.dimension(0)-1, output.dimension(1)-1); + output.slice(offsets, extents).stride(strides).device(device) = input.constant(0.0f); + } +}; + +static void test_custom_unary_op() +{ + Tensor tensor(3,5); + tensor.setRandom(); + + Tensor result = tensor.customOp(InsertZeros()); + VERIFY_IS_EQUAL(result.dimension(0), 6); + VERIFY_IS_EQUAL(result.dimension(1), 10); + + for (int i = 0; i < 6; i+=2) { + for (int j = 0; j < 10; j+=2) { + VERIFY_IS_EQUAL(result(i, j), tensor(i/2, j/2)); + } + } + for (int i = 1; i < 6; i+=2) { + for (int j = 1; j < 10; j+=2) { + VERIFY_IS_EQUAL(result(i, j), 0); + } + } +} + + +struct BatchMatMul { + DSizes dimensions(const Tensor& input1, const Tensor& input2) const { + DSizes result; + result[0] = input1.dimension(0); + result[1] = input2.dimension(1); + result[2] = input2.dimension(2); + return result; + } + + template + void eval(const Tensor& input1, const Tensor& input2, + Output& output, const Device& device) const + { + typedef Tensor::DimensionPair DimPair; + array dims({{DimPair(1, 0)}}); + for (int i = 0; i < output.dimension(2); ++i) { + output.template chip<2>(i).device(device) = input1.chip<2>(i).contract(input2.chip<2>(i), dims); + } + } +}; + + +static void test_custom_binary_op() +{ + Tensor tensor1(2,3,5); + tensor1.setRandom(); + Tensor tensor2(3,7,5); + tensor2.setRandom(); + + Tensor result = tensor1.customOp(tensor2, BatchMatMul()); + for (int i = 0; i < 5; ++i) { + typedef Tensor::DimensionPair DimPair; + array dims({{DimPair(1, 0)}}); + Tensor reference = tensor1.chip<2>(i).contract(tensor2.chip<2>(i), dims); + TensorRef> val = result.chip<2>(i); + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_APPROX(val(j, k), reference(j, k)); + } + } + } +} + + +void test_cxx11_tensor_custom_op() +{ + CALL_SUBTEST(test_custom_unary_op()); + CALL_SUBTEST(test_custom_binary_op()); +}