mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Updated the contraction code to make it compatible with half floats.
This commit is contained in:
parent
180156ba1a
commit
670db7988d
@ -99,23 +99,23 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
|
||||
#define prefetchIntoRegisters(base_k) \
|
||||
{ \
|
||||
lhs_pf0 = Scalar(0); \
|
||||
lhs_pf1 = Scalar(0); \
|
||||
lhs_pf2 = Scalar(0); \
|
||||
lhs_pf3 = Scalar(0); \
|
||||
lhs_pf4 = Scalar(0); \
|
||||
lhs_pf5 = Scalar(0); \
|
||||
lhs_pf6 = Scalar(0); \
|
||||
lhs_pf7 = Scalar(0); \
|
||||
lhs_pf0 = conv(0); \
|
||||
lhs_pf1 = conv(0); \
|
||||
lhs_pf2 = conv(0); \
|
||||
lhs_pf3 = conv(0); \
|
||||
lhs_pf4 = conv(0); \
|
||||
lhs_pf5 = conv(0); \
|
||||
lhs_pf6 = conv(0); \
|
||||
lhs_pf7 = conv(0); \
|
||||
\
|
||||
rhs_pf0 = Scalar(0); \
|
||||
rhs_pf1 = Scalar(0); \
|
||||
rhs_pf2 = Scalar(0); \
|
||||
rhs_pf3 = Scalar(0); \
|
||||
rhs_pf4 = Scalar(0); \
|
||||
rhs_pf5 = Scalar(0); \
|
||||
rhs_pf6 = Scalar(0); \
|
||||
rhs_pf7 = Scalar(0); \
|
||||
rhs_pf0 = conv(0); \
|
||||
rhs_pf1 = conv(0); \
|
||||
rhs_pf2 = conv(0); \
|
||||
rhs_pf3 = conv(0); \
|
||||
rhs_pf4 = conv(0); \
|
||||
rhs_pf5 = conv(0); \
|
||||
rhs_pf6 = conv(0); \
|
||||
rhs_pf7 = conv(0); \
|
||||
\
|
||||
if (!needs_edge_check || lhs_vert < m_size) { \
|
||||
const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
|
||||
@ -261,15 +261,16 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
// declare and initialize result array
|
||||
#define res(i, j) _res_##i##j
|
||||
#define initResultRow(i) \
|
||||
Scalar res(i, 0) = Scalar(0); \
|
||||
Scalar res(i, 1) = Scalar(0); \
|
||||
Scalar res(i, 2) = Scalar(0); \
|
||||
Scalar res(i, 3) = Scalar(0); \
|
||||
Scalar res(i, 4) = Scalar(0); \
|
||||
Scalar res(i, 5) = Scalar(0); \
|
||||
Scalar res(i, 6) = Scalar(0); \
|
||||
Scalar res(i, 7) = Scalar(0); \
|
||||
Scalar res(i, 0) = conv(0); \
|
||||
Scalar res(i, 1) = conv(0); \
|
||||
Scalar res(i, 2) = conv(0); \
|
||||
Scalar res(i, 3) = conv(0); \
|
||||
Scalar res(i, 4) = conv(0); \
|
||||
Scalar res(i, 5) = conv(0); \
|
||||
Scalar res(i, 6) = conv(0); \
|
||||
Scalar res(i, 7) = conv(0); \
|
||||
|
||||
internal::scalar_cast_op<int, Scalar> conv;
|
||||
initResultRow(0);
|
||||
initResultRow(1);
|
||||
initResultRow(2);
|
||||
@ -1313,6 +1314,34 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
|
||||
static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
|
||||
const Index m_blocks = (m + 63) / 64;
|
||||
const Index n_blocks = (n + 63) / 64;
|
||||
const dim3 num_blocks(m_blocks, n_blocks, 1);
|
||||
const dim3 block_size(8, 8, 8);
|
||||
LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
|
||||
static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
|
||||
if (m < 768 || n < 768) {
|
||||
const Index m_blocks = (m + 63) / 64;
|
||||
const Index n_blocks = (n + 63) / 64;
|
||||
const dim3 num_blocks(m_blocks, n_blocks, 1);
|
||||
const dim3 block_size(16, 16, 1);
|
||||
LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
|
||||
} else {
|
||||
const Index m_blocks = (m + 127) / 128;
|
||||
const Index n_blocks = (n + 63) / 64;
|
||||
const dim3 num_blocks(m_blocks, n_blocks, 1);
|
||||
const dim3 block_size(8, 32, 1);
|
||||
LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||
void evalTyped(Scalar* buffer) const {
|
||||
// columns in left side, rows in right side
|
||||
@ -1353,28 +1382,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
OutputMapper output(buffer, m);
|
||||
|
||||
setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
|
||||
if (internal::is_same<LhsScalar, float>::value &&
|
||||
internal::is_same<RhsScalar, float>::value) {
|
||||
if (m < 768 || n < 768) {
|
||||
const Index m_blocks = (m + 63) / 64;
|
||||
const Index n_blocks = (n + 63) / 64;
|
||||
const dim3 num_blocks(m_blocks, n_blocks, 1);
|
||||
const dim3 block_size(16, 16, 1);
|
||||
LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k);
|
||||
} else {
|
||||
const Index m_blocks = (m + 127) / 128;
|
||||
const Index n_blocks = (n + 63) / 64;
|
||||
const dim3 num_blocks(m_blocks, n_blocks, 1);
|
||||
const dim3 block_size(8, 32, 1);
|
||||
LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k);
|
||||
}
|
||||
} else {
|
||||
const Index m_blocks = (m + 63) / 64;
|
||||
const Index n_blocks = (n + 63) / 64;
|
||||
const dim3 num_blocks(m_blocks, n_blocks, 1);
|
||||
const dim3 block_size(8, 8, 8);
|
||||
LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k);
|
||||
}
|
||||
LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user