From 4181556907fd29d6328fb718fa42cf9ce4734133 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 15 Aug 2018 09:34:47 -0700 Subject: [PATCH] Fixed the tensor contraction code. --- unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 10 ++-------- unsupported/test/cxx11_tensor_contraction.cpp | 2 +- unsupported/test/cxx11_tensor_thread_pool.cpp | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index a023718c6..5d619efd8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -152,13 +152,7 @@ struct TensorContractionParams { // 1. Elementwise Relu transformation following Conv2D. // 2. AddBias to the Conv2D output channels dimension. // -// See expected implementation in NoOpOutputKernel. -struct OutputKernel { - template - typedef internal::blas_data_mapper OutputMapper; -}; - -// Output kernel that does absolutely nothing. +// The NoOpOutputKernel implements an output kernel that does absolutely nothing. struct NoOpOutputKernel { /** * Tensor contraction evaluator calls this kernel after finishing each block @@ -177,7 +171,7 @@ struct NoOpOutputKernel { */ template EIGEN_ALWAYS_INLINE void operator()( - const OutputKernel::OutputMapper& /*output_mapper*/, + const internal::blas_data_mapper& /*output_mapper*/, const TensorContractionParams& /*params*/, Index /*i*/, Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {} }; diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 2e918eb30..928d20f6e 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -514,7 +514,7 @@ static void test_const_inputs() struct SqrtOutputKernel { template EIGEN_ALWAYS_INLINE void operator()( - const OutputKernel::OutputMapper& output_mapper, + const internal::blas_data_mapper& output_mapper, const TensorContractionParams&, Index, Index, Index num_rows, Index num_cols) const { for (int i = 0; i < num_rows; ++i) { diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index dd163c18a..7606b0abf 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -255,7 +255,7 @@ void test_multithread_contraction_agrees_with_singlethread() { struct SqrtOutputKernel { template EIGEN_ALWAYS_INLINE void operator()( - const OutputKernel::OutputMapper& output_mapper, + const internal::blas_data_mapper& output_mapper, const TensorContractionParams&, Index, Index, Index num_rows, Index num_cols) const { for (int i = 0; i < num_rows; ++i) {