From 670c71d906a4f0adc7edf266c996183ae8e4a2cc Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 29 Jun 2015 11:30:36 -0700 Subject: [PATCH] Express the full reduction operations (such as sum, max, min) using TensorDimensionList --- .../Eigen/CXX11/src/Tensor/TensorBase.h | 35 ++++++++----------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 944dbf03f..30432fbc8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -301,11 +301,10 @@ class TensorBase return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::SumReducer()); } - const TensorReductionOp, const array, const Derived> + const TensorReductionOp, const DimensionList, const Derived> sum() const { - array in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::SumReducer()); + DimensionList in_dims; + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::SumReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -314,11 +313,10 @@ class TensorBase return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MeanReducer()); } - const TensorReductionOp, const array, const Derived> + const TensorReductionOp, const DimensionList, const Derived> mean() const { - array in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::MeanReducer()); + DimensionList in_dims; + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MeanReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -327,11 +325,10 @@ class TensorBase return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::ProdReducer()); } - const TensorReductionOp, const array, const Derived> + const TensorReductionOp, const DimensionList, const Derived> prod() const { - array in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::ProdReducer()); + DimensionList in_dims; + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::ProdReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -340,11 +337,10 @@ class TensorBase return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MaxReducer()); } - const TensorReductionOp, const array, const Derived> + const TensorReductionOp, const DimensionList, const Derived> maximum() const { - array in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::MaxReducer()); + DimensionList in_dims; + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MaxReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -353,11 +349,10 @@ class TensorBase return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MinReducer()); } - const TensorReductionOp, const array, const Derived> + const TensorReductionOp, const DimensionList, const Derived> minimum() const { - array in_dims; - for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; - return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::MinReducer()); + DimensionList in_dims; + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MinReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE