Add recursive work splitting to EvalShardedByInnerDimContext

This commit is contained in:
Eugene Zhulenev 2019-12-05 14:50:19 -08:00
parent 25230d1862
commit bb7ccac3af

View File

@ -1159,16 +1159,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void run() {
Barrier barrier(internal::convert_index<int>(num_blocks));
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
evaluator->m_device.enqueueNoNotification(
[this, block_idx, &barrier]() {
Index block_start = block_idx * block_size;
Index block_end = block_start + actualBlockSize(block_idx);
processBlock<Alignment>(block_idx, block_start, block_end);
barrier.Notify();
});
}
eval<Alignment>(barrier, 0, num_blocks);
barrier.Wait();
// Aggregate partial sums from l0 ranges.
@ -1180,38 +1171,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void runAsync() {
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
evaluator->m_device.enqueueNoNotification([this, block_idx]() {
Index block_start = block_idx * block_size;
Index block_end = block_start + actualBlockSize(block_idx);
processBlock<Alignment>(block_idx, block_start, block_end);
int v = num_pending_blocks.fetch_sub(1);
eigen_assert(v >= 1);
if (v == 1) {
// Aggregate partial sums from l0 ranges.
aggregateL0Blocks<Alignment>();
// Apply output kernel.
applyOutputKernel();
// NOTE: If we call `done` callback before deleting this (context),
// it might deallocate Self* pointer captured by context, and we'll
// fail in destructor trying to deallocate temporary buffers.
// Move done call back from context before it will be destructed.
DoneCallback done_copy = std::move(done);
// We are confident that we are the last one who touches context.
delete this;
// Now safely call the done callback.
done_copy();
}
});
}
evalAsync<Alignment>(0, num_blocks);
}
private:
@ -1405,6 +1365,68 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
}
template <int Alignment>
void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
while (end_block_idx - start_block_idx > 1) {
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
evaluator->m_device.enqueueNoNotification(
[this, &barrier, mid_block_idx, end_block_idx]() {
eval<Alignment>(barrier, mid_block_idx, end_block_idx);
});
end_block_idx = mid_block_idx;
}
Index block_idx = start_block_idx;
Index block_start = block_idx * block_size;
Index block_end = block_start + actualBlockSize(block_idx);
processBlock<Alignment>(block_idx, block_start, block_end);
barrier.Notify();
}
template <int Alignment>
void evalAsync(Index start_block_idx, Index end_block_idx) {
while (end_block_idx - start_block_idx > 1) {
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
evaluator->m_device.enqueueNoNotification(
[this, mid_block_idx, end_block_idx]() {
evalAsync<Alignment>(mid_block_idx, end_block_idx);
});
end_block_idx = mid_block_idx;
}
Index block_idx = start_block_idx;
Index block_start = block_idx * block_size;
Index block_end = block_start + actualBlockSize(block_idx);
processBlock<Alignment>(block_idx, block_start, block_end);
int v = num_pending_blocks.fetch_sub(1);
eigen_assert(v >= 1);
if (v == 1) {
// Aggregate partial sums from l0 ranges.
aggregateL0Blocks<Alignment>();
// Apply output kernel.
applyOutputKernel();
// NOTE: If we call `done` callback before deleting this (context),
// it might deallocate Self* pointer captured by context, and we'll
// fail in destructor trying to deallocate temporary buffers.
// Move done call back from context before it will be destructed.
DoneCallback done_copy = std::move(done);
// We are confident that we are the last one who touches context.
delete this;
// Now safely call the done callback.
done_copy();
}
}
// Cost model doesn't capture well the cost associated with constructing
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs
// and gemm_pack_rhs, so we specify minimum desired block size.