mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
Reduce block evaluation overhead for small tensor expressions
This commit is contained in:
parent
7252163335
commit
788bef6ab5
@ -282,19 +282,8 @@ class TensorBlockMapper {
|
||||
TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
|
||||
const TensorBlockResourceRequirements& requirements)
|
||||
: m_tensor_dimensions(dimensions), m_requirements(requirements) {
|
||||
// Initialize `m_block_dimensions`.
|
||||
// Compute block dimensions and the total number of blocks.
|
||||
InitializeBlockDimensions();
|
||||
|
||||
// Calculate block counts by dimension and total block count.
|
||||
DSizes<IndexType, NumDims> block_count;
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
|
||||
}
|
||||
m_total_block_count = array_prod(block_count);
|
||||
|
||||
// Calculate block strides (used for enumerating blocks).
|
||||
m_tensor_strides = strides<Layout>(m_tensor_dimensions);
|
||||
m_block_strides = strides<Layout>(block_count);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
|
||||
@ -339,23 +328,33 @@ class TensorBlockMapper {
|
||||
void InitializeBlockDimensions() {
|
||||
// Requested block shape and size.
|
||||
const TensorBlockShapeType shape_type = m_requirements.shape_type;
|
||||
const IndexType target_block_size =
|
||||
IndexType target_block_size =
|
||||
numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));
|
||||
|
||||
IndexType tensor_size = m_tensor_dimensions.TotalSize();
|
||||
|
||||
// Corner case: one of the dimensions is zero. Logic below is too complex
|
||||
// to handle this case on a general basis, just use unit block size.
|
||||
// Note: we must not yield blocks with zero dimensions (recipe for
|
||||
// overflows/underflows, divisions by zero and NaNs later).
|
||||
if (m_tensor_dimensions.TotalSize() == 0) {
|
||||
if (tensor_size == 0) {
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_block_dimensions[i] = 1;
|
||||
}
|
||||
m_total_block_count = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
// If tensor fits into a target block size, evaluate it as a single block.
|
||||
if (m_tensor_dimensions.TotalSize() <= target_block_size) {
|
||||
if (tensor_size <= target_block_size) {
|
||||
m_block_dimensions = m_tensor_dimensions;
|
||||
m_total_block_count = 1;
|
||||
// The only valid block index is `0`, and in this case we do not need
|
||||
// to compute real strides for tensor or blocks (see blockDescriptor).
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_tensor_strides[i] = 0;
|
||||
m_block_strides[i] = 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@ -418,6 +417,17 @@ class TensorBlockMapper {
|
||||
eigen_assert(m_block_dimensions.TotalSize() >=
|
||||
numext::mini<IndexType>(target_block_size,
|
||||
m_tensor_dimensions.TotalSize()));
|
||||
|
||||
// Calculate block counts by dimension and total block count.
|
||||
DSizes<IndexType, NumDims> block_count;
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
|
||||
}
|
||||
m_total_block_count = array_prod(block_count);
|
||||
|
||||
// Calculate block strides (used for enumerating blocks).
|
||||
m_tensor_strides = strides<Layout>(m_tensor_dimensions);
|
||||
m_block_strides = strides<Layout>(block_count);
|
||||
}
|
||||
|
||||
DSizes<IndexType, NumDims> m_tensor_dimensions;
|
||||
|
@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable,
|
||||
IndexType lastBlockIdx) {
|
||||
TensorBlockScratch scratch(device);
|
||||
|
||||
for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) {
|
||||
for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx;
|
||||
++block_idx) {
|
||||
TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx);
|
||||
evaluator.evalBlock(desc, scratch);
|
||||
scratch.reset();
|
||||
}
|
||||
};
|
||||
|
||||
device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
|
||||
eval_block);
|
||||
// Evaluate small expressions directly as a single block.
|
||||
if (tiling.block_mapper.blockCount() == 1) {
|
||||
TensorBlockScratch scratch(device);
|
||||
TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions());
|
||||
evaluator.evalBlock(desc, scratch);
|
||||
} else {
|
||||
device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
|
||||
eval_block);
|
||||
}
|
||||
}
|
||||
evaluator.cleanup();
|
||||
}
|
||||
@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
|
||||
scratch.reset();
|
||||
}
|
||||
};
|
||||
ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
|
||||
ctx->tiling.cost, eval_block, [ctx]() { delete ctx; });
|
||||
|
||||
// Evaluate small expressions directly as a single block.
|
||||
if (ctx->tiling.block_mapper.blockCount() == 1) {
|
||||
TensorBlockScratch scratch(ctx->device);
|
||||
TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions());
|
||||
ctx->evaluator.evalBlock(desc, scratch);
|
||||
delete ctx;
|
||||
} else {
|
||||
ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
|
||||
ctx->tiling.cost, eval_block,
|
||||
[ctx]() { delete ctx; });
|
||||
}
|
||||
};
|
||||
|
||||
ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);
|
||||
|
Loading…
Reference in New Issue
Block a user