diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h index 222333847..dc9af3aa8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h @@ -282,19 +282,8 @@ class TensorBlockMapper { TensorBlockMapper(const DSizes& dimensions, const TensorBlockResourceRequirements& requirements) : m_tensor_dimensions(dimensions), m_requirements(requirements) { - // Initialize `m_block_dimensions`. + // Compute block dimensions and the total number of blocks. InitializeBlockDimensions(); - - // Calculate block counts by dimension and total block count. - DSizes block_count; - for (int i = 0; i < NumDims; ++i) { - block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]); - } - m_total_block_count = array_prod(block_count); - - // Calculate block strides (used for enumerating blocks). - m_tensor_strides = strides(m_tensor_dimensions); - m_block_strides = strides(block_count); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const { @@ -339,23 +328,33 @@ class TensorBlockMapper { void InitializeBlockDimensions() { // Requested block shape and size. const TensorBlockShapeType shape_type = m_requirements.shape_type; - const IndexType target_block_size = + IndexType target_block_size = numext::maxi(1, static_cast(m_requirements.size)); + IndexType tensor_size = m_tensor_dimensions.TotalSize(); + // Corner case: one of the dimensions is zero. Logic below is too complex // to handle this case on a general basis, just use unit block size. // Note: we must not yield blocks with zero dimensions (recipe for // overflows/underflows, divisions by zero and NaNs later). - if (m_tensor_dimensions.TotalSize() == 0) { + if (tensor_size == 0) { for (int i = 0; i < NumDims; ++i) { m_block_dimensions[i] = 1; } + m_total_block_count = 0; return; } // If tensor fits into a target block size, evaluate it as a single block. - if (m_tensor_dimensions.TotalSize() <= target_block_size) { + if (tensor_size <= target_block_size) { m_block_dimensions = m_tensor_dimensions; + m_total_block_count = 1; + // The only valid block index is `0`, and in this case we do not need + // to compute real strides for tensor or blocks (see blockDescriptor). + for (int i = 0; i < NumDims; ++i) { + m_tensor_strides[i] = 0; + m_block_strides[i] = 1; + } return; } @@ -418,6 +417,17 @@ class TensorBlockMapper { eigen_assert(m_block_dimensions.TotalSize() >= numext::mini(target_block_size, m_tensor_dimensions.TotalSize())); + + // Calculate block counts by dimension and total block count. + DSizes block_count; + for (int i = 0; i < NumDims; ++i) { + block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]); + } + m_total_block_count = array_prod(block_count); + + // Calculate block strides (used for enumerating blocks). + m_tensor_strides = strides(m_tensor_dimensions); + m_block_strides = strides(block_count); } DSizes m_tensor_dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index e2f1806cb..b90791d8d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -374,15 +374,23 @@ class TensorExecutordevice.parallelForAsync(ctx->tiling.block_mapper.blockCount(), - ctx->tiling.cost, eval_block, [ctx]() { delete ctx; }); + + // Evaluate small expressions directly as a single block. + if (ctx->tiling.block_mapper.blockCount() == 1) { + TensorBlockScratch scratch(ctx->device); + TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions()); + ctx->evaluator.evalBlock(desc, scratch); + delete ctx; + } else { + ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(), + ctx->tiling.cost, eval_block, + [ctx]() { delete ctx; }); + } }; ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);