mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Add recursive work splitting to EvalShardedByInnerDimContext
This commit is contained in:
parent
25230d1862
commit
bb7ccac3af
@ -1159,16 +1159,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
template <int Alignment>
|
||||
void run() {
|
||||
Barrier barrier(internal::convert_index<int>(num_blocks));
|
||||
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
||||
evaluator->m_device.enqueueNoNotification(
|
||||
[this, block_idx, &barrier]() {
|
||||
Index block_start = block_idx * block_size;
|
||||
Index block_end = block_start + actualBlockSize(block_idx);
|
||||
|
||||
processBlock<Alignment>(block_idx, block_start, block_end);
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
eval<Alignment>(barrier, 0, num_blocks);
|
||||
barrier.Wait();
|
||||
|
||||
// Aggregate partial sums from l0 ranges.
|
||||
@ -1180,38 +1171,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
|
||||
template <int Alignment>
|
||||
void runAsync() {
|
||||
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
||||
evaluator->m_device.enqueueNoNotification([this, block_idx]() {
|
||||
Index block_start = block_idx * block_size;
|
||||
Index block_end = block_start + actualBlockSize(block_idx);
|
||||
|
||||
processBlock<Alignment>(block_idx, block_start, block_end);
|
||||
|
||||
int v = num_pending_blocks.fetch_sub(1);
|
||||
eigen_assert(v >= 1);
|
||||
|
||||
if (v == 1) {
|
||||
// Aggregate partial sums from l0 ranges.
|
||||
aggregateL0Blocks<Alignment>();
|
||||
|
||||
// Apply output kernel.
|
||||
applyOutputKernel();
|
||||
|
||||
// NOTE: If we call `done` callback before deleting this (context),
|
||||
// it might deallocate Self* pointer captured by context, and we'll
|
||||
// fail in destructor trying to deallocate temporary buffers.
|
||||
|
||||
// Move done call back from context before it will be destructed.
|
||||
DoneCallback done_copy = std::move(done);
|
||||
|
||||
// We are confident that we are the last one who touches context.
|
||||
delete this;
|
||||
|
||||
// Now safely call the done callback.
|
||||
done_copy();
|
||||
}
|
||||
});
|
||||
}
|
||||
evalAsync<Alignment>(0, num_blocks);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1405,6 +1365,68 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
}
|
||||
|
||||
template <int Alignment>
|
||||
void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
|
||||
while (end_block_idx - start_block_idx > 1) {
|
||||
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
|
||||
evaluator->m_device.enqueueNoNotification(
|
||||
[this, &barrier, mid_block_idx, end_block_idx]() {
|
||||
eval<Alignment>(barrier, mid_block_idx, end_block_idx);
|
||||
});
|
||||
end_block_idx = mid_block_idx;
|
||||
}
|
||||
|
||||
Index block_idx = start_block_idx;
|
||||
Index block_start = block_idx * block_size;
|
||||
Index block_end = block_start + actualBlockSize(block_idx);
|
||||
|
||||
processBlock<Alignment>(block_idx, block_start, block_end);
|
||||
barrier.Notify();
|
||||
}
|
||||
|
||||
template <int Alignment>
|
||||
void evalAsync(Index start_block_idx, Index end_block_idx) {
|
||||
while (end_block_idx - start_block_idx > 1) {
|
||||
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
|
||||
evaluator->m_device.enqueueNoNotification(
|
||||
[this, mid_block_idx, end_block_idx]() {
|
||||
evalAsync<Alignment>(mid_block_idx, end_block_idx);
|
||||
});
|
||||
end_block_idx = mid_block_idx;
|
||||
}
|
||||
|
||||
Index block_idx = start_block_idx;
|
||||
|
||||
Index block_start = block_idx * block_size;
|
||||
Index block_end = block_start + actualBlockSize(block_idx);
|
||||
|
||||
processBlock<Alignment>(block_idx, block_start, block_end);
|
||||
|
||||
int v = num_pending_blocks.fetch_sub(1);
|
||||
eigen_assert(v >= 1);
|
||||
|
||||
if (v == 1) {
|
||||
// Aggregate partial sums from l0 ranges.
|
||||
aggregateL0Blocks<Alignment>();
|
||||
|
||||
// Apply output kernel.
|
||||
applyOutputKernel();
|
||||
|
||||
// NOTE: If we call `done` callback before deleting this (context),
|
||||
// it might deallocate Self* pointer captured by context, and we'll
|
||||
// fail in destructor trying to deallocate temporary buffers.
|
||||
|
||||
// Move done call back from context before it will be destructed.
|
||||
DoneCallback done_copy = std::move(done);
|
||||
|
||||
// We are confident that we are the last one who touches context.
|
||||
delete this;
|
||||
|
||||
// Now safely call the done callback.
|
||||
done_copy();
|
||||
}
|
||||
}
|
||||
|
||||
// Cost model doesn't capture well the cost associated with constructing
|
||||
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs
|
||||
// and gemm_pack_rhs, so we specify minimum desired block size.
|
||||
|
Loading…
Reference in New Issue
Block a user