Fixed the tensor contraction code.

This commit is contained in:
Benoit Steiner 2018-08-15 09:34:47 -07:00
parent b6f96cf7dd
commit 4181556907
3 changed files with 4 additions and 10 deletions

View File

@ -152,13 +152,7 @@ struct TensorContractionParams {
// 1. Elementwise Relu transformation following Conv2D.
// 2. AddBias to the Conv2D output channels dimension.
//
// See expected implementation in NoOpOutputKernel.
struct OutputKernel {
template <typename Index, typename Scalar>
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
};
// Output kernel that does absolutely nothing.
// The NoOpOutputKernel implements an output kernel that does absolutely nothing.
struct NoOpOutputKernel {
/**
* Tensor contraction evaluator calls this kernel after finishing each block
@ -177,7 +171,7 @@ struct NoOpOutputKernel {
*/
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
const internal::blas_data_mapper<Scalar, Index, ColMajor>& /*output_mapper*/,
const TensorContractionParams& /*params*/, Index /*i*/,
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
};

View File

@ -514,7 +514,7 @@ static void test_const_inputs()
struct SqrtOutputKernel {
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
const TensorContractionParams&, Index, Index, Index num_rows,
Index num_cols) const {
for (int i = 0; i < num_rows; ++i) {

View File

@ -255,7 +255,7 @@ void test_multithread_contraction_agrees_with_singlethread() {
struct SqrtOutputKernel {
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
const TensorContractionParams&, Index, Index, Index num_rows,
Index num_cols) const {
for (int i = 0; i < num_rows; ++i) {