Reduce block evaluation overhead for small tensor expressions

This commit is contained in:
Eugene Zhulenev 2019-12-17 19:06:14 +00:00 committed by Rasmus Munk Larsen
parent 7252163335
commit 788bef6ab5
2 changed files with 48 additions and 20 deletions

View File

@ -282,19 +282,8 @@ class TensorBlockMapper {
TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
const TensorBlockResourceRequirements& requirements)
: m_tensor_dimensions(dimensions), m_requirements(requirements) {
// Initialize `m_block_dimensions`.
// Compute block dimensions and the total number of blocks.
InitializeBlockDimensions();
// Calculate block counts by dimension and total block count.
DSizes<IndexType, NumDims> block_count;
for (int i = 0; i < NumDims; ++i) {
block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
}
m_total_block_count = array_prod(block_count);
// Calculate block strides (used for enumerating blocks).
m_tensor_strides = strides<Layout>(m_tensor_dimensions);
m_block_strides = strides<Layout>(block_count);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
@ -339,23 +328,33 @@ class TensorBlockMapper {
void InitializeBlockDimensions() {
// Requested block shape and size.
const TensorBlockShapeType shape_type = m_requirements.shape_type;
const IndexType target_block_size =
IndexType target_block_size =
numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));
IndexType tensor_size = m_tensor_dimensions.TotalSize();
// Corner case: one of the dimensions is zero. Logic below is too complex
// to handle this case on a general basis, just use unit block size.
// Note: we must not yield blocks with zero dimensions (recipe for
// overflows/underflows, divisions by zero and NaNs later).
if (m_tensor_dimensions.TotalSize() == 0) {
if (tensor_size == 0) {
for (int i = 0; i < NumDims; ++i) {
m_block_dimensions[i] = 1;
}
m_total_block_count = 0;
return;
}
// If tensor fits into a target block size, evaluate it as a single block.
if (m_tensor_dimensions.TotalSize() <= target_block_size) {
if (tensor_size <= target_block_size) {
m_block_dimensions = m_tensor_dimensions;
m_total_block_count = 1;
// The only valid block index is `0`, and in this case we do not need
// to compute real strides for tensor or blocks (see blockDescriptor).
for (int i = 0; i < NumDims; ++i) {
m_tensor_strides[i] = 0;
m_block_strides[i] = 1;
}
return;
}
@ -418,6 +417,17 @@ class TensorBlockMapper {
eigen_assert(m_block_dimensions.TotalSize() >=
numext::mini<IndexType>(target_block_size,
m_tensor_dimensions.TotalSize()));
// Calculate block counts by dimension and total block count.
DSizes<IndexType, NumDims> block_count;
for (int i = 0; i < NumDims; ++i) {
block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
}
m_total_block_count = array_prod(block_count);
// Calculate block strides (used for enumerating blocks).
m_tensor_strides = strides<Layout>(m_tensor_dimensions);
m_block_strides = strides<Layout>(block_count);
}
DSizes<IndexType, NumDims> m_tensor_dimensions;

View File

@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable,
IndexType lastBlockIdx) {
TensorBlockScratch scratch(device);
for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) {
for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx;
++block_idx) {
TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx);
evaluator.evalBlock(desc, scratch);
scratch.reset();
}
};
device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
eval_block);
// Evaluate small expressions directly as a single block.
if (tiling.block_mapper.blockCount() == 1) {
TensorBlockScratch scratch(device);
TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions());
evaluator.evalBlock(desc, scratch);
} else {
device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
eval_block);
}
}
evaluator.cleanup();
}
@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
scratch.reset();
}
};
ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
ctx->tiling.cost, eval_block, [ctx]() { delete ctx; });
// Evaluate small expressions directly as a single block.
if (ctx->tiling.block_mapper.blockCount() == 1) {
TensorBlockScratch scratch(ctx->device);
TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions());
ctx->evaluator.evalBlock(desc, scratch);
delete ctx;
} else {
ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
ctx->tiling.cost, eval_block,
[ctx]() { delete ctx; });
}
};
ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);