Optimize evalShardedByInnerDim

This commit is contained in:
Eugene Zhulenev 2019-01-08 16:26:31 -08:00
parent 190d053e41
commit e70ffef967

View File

@ -756,6 +756,36 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
}
template <int Alignment>
EIGEN_STRONG_INLINE void addAllToBuffer(size_t n, const Scalar* src_buf0,
const Scalar* src_buf1,
const Scalar* src_buf2,
Scalar* dst_buf) const {
using ::Eigen::internal::padd;
using ::Eigen::internal::pload;
using ::Eigen::internal::ploadt;
using ::Eigen::internal::pstoret;
const int output_packet_size =
internal::unpacket_traits<PacketReturnType>::size;
size_t i = 0;
const size_t num_packets = n / output_packet_size;
for (; i < output_packet_size * num_packets; i += output_packet_size) {
const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
}
for (; i < n; ++i) {
dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
}
}
// Decide whether we want to shard m x k x n contraction over the inner
// (contraction) dimension (k).
static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
@ -788,50 +818,145 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
const Index packet_size = internal::packet_traits<RhsScalar>::size;
const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
// We will compute partial results into the buffers of this size.
const Index buffer_size_bytes = m * n * sizeof(Scalar);
// The underlying GEMM kernel assumes that k is a multiple of
// the packet size and subtle breakage occurs if this is violated.
Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
Index num_blocks = divup<Index>(k, block_size);
// we use 'result' for the first block's partial result.
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
Barrier barrier(internal::convert_index<int>(num_blocks));
auto process_block = [=, &barrier](Scalar* buf, Index begin, Index end) {
::memset(buf, 0, m * n * sizeof(Scalar));
const Index packet_size = internal::packet_traits<RhsScalar>::size;
const auto round_up = [=](Index index) -> Index {
const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
return divup<Index>(index, kmultiple) * kmultiple;
};
// Cost model doesn't capture well the cost associated with constructing
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs and
// gemm_pack_rhs, so we specify minimum desired block size.
const Index target_block_size = round_up(divup<Index>(k, num_threads));
const Index desired_min_block_size = 12 * packet_size;
const Index block_size = numext::mini<Index>(
k, numext::maxi<Index>(desired_min_block_size, target_block_size));
const Index num_blocks = divup<Index>(k, block_size);
// Compute block size with accounting for potentially incomplete last block.
const auto actual_block_size = [=](Index block_idx) -> Index {
return block_idx + 1 < num_blocks
? block_size
: k + block_size - block_size * num_blocks;
};
// We compute partial gemm results in parallel, and to get the final result
// we need to add them all together. For the large number of threads (>= 48)
// this adds a very expensive sequential step at the end.
//
// We split the [0, num_blocks) into small ranges, and when a task for the
// block finishes its partial gemm computation, it checks if it was the last
// gemm in the range, and if so, it will add all blocks of the range.
//
// After all tasks finihes, we need to add only these pre-aggregated blocks.
// Compute range size with accounting for potentially incomplete last range.
const auto actual_range_size = [=](Index num_ranges, Index range_size,
Index range_idx) -> Index {
eigen_assert(range_idx < num_ranges);
return range_idx + 1 < num_ranges
? range_size
: num_blocks + range_size - range_size * num_ranges;
};
// For now we use just a single level of ranges to compute pre-aggregated
// partial sums, but in general we can use more layers to compute tree
// aggregation in parallel and reduce the size of the sequential step.
//
// TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
// sense only if number of threads >= ~128?
static const Index l0_size = 4;
const Index l0_ranges = divup<Index>(num_blocks, l0_size);
// Keep count of pending gemm tasks for each l0 range.
MaxSizeVector<std::atomic<int>> l0_state(l0_ranges);
for (int i = 0; i < l0_ranges; ++i) {
l0_state.emplace_back(actual_range_size(l0_ranges, l0_size, i));
}
MaxSizeVector<Scalar*> block_buffers(num_blocks);
auto process_block = [&, this](Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
::memset(buf, 0, buffer_size_bytes);
TENSOR_CONTRACTION_DISPATCH(
this->template evalGemmPartialWithoutOutputKernel, Alignment,
(buf, begin, end, this->m_device.numThreads()));
barrier.Notify();
(buf, begin, end, /*num_threads=*/num_blocks));
// Check if it was the last task in l0 range.
const Index l0_index = block_idx / l0_size;
const int v = l0_state[l0_index].fetch_sub(1);
eigen_assert(v >= 1);
// If we processed the last block of the range, we can aggregate all
// partial results into the first block of the range.
if (v == 1) {
const Index rng_size = actual_range_size(l0_ranges, l0_size, l0_index);
const Index dst_block_idx = l0_index * l0_size;
if (rng_size == l0_size) {
addAllToBuffer<Alignment>(
m * n,
/*src_buf0=*/block_buffers[dst_block_idx + 1],
/*src_buf1=*/block_buffers[dst_block_idx + 2],
/*src_buf2=*/block_buffers[dst_block_idx + 3],
/*dst_buf= */ block_buffers[dst_block_idx]);
} else {
// Aggregate blocks of potentially incomplete last range.
for (int i = 1; i < rng_size; ++i) {
addToBuffer<Alignment>(m * n,
/*src_buf=*/block_buffers[dst_block_idx + i],
/*dst_buf=*/block_buffers[dst_block_idx]);
}
}
}
};
Index start = 0;
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
// The underlying GEMM kernel assumes that k is a multiple of packet size
// (currently largest packet size is 16) and subtle breakage occurs if
// this is violated.
block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
Scalar* buf;
if (start == 0) {
buf = result;
} else {
buf = static_cast<Scalar*>(
this->m_device.allocate(m * n * sizeof(Scalar)));
block_buffers.push_back(buf);
}
Index end = start + block_size;
if (end > k) {
end = k;
}
this->m_device.enqueueNoNotification(
[=, &process_block]() { process_block(buf, start, end); });
start = end;
Barrier barrier(internal::convert_index<int>(num_blocks));
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
Scalar* buf = block_idx == 0
? result
: static_cast<Scalar*>(
this->m_device.allocate(buffer_size_bytes));
block_buffers.push_back(buf);
Index block_start = block_idx * block_size;
Index block_end = block_start + actual_block_size(block_idx);
this->m_device.enqueueNoNotification([=, &barrier, &process_block]() {
process_block(block_idx, block_start, block_end);
barrier.Notify();
});
}
barrier.Wait();
// Add other partial results into first partial result.
for (const auto& buf : block_buffers) {
addToBuffer<Alignment>(m * n, buf, result);
this->m_device.deallocate(buf);
// Aggregate partial sums from l0 ranges.
Index l0_index = 1;
for (; l0_index + 2 < l0_ranges; l0_index += 3) {
addAllToBuffer<Alignment>(
m * n,
/*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
/*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
/*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
/*dst_buf= */block_buffers[0]);
}
for (; l0_index < l0_ranges; ++l0_index) {
addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
block_buffers[0]);
}
// Don't forget to deallocate ALL temporary buffers.
for (Index i = 1; i < num_blocks; ++i) {
this->m_device.deallocate(block_buffers[i]);
}
// Finally call output kernel with finalized output buffer.