mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Optimize evalShardedByInnerDim
This commit is contained in:
parent
190d053e41
commit
e70ffef967
@ -756,6 +756,36 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
}
|
||||
|
||||
template <int Alignment>
|
||||
EIGEN_STRONG_INLINE void addAllToBuffer(size_t n, const Scalar* src_buf0,
|
||||
const Scalar* src_buf1,
|
||||
const Scalar* src_buf2,
|
||||
Scalar* dst_buf) const {
|
||||
using ::Eigen::internal::padd;
|
||||
using ::Eigen::internal::pload;
|
||||
using ::Eigen::internal::ploadt;
|
||||
using ::Eigen::internal::pstoret;
|
||||
|
||||
const int output_packet_size =
|
||||
internal::unpacket_traits<PacketReturnType>::size;
|
||||
|
||||
size_t i = 0;
|
||||
const size_t num_packets = n / output_packet_size;
|
||||
for (; i < output_packet_size * num_packets; i += output_packet_size) {
|
||||
const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
|
||||
const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
|
||||
const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
|
||||
|
||||
const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
|
||||
const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
|
||||
|
||||
pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
|
||||
}
|
||||
for (; i < n; ++i) {
|
||||
dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Decide whether we want to shard m x k x n contraction over the inner
|
||||
// (contraction) dimension (k).
|
||||
static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
|
||||
@ -788,50 +818,145 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
const Index m = this->m_i_size;
|
||||
const Index n = this->m_j_size;
|
||||
const Index k = this->m_k_size;
|
||||
const Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
|
||||
|
||||
// We will compute partial results into the buffers of this size.
|
||||
const Index buffer_size_bytes = m * n * sizeof(Scalar);
|
||||
|
||||
// The underlying GEMM kernel assumes that k is a multiple of
|
||||
// the packet size and subtle breakage occurs if this is violated.
|
||||
Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
|
||||
Index num_blocks = divup<Index>(k, block_size);
|
||||
// we use 'result' for the first block's partial result.
|
||||
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
|
||||
Barrier barrier(internal::convert_index<int>(num_blocks));
|
||||
auto process_block = [=, &barrier](Scalar* buf, Index begin, Index end) {
|
||||
::memset(buf, 0, m * n * sizeof(Scalar));
|
||||
const Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||
|
||||
const auto round_up = [=](Index index) -> Index {
|
||||
const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
|
||||
return divup<Index>(index, kmultiple) * kmultiple;
|
||||
};
|
||||
|
||||
// Cost model doesn't capture well the cost associated with constructing
|
||||
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs and
|
||||
// gemm_pack_rhs, so we specify minimum desired block size.
|
||||
const Index target_block_size = round_up(divup<Index>(k, num_threads));
|
||||
const Index desired_min_block_size = 12 * packet_size;
|
||||
|
||||
const Index block_size = numext::mini<Index>(
|
||||
k, numext::maxi<Index>(desired_min_block_size, target_block_size));
|
||||
const Index num_blocks = divup<Index>(k, block_size);
|
||||
|
||||
// Compute block size with accounting for potentially incomplete last block.
|
||||
const auto actual_block_size = [=](Index block_idx) -> Index {
|
||||
return block_idx + 1 < num_blocks
|
||||
? block_size
|
||||
: k + block_size - block_size * num_blocks;
|
||||
};
|
||||
|
||||
// We compute partial gemm results in parallel, and to get the final result
|
||||
// we need to add them all together. For the large number of threads (>= 48)
|
||||
// this adds a very expensive sequential step at the end.
|
||||
//
|
||||
// We split the [0, num_blocks) into small ranges, and when a task for the
|
||||
// block finishes its partial gemm computation, it checks if it was the last
|
||||
// gemm in the range, and if so, it will add all blocks of the range.
|
||||
//
|
||||
// After all tasks finihes, we need to add only these pre-aggregated blocks.
|
||||
|
||||
// Compute range size with accounting for potentially incomplete last range.
|
||||
const auto actual_range_size = [=](Index num_ranges, Index range_size,
|
||||
Index range_idx) -> Index {
|
||||
eigen_assert(range_idx < num_ranges);
|
||||
return range_idx + 1 < num_ranges
|
||||
? range_size
|
||||
: num_blocks + range_size - range_size * num_ranges;
|
||||
};
|
||||
|
||||
// For now we use just a single level of ranges to compute pre-aggregated
|
||||
// partial sums, but in general we can use more layers to compute tree
|
||||
// aggregation in parallel and reduce the size of the sequential step.
|
||||
//
|
||||
// TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
|
||||
// sense only if number of threads >= ~128?
|
||||
static const Index l0_size = 4;
|
||||
const Index l0_ranges = divup<Index>(num_blocks, l0_size);
|
||||
|
||||
// Keep count of pending gemm tasks for each l0 range.
|
||||
MaxSizeVector<std::atomic<int>> l0_state(l0_ranges);
|
||||
for (int i = 0; i < l0_ranges; ++i) {
|
||||
l0_state.emplace_back(actual_range_size(l0_ranges, l0_size, i));
|
||||
}
|
||||
|
||||
MaxSizeVector<Scalar*> block_buffers(num_blocks);
|
||||
|
||||
auto process_block = [&, this](Index block_idx, Index begin, Index end) {
|
||||
Scalar* buf = block_buffers[block_idx];
|
||||
::memset(buf, 0, buffer_size_bytes);
|
||||
|
||||
TENSOR_CONTRACTION_DISPATCH(
|
||||
this->template evalGemmPartialWithoutOutputKernel, Alignment,
|
||||
(buf, begin, end, this->m_device.numThreads()));
|
||||
barrier.Notify();
|
||||
(buf, begin, end, /*num_threads=*/num_blocks));
|
||||
|
||||
// Check if it was the last task in l0 range.
|
||||
const Index l0_index = block_idx / l0_size;
|
||||
const int v = l0_state[l0_index].fetch_sub(1);
|
||||
eigen_assert(v >= 1);
|
||||
|
||||
// If we processed the last block of the range, we can aggregate all
|
||||
// partial results into the first block of the range.
|
||||
if (v == 1) {
|
||||
const Index rng_size = actual_range_size(l0_ranges, l0_size, l0_index);
|
||||
const Index dst_block_idx = l0_index * l0_size;
|
||||
|
||||
if (rng_size == l0_size) {
|
||||
addAllToBuffer<Alignment>(
|
||||
m * n,
|
||||
/*src_buf0=*/block_buffers[dst_block_idx + 1],
|
||||
/*src_buf1=*/block_buffers[dst_block_idx + 2],
|
||||
/*src_buf2=*/block_buffers[dst_block_idx + 3],
|
||||
/*dst_buf= */ block_buffers[dst_block_idx]);
|
||||
} else {
|
||||
// Aggregate blocks of potentially incomplete last range.
|
||||
for (int i = 1; i < rng_size; ++i) {
|
||||
addToBuffer<Alignment>(m * n,
|
||||
/*src_buf=*/block_buffers[dst_block_idx + i],
|
||||
/*dst_buf=*/block_buffers[dst_block_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Index start = 0;
|
||||
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
|
||||
// The underlying GEMM kernel assumes that k is a multiple of packet size
|
||||
// (currently largest packet size is 16) and subtle breakage occurs if
|
||||
// this is violated.
|
||||
block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
|
||||
Scalar* buf;
|
||||
if (start == 0) {
|
||||
buf = result;
|
||||
} else {
|
||||
buf = static_cast<Scalar*>(
|
||||
this->m_device.allocate(m * n * sizeof(Scalar)));
|
||||
block_buffers.push_back(buf);
|
||||
}
|
||||
Index end = start + block_size;
|
||||
if (end > k) {
|
||||
end = k;
|
||||
}
|
||||
this->m_device.enqueueNoNotification(
|
||||
[=, &process_block]() { process_block(buf, start, end); });
|
||||
start = end;
|
||||
|
||||
Barrier barrier(internal::convert_index<int>(num_blocks));
|
||||
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
||||
Scalar* buf = block_idx == 0
|
||||
? result
|
||||
: static_cast<Scalar*>(
|
||||
this->m_device.allocate(buffer_size_bytes));
|
||||
block_buffers.push_back(buf);
|
||||
|
||||
Index block_start = block_idx * block_size;
|
||||
Index block_end = block_start + actual_block_size(block_idx);
|
||||
|
||||
this->m_device.enqueueNoNotification([=, &barrier, &process_block]() {
|
||||
process_block(block_idx, block_start, block_end);
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
barrier.Wait();
|
||||
|
||||
// Add other partial results into first partial result.
|
||||
for (const auto& buf : block_buffers) {
|
||||
addToBuffer<Alignment>(m * n, buf, result);
|
||||
this->m_device.deallocate(buf);
|
||||
// Aggregate partial sums from l0 ranges.
|
||||
Index l0_index = 1;
|
||||
for (; l0_index + 2 < l0_ranges; l0_index += 3) {
|
||||
addAllToBuffer<Alignment>(
|
||||
m * n,
|
||||
/*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
|
||||
/*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
|
||||
/*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
|
||||
/*dst_buf= */block_buffers[0]);
|
||||
}
|
||||
for (; l0_index < l0_ranges; ++l0_index) {
|
||||
addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
|
||||
block_buffers[0]);
|
||||
}
|
||||
|
||||
// Don't forget to deallocate ALL temporary buffers.
|
||||
for (Index i = 1; i < num_blocks; ++i) {
|
||||
this->m_device.deallocate(block_buffers[i]);
|
||||
}
|
||||
|
||||
// Finally call output kernel with finalized output buffer.
|
||||
|
Loading…
Reference in New Issue
Block a user