diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 1b031b7a1..33e8c01c2 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -17,6 +17,7 @@ namespace internal { template struct SumReducer { static const bool PacketAccess = true; + static const bool IsStateful = false; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { (*accum) += t; @@ -49,6 +50,8 @@ template struct SumReducer template struct MeanReducer { static const bool PacketAccess = true; + static const bool IsStateful = true; + MeanReducer() : scalarCount_(0), packetCount_(0) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) { @@ -88,6 +91,7 @@ template struct MeanReducer template struct MaxReducer { static const bool PacketAccess = true; + static const bool IsStateful = false; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t > *accum) { *accum = t; } @@ -120,6 +124,7 @@ template struct MaxReducer template struct MinReducer { static const bool PacketAccess = true; + static const bool IsStateful = false; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { if (t < *accum) { *accum = t; } @@ -153,6 +158,7 @@ template struct MinReducer template struct ProdReducer { static const bool PacketAccess = true; + static const bool IsStateful = false; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { (*accum) *= t; diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 0a20a01a4..5ec7c8bf4 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -228,6 +228,29 @@ static void test_multithread_contraction_agrees_with_singlethread() { } +template +static void test_multithreaded_reductions() { + const int num_threads = internal::random(3, 11); + ThreadPool thread_pool(num_threads); + Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads); + + const int num_rows = internal::random(13, 732); + const int num_cols = internal::random(13, 732); + Tensor t1(num_rows, num_cols); + t1.setRandom(); + + Tensor full_redux(1); + full_redux = t1.sum(); + + Tensor full_redux_tp(1); + full_redux_tp.device(thread_pool_device) = t1.sum(); + + // Check that the single threaded and the multi threaded reductions return + // the same result. + VERIFY_IS_APPROX(full_redux(0), full_redux_tp(0)); +} + + static void test_memcpy() { for (int i = 0; i < 5; ++i) { @@ -271,6 +294,9 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST(test_contraction_corner_cases()); CALL_SUBTEST(test_contraction_corner_cases()); + CALL_SUBTEST(test_multithreaded_reductions()); + CALL_SUBTEST(test_multithreaded_reductions()); + CALL_SUBTEST(test_memcpy()); CALL_SUBTEST(test_multithread_random());