From f0ce85b757ce237d763d7751bda61901e78d5dc8 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 29 Jun 2015 14:04:15 -0700 Subject: [PATCH] Improved support for fixed size tensors --- unsupported/Eigen/CXX11/src/Tensor/Tensor.h | 22 ++++++++ .../Eigen/CXX11/src/Tensor/TensorDimensions.h | 51 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h index 4dbdbfb3e..24953ec94 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h @@ -375,6 +375,28 @@ class Tensor : public TensorBase + EIGEN_DEVICE_FUNC + void resize(const Sizes& dimensions) { + array dims; + for (int i = 0; i < NumIndices; ++i) { + dims[i] = dimensions[i]; + } + resize(dims); + } +#else + template + EIGEN_DEVICE_FUNC + void resize(const Sizes& dimensions) { + array dims; + for (int i = 0; i < NumIndices; ++i) { + dims[i] = dimensions[i]; + } + resize(dims); + } +#endif + protected: bool checkIndexRange(const array& indices) const diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h index 836daea65..5928f0b0c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h @@ -69,6 +69,31 @@ struct fixed_size_tensor_index_linearization_helper +struct fixed_size_tensor_index_extraction_helper +{ + template EIGEN_DEVICE_FUNC + static inline Index run(const Index index, + const Dimensions& dimensions) + { + const Index mult = (index == n) ? 1 : 0; + return array_get(dimensions) * mult + + fixed_size_tensor_index_extraction_helper::run(index, dimensions); + } +}; + +template +struct fixed_size_tensor_index_extraction_helper +{ + template EIGEN_DEVICE_FUNC + static inline Index run(const Index index, + const Dimensions& dimensions) + { + const Index mult = (index == 0) ? 1 : 0; + return array_get<0>(dimensions) * mult; + } +}; + } // end namespace internal @@ -99,6 +124,10 @@ struct Sizes : internal::numeric_list { } #endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const int index) const { + return internal::fixed_size_tensor_index_extraction_helper::run(index, *this); + } + template Sizes& operator = (const T& /*other*/) { // add assertion failure if the size of other is different return *this; @@ -114,10 +143,12 @@ struct Sizes : internal::numeric_list { } }; +namespace internal { template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes&) { return Sizes::total_size; } +} #else @@ -166,6 +197,24 @@ template ::value; + case 1: + return internal::get<1, Base>::value; + case 2: + return internal::get<2, Base>::value; + case 3: + return internal::get<3, Base>::value; + case 4: + return internal::get<4, Base>::value; + default: + eigen_assert(false && "index overflow"); + return static_cast(-1); + } + } + template Sizes& operator = (const T&) { // to do: check the size of other return *this; @@ -181,10 +230,12 @@ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes&) { return Sizes::total_size; } +} #endif