diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h index babe33fff..f6bd949bd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h @@ -1,7 +1,9 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2014 Benoit Steiner +// Copyright (C) 2014-2015 Benoit Steiner +// Copyright (C) 2015 Navdeep Jaitly +// Copyright (C) 2014 Eric Martin // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -19,7 +21,7 @@ template +template __global__ void __launch_bounds__(512) - EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, - const OutputMapper output, - const Index m_size, const Index n_size, const Index k_size) { - __shared__ volatile Scalar lhs_shmem[72 * 64]; - __shared__ volatile Scalar rhs_shmem[72 * 64]; +EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, + const Index m_size, const Index n_size, const Index k_size) { + __shared__ volatile Scalar lhs_shmem[72 * 64]; + __shared__ volatile Scalar rhs_shmem[72 * 64]; - const Index m_block_idx = blockIdx.x; - const Index n_block_idx = blockIdx.y; + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; - const Index base_m = 64 * m_block_idx; - const Index base_n = 64 * n_block_idx; + const Index base_m = 64 * m_block_idx; + const Index base_n = 64 * n_block_idx; - if (base_m + 63 < m_size && base_n + 63 < n_size) { - EigenContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + if (base_m + 63 < m_size && base_n + 63 < n_size) { + EigenContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + } else { + EigenContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + } +} + + +template +__device__ EIGEN_STRONG_INLINE void +EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, float2 lhs_shmem2[][16], + float2 rhs_shmem2[][8], const Index m_size, + const Index n_size, const Index k_size, + const Index base_m, const Index base_n) { + typedef float Scalar; + + // prefetch registers + float4 lhs_pf0, rhs_pf0; + + float4 results[4]; + for (int i=0; i < 4; i++) { + results[i].x = results[i].y = results[i].z = results[i].w = 0; + } + + +#define prefetch_lhs(reg, row, col) \ + if (!CHECK_LHS_BOUNDARY) { \ + if (col < k_size) { \ + reg =lhs.loadPacket(row, col); \ + } \ + } else { \ + if (col < k_size) { \ + if (row + 3 < m_size) { \ + reg =lhs.loadPacket(row, col); \ + } else if (row + 2 < m_size) { \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + reg.z =lhs(row + 2, col); \ + } else if (row + 1 < m_size) { \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + } else if (row < m_size) { \ + reg.x =lhs(row + 0, col); \ + } \ + } \ + } \ + + + Index lhs_vert = base_m+threadIdx.x*4; + + for (Index k = 0; k < k_size; k += 16) { + lhs_pf0 = internal::pset1(0); + rhs_pf0 = internal::pset1(0); + + Index lhs_horiz = threadIdx.y+k; + prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz) + + Index rhs_vert = k+(threadIdx.x%4)*4; + Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n; + + if (!CHECK_RHS_BOUNDARY) { + if ((rhs_vert + 3) < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0); + } else if (rhs_vert + 2 < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); + } else if (rhs_vert + 1 < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + } else if (rhs_vert < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + } } else { - EigenContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + if (rhs_horiz0 < n_size) { + if ((rhs_vert + 3) < k_size) { + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0); + } else if ((rhs_vert + 2) < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); + } else if ((rhs_vert + 1) < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + } else if (rhs_vert < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + } + } + } + float x1, x2 ; + // the following can be a bitwise operation..... some day. + if((threadIdx.x%8) < 4) { + x1 = rhs_pf0.y; + x2 = rhs_pf0.w; + } else { + x1 = rhs_pf0.x; + x2 = rhs_pf0.z; + } + x1 = __shfl_xor(x1, 4); + x2 = __shfl_xor(x2, 4); + if((threadIdx.x%8) < 4) { + rhs_pf0.y = x1; + rhs_pf0.w = x2; + } else { + rhs_pf0.x = x1; + rhs_pf0.z = x2; + } + + // We have 64 features. + // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1. + // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3. + // ... + // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63 + // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1 + // ... + rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y); + rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w); + + // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) + // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) + // ... + // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) + // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) + // ... + + lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y); + lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w); + + +#define add_vals(fl1, fl2, fr1, fr2)\ + results[0].x += fl1.x * fr1.x;\ + results[0].y += fl1.y * fr1.x;\ + results[0].z += fl2.x * fr1.x;\ + results[0].w += fl2.y * fr1.x;\ +\ + results[1].x += fl1.x * fr1.y;\ + results[1].y += fl1.y * fr1.y;\ + results[1].z += fl2.x * fr1.y;\ + results[1].w += fl2.y * fr1.y;\ +\ + results[2].x += fl1.x * fr2.x;\ + results[2].y += fl1.y * fr2.x;\ + results[2].z += fl2.x * fr2.x;\ + results[2].w += fl2.y * fr2.x;\ +\ + results[3].x += fl1.x * fr2.y;\ + results[3].y += fl1.y * fr2.y;\ + results[3].z += fl2.x * fr2.y;\ + results[3].w += fl2.y * fr2.y;\ + + __syncthreads(); + + // Do the multiplies. + #pragma unroll + for (int koff = 0; koff < 16; koff ++) { + // 32 x threads. + float2 fl1 = lhs_shmem2[koff][threadIdx.x]; + float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x]; + + int start_feature = threadIdx.y * 4; + float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; + float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; + + add_vals(fl1, fl2, fr1, fr2) + } + __syncthreads(); + } + +#undef prefetch_lhs +#undef add_vals + + Index horiz_base = threadIdx.y*4+base_n; + if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { + for (int i = 0; i < 4; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } else if (!CHECK_RHS_BOUNDARY) { + // CHECK LHS + if (lhs_vert + 3 < m_size) { + for (int i = 0; i < 4; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } else if (lhs_vert + 2 < m_size) { + for (int i = 0; i < 4; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + } + } else if (lhs_vert + 1 < m_size) { + for (int i = 0; i < 4; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + } + } else if (lhs_vert < m_size) { + for (int i = 0; i < 4; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + } + } + } else if (!CHECK_LHS_BOUNDARY) { + // CHECK RHS + /* + int ncols_rem = fminf(n_size- horiz_base, 4); + for (int i = 0; i < ncols_rem; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + }*/ + for (int i = 0; i < 4; i++) { + if (horiz_base+i < n_size) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } + } else { + // CHECK both boundaries. + for (int i = 0; i < 4; i++) { + if (horiz_base+i < n_size) { + if (lhs_vert < m_size) + output(lhs_vert, horiz_base + i) = results[i].x; + if (lhs_vert + 1 < m_size) + output(lhs_vert + 1, horiz_base + i) = results[i].y; + if (lhs_vert + 2 < m_size) + output(lhs_vert + 2, horiz_base + i) = results[i].z; + if (lhs_vert + 3 < m_size) + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } + } +} + + +template +__device__ EIGEN_STRONG_INLINE void +EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, float2 lhs_shmem2[][32], + float2 rhs_shmem2[][8], const Index m_size, + const Index n_size, const Index k_size, + const Index base_m, const Index base_n) { + typedef float Scalar; + + // prefetch registers + float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3; + float4 rhs_pf0, rhs_pf1; + + float4 results[8]; + for (int i=0; i < 8; i++) { + results[i].x = results[i].y = results[i].z = results[i].w = 0; + } + + + Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32; + for (Index k = 0; k < k_size; k += 32) { + lhs_pf0 = internal::pset1(0); + lhs_pf1 = internal::pset1(0); + lhs_pf2 = internal::pset1(0); + lhs_pf3 = internal::pset1(0); + + rhs_pf0 = internal::pset1(0); + rhs_pf1 = internal::pset1(0); + + if (!CHECK_LHS_BOUNDARY) { + if ((threadIdx.y/4+k+24) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8)); + lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16)); + lhs_pf3 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+24)); + } else if ((threadIdx.y/4+k+16) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8)); + lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16)); + } else if ((threadIdx.y/4+k+8) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8)); + } else if ((threadIdx.y/4+k) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + } + } else { + // just CHECK_LHS_BOUNDARY + if (lhs_vert + 3 < m_size) { + if ((threadIdx.y/4+k+24) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8)); + lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16)); + lhs_pf3 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+24)); + } else if ((threadIdx.y/4+k+16) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8)); + lhs_pf2 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+16)); + } else if ((threadIdx.y/4+k+8) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + lhs_pf1 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k+8)); + } else if ((threadIdx.y/4+k) < k_size) { + lhs_pf0 =lhs.loadPacket(lhs_vert, (threadIdx.y/4+k)); + } + } else if (lhs_vert + 2 < m_size) { + if ((threadIdx.y/4+k+24) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); + lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); + lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); + lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); + lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); + lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); + lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); + lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24)); + } else if ((threadIdx.y/4+k+16) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); + lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); + lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); + lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); + lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); + } else if ((threadIdx.y/4+k+8) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); + lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); + } else if ((threadIdx.y/4+k) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); + } + } else if (lhs_vert + 1 < m_size) { + if ((threadIdx.y/4+k+24) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); + lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); + lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); + lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); + lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); + } else if ((threadIdx.y/4+k+16) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); + lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); + lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); + } else if ((threadIdx.y/4+k+8) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); + } else if ((threadIdx.y/4+k) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); + } + } else if (lhs_vert < m_size) { + if ((threadIdx.y/4+k+24) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); + lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); + } else if ((threadIdx.y/4+k+16) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); + } else if ((threadIdx.y/4+k+8) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); + } else if ((threadIdx.y/4+k) < k_size) { + lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); + } + } + } + __syncthreads(); + Index rhs_vert = k+threadIdx.x*4; + Index rhs_horiz0 = threadIdx.y*2+base_n; + Index rhs_horiz1 = threadIdx.y*2+1+base_n; + if (!CHECK_RHS_BOUNDARY) { + if ((rhs_vert + 3) < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0); + rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz1); + } else if (rhs_vert + 2 < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); + rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); + rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); + rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); + } else if (rhs_vert + 1 < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); + rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); + } else if (rhs_vert < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); + } + } else { + if (rhs_horiz1 < n_size) { + if ((rhs_vert + 3) < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0); + rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz1); + } else if (rhs_vert + 2 < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); + rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); + rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); + rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); + } else if (k+threadIdx.x*4 + 1 < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); + rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); + } else if (k+threadIdx.x*4 < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); + } + } else if (rhs_horiz0 < n_size) { + if ((rhs_vert + 3) < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz0); + } else if ((rhs_vert + 2) < k_size) { + // just CHECK_RHS_BOUNDARY + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); + } else if ((rhs_vert + 1) < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); + } else if (rhs_vert < k_size) { + rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); + } + } + } + __syncthreads(); + // Loaded. Do computation + // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1. + // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3. + // .. + // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63 + rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x); + // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1. + // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3. + // .. + rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y); + // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1. + // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3. + rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z); + // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1. + // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3. + rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w); + + // LHS. + // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) + // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) + // ... + // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) + // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) + + +#define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\ + results[0].x += a_feat1.x * f1.x;\ + results[1].x += a_feat1.x * f1.y;\ + results[2].x += a_feat1.x * f2.x;\ + results[3].x += a_feat1.x * f2.y;\ + results[4].x += a_feat1.x * f3.x;\ + results[5].x += a_feat1.x * f3.y;\ + results[6].x += a_feat1.x * f4.x;\ + results[7].x += a_feat1.x * f4.y;\ +\ + results[0].y += a_feat1.y * f1.x;\ + results[1].y += a_feat1.y * f1.y;\ + results[2].y += a_feat1.y * f2.x;\ + results[3].y += a_feat1.y * f2.y;\ + results[4].y += a_feat1.y * f3.x;\ + results[5].y += a_feat1.y * f3.y;\ + results[6].y += a_feat1.y * f4.x;\ + results[7].y += a_feat1.y * f4.y;\ +\ + results[0].z += a_feat2.x * f1.x;\ + results[1].z += a_feat2.x * f1.y;\ + results[2].z += a_feat2.x * f2.x;\ + results[3].z += a_feat2.x * f2.y;\ + results[4].z += a_feat2.x * f3.x;\ + results[5].z += a_feat2.x * f3.y;\ + results[6].z += a_feat2.x * f4.x;\ + results[7].z += a_feat2.x * f4.y;\ +\ + results[0].w += a_feat2.y * f1.x;\ + results[1].w += a_feat2.y * f1.y;\ + results[2].w += a_feat2.y * f2.x;\ + results[3].w += a_feat2.y * f2.y;\ + results[4].w += a_feat2.y * f3.x;\ + results[5].w += a_feat2.y * f3.y;\ + results[6].w += a_feat2.y * f4.x;\ + results[7].w += a_feat2.y * f4.y;\ + + lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y); + lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y); + lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y); + lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y); + + lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w); + lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w); + lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w); + lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w); + + __syncthreads(); + + // Do the multiplies. + #pragma unroll + for (int koff = 0; koff < 32; koff ++) { + float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8]; + float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8]; + + // first feature is at (threadIdx.y/4) * 8 last is at start + 8. + int start_feature = (threadIdx.y / 4) * 8; + + float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4]; + float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4]; + float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4]; + float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4]; + + add_vals(a3, a4, br1, br2, br3, br4) + } + __syncthreads(); + } // end loop over k + + + __syncthreads(); + Index horiz_base = (threadIdx.y/4)*8+base_n; + if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { + for (int i = 0; i < 8; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } else if (!CHECK_RHS_BOUNDARY) { + if (lhs_vert + 3 < m_size) { + for (int i = 0; i < 8; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } else if (lhs_vert + 2 < m_size) { + for (int i = 0; i < 8; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + } + } else if (lhs_vert + 1 < m_size) { + for (int i = 0; i < 8; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + } + } else if (lhs_vert < m_size) { + for (int i = 0; i < 8; i++) { + output(lhs_vert, horiz_base + i) = results[i].x; + } + } + } else if (!CHECK_LHS_BOUNDARY) { + // CHECK BOUNDARY_B + for (int i = 0; i < 8; i++) { + if (horiz_base + i < n_size) { + output(lhs_vert, horiz_base + i) = results[i].x; + output(lhs_vert + 1, horiz_base + i) = results[i].y; + output(lhs_vert + 2, horiz_base + i) = results[i].z; + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } + } else { + // CHECK both boundaries. + for (int i = 0; i < 8; i++) { + if (horiz_base + i < n_size) { + if (lhs_vert < m_size) + output(lhs_vert, horiz_base + i) = results[i].x; + if (lhs_vert + 1 < m_size) + output(lhs_vert + 1, horiz_base + i) = results[i].y; + if (lhs_vert + 2 < m_size) + output(lhs_vert + 2, horiz_base + i) = results[i].z; + if (lhs_vert + 3 < m_size) + output(lhs_vert + 3, horiz_base + i) = results[i].w; + } + } + } +} + + +template +__global__ void +__launch_bounds__(256) +EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, + const Index m_size, const Index n_size, const Index k_size) { + __shared__ float2 lhs_shmem[64*32]; + __shared__ float2 rhs_shmem[128*8]; + + typedef float2 LHS_MEM[64][32]; + typedef float2 RHS_MEM[128][8]; + + typedef float2 LHS_MEM16x16[32][16]; + typedef float2 RHS_MEM16x16[64][8]; + + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; + + const Index base_m = 128 * m_block_idx; + const Index base_n = 64 * n_block_idx; + + bool check_rhs = (base_n + 63) >= n_size; + bool check_lhs128 = (base_m + 127) >= m_size; + bool check_lhs64 = (base_m + 63) >= m_size; + + if (!check_rhs) { + if (!check_lhs128) { + // >= 128 rows left + EigenFloatContractionKernelInternal( + lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); + } else { + EigenFloatContractionKernelInternal( + lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); + } + } else { + if (!check_lhs128) { + // >= 128 rows left + EigenFloatContractionKernelInternal( + lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); + } else { + EigenFloatContractionKernelInternal( + lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); + } + } +} + +template +__global__ void +__launch_bounds__(256) +EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, + const Index m_size, const Index n_size, const Index k_size) { + __shared__ float2 lhs_shmem[32][16]; + __shared__ float2 rhs_shmem[64][8]; + + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; + + const Index base_m = 64 * m_block_idx; + const Index base_n = 64 * n_block_idx; + + if (base_m + 63 < m_size) { + if (base_n + 63 < n_size) { + EigenFloatContractionKernelInternal16x16(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); + } else { + EigenFloatContractionKernelInternal16x16(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); + } + } else { + if (base_n + 63 < n_size) { + EigenFloatContractionKernelInternal16x16(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); + } else { + EigenFloatContractionKernelInternal16x16(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); + } + } +} + + +template +struct TensorEvaluator, GpuDevice> : + public TensorContractionEvaluatorBase, GpuDevice> > { + + typedef GpuDevice Device; + + typedef TensorEvaluator, Device> Self; + typedef TensorContractionEvaluatorBase Base; + + typedef TensorContractionOp XprType; + typedef typename internal::remove_const::type Scalar; + typedef typename XprType::Packet Packet; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + enum { + Layout = TensorEvaluator::Layout, + }; + + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; + + static const int LDims = + internal::array_size::Dimensions>::value; + static const int RDims = + internal::array_size::Dimensions>::value; + static const int ContractDims = internal::array_size::value; + + typedef array left_dim_mapper_t; + typedef array right_dim_mapper_t; + + typedef array contract_t; + typedef array::size> left_nocontract_t; + typedef array::size> right_nocontract_t; + + static const int NumDims = max_n_1::size; + + typedef DSizes Dimensions; + + // typedefs needed in evalTo + typedef typename internal::remove_const::type LhsScalar; + typedef typename internal::remove_const::type RhsScalar; + + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; + + typedef typename LeftEvaluator::Dimensions LeftDimensions; + typedef typename RightEvaluator::Dimensions RightDimensions; + + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : + Base(op, device) {} + + // We need to redefine this method to make nvcc happy + EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { + this->m_leftImpl.evalSubExprsIfNeeded(NULL); + this->m_rightImpl.evalSubExprsIfNeeded(NULL); + if (data) { + evalTo(data); + return false; + } else { + this->m_result = static_cast(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); + evalTo(this->m_result); + return true; } } - - - template -__device__ EIGEN_STRONG_INLINE void - EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, - const OutputMapper output, float4* lhs_shmem4, float2* rhs_shmem2, - const Index m_size, const Index n_size, const Index k_size) { - typedef float Scalar; - - const Index m_block_idx = blockIdx.x; - const Index n_block_idx = blockIdx.y; - - const Index base_m = 64 * m_block_idx; - const Index base_n = 64 * n_block_idx; - - const Index lane = threadIdx.x + 8 * (threadIdx.y % 4); - - // prefetch registers - float4 lhs_pf0; - float4 lhs_pf1; - - float4 rhs_pf0; - float4 rhs_pf1; - - // shared memory is formatted - // (contract idx in block, nocontract idx in block, block idx) - // where block idx is column major. This transposition limits the number of - // bank conflicts when reading the LHS. The core idea is that since the contracting - // index is shared by both sides, then the contracting index should be in threadIdx.x. - - // all of these indices assume float4 loading - // this thread loads the float4 starting at this index, and then also loads - // another float4 starting 32 columns to to the right - const Index horiz_block_idx = threadIdx.z / 2; - const Index vert_block_idx = threadIdx.x / 2 + 4 * (threadIdx.y % 2); - const Index horiz_idx_in_block = threadIdx.y / 2 + 4 * (threadIdx.z % 2); - const Index vert_idx_in_block = threadIdx.x % 2; - - // there's padding in both the LHS and RHS shared memory layouts. This padding - // allows for 0 bank conflicts on all shmem stores and loads. - // LHS padding: 1 float4 on each 8x8 block of floats - // RHS padding: 1 float2 on each block, and 12 additional float2s between vertical blocks - // 3 and 4 - - // storage indices - // lhs index with respect to float4s - const Index lhs_store_idx_base = - 136 * horiz_block_idx + - 17 * vert_block_idx + - 8 * vert_idx_in_block + - horiz_idx_in_block; - - // rhs index with respect to floats - const Index rhs_store_idx_base = - 552 * horiz_block_idx + - 66 * vert_block_idx + - 32 * (horiz_idx_in_block / 4) + (horiz_idx_in_block % 4) + - 16 * vert_idx_in_block + - ((vert_block_idx < 4) ? 0 : 24); - - const Index lhs_store_idx_0 = lhs_store_idx_base + 544 * 0; - const Index lhs_store_idx_1 = lhs_store_idx_base + 544 * 1; - - const Index rhs_store_idx_0 = (rhs_store_idx_base / 2) + ((lane < 16) ? 0 : 4); - const Index rhs_store_idx_1 = rhs_store_idx_0 + 2; - const Index rhs_store_idx_2 = rhs_store_idx_0 + 1104; - const Index rhs_store_idx_3 = rhs_store_idx_1 + 1104; - - // The below diagrams show which shmem index (with respect to floats) each element - // in an 8x8 input block gets packed into: - // LHS: - // 0 4 8 12 16 20 24 28 - // 1 5 9 13 17 21 25 29 - // 2 6 10 14 18 22 26 30 - // 3 7 11 15 19 23 27 31 - // 32 36 40 44 48 52 56 60 - // ... (pack as 2 rows of float4 indexed row major, each float4 is vertical) - // - // RHS: - // 0 1 2 3 32 33 34 35 - // 4 5 6 7 36 37 38 39 - // ... (pack as 2 cols of float4 indexed col major, each float4 is horizontal) - - // Each thread in a warp loads 2 float4s. This happens in 2 instructions. On each of these - // instruction, the warp loads 2 columns (2 cols * 64 elements / col = 128 elements = 32 threads - // * 4 elements/thread). For the LHS, we're able to store the loaded float4 directly into - // shmem (using a 128 bit store instruction). For the RHS, we need to transpose the data. - // This is done with warp shuffles. Furthermore, we only use 64 bit stores for the RHS, because - // 64 bits is only 2 columns (which is all we load in a warp), and the padding for the RHS - // doesn't meet 64 bit alignment requirements (namely, the 4 consecutive floats that we want - // to load on the RHS are 8 byte aligned, not 16 byte aligned, which is required for float4). - - const Index load_idx_vert = 4 * (threadIdx.x + 8 * (threadIdx.y % 2)); - const Index load_idx_horiz = (threadIdx.y / 2) + 4 * threadIdx.z; - - const Index lhs_vert = base_m + load_idx_vert; - const Index rhs_horiz_0 = base_n + load_idx_horiz; - const Index rhs_horiz_1 = base_n + load_idx_horiz + 32; - -#define prefetchIntoRegisters(base_k) \ - { \ - lhs_pf0 = internal::pset1(0); \ - lhs_pf1 = internal::pset1(0); \ - \ - rhs_pf0 = internal::pset1(0); \ - rhs_pf1 = internal::pset1(0); \ - \ - const Index lhs_horiz_0 = base_k + load_idx_horiz; \ - const Index lhs_horiz_1 = base_k + load_idx_horiz + 32; \ - if (!needs_edge_check || lhs_vert + 3 < m_size) { \ - if (lhs_horiz_1 < k_size) { \ - lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \ - lhs_pf1 = lhs.loadPacket(lhs_vert, lhs_horiz_1); \ - } else if (lhs_horiz_0 < k_size) { \ - lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \ - } \ - } else if (lhs_vert + 2 < m_size) { \ - if (lhs_horiz_1 < k_size) { \ - lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ - lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ - lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \ - \ - lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ - lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \ - lhs_pf1.z = lhs(lhs_vert + 2, lhs_horiz_1); \ - } else if (lhs_horiz_0 < k_size) { \ - lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ - lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ - lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \ - } \ - } else if (lhs_vert + 1 < m_size) { \ - if (lhs_horiz_1 < k_size) { \ - lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ - lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ - \ - lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ - lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \ - } else if (lhs_horiz_0 < k_size) { \ - lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ - lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ - } \ - } else if (lhs_vert < m_size) { \ - if (lhs_horiz_1 < k_size) { \ - lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ - lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ - } else if (lhs_horiz_0 < k_size) { \ - lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ - } \ -} \ - \ - const Index rhs_vert = base_k + load_idx_vert; \ - if (rhs_vert + 3 < k_size) { \ - if (!needs_edge_check || rhs_horiz_1 < n_size) { \ - rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \ - rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz_1); \ - } else if (rhs_horiz_0 < n_size) { \ - rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \ - } \ - } else if (rhs_vert + 2 < k_size) { \ - if (!needs_edge_check || rhs_horiz_1 < n_size) { \ - rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ - rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ - rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \ - \ - rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ - rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \ - rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz_1); \ - } else if (rhs_horiz_0 < n_size) { \ - rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ - rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ - rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \ - } \ - } else if (rhs_vert + 1 < k_size) { \ - if (!needs_edge_check || rhs_horiz_1 < n_size) { \ - rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ - rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ - \ - rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ - rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \ - } else if (rhs_horiz_0 < n_size) { \ - rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ - rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ - } \ - } else if (rhs_vert < k_size) { \ - if (!needs_edge_check || rhs_horiz_1 < n_size) { \ - rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ - rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ - } else if (rhs_horiz_0 < n_size) { \ - rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ - } \ -} \ - \ - float swap_val0 = (lane < 16) ? rhs_pf0.z : rhs_pf0.x; \ - float swap_val1 = (lane < 16) ? rhs_pf0.w : rhs_pf0.y; \ - float swap_val2 = (lane < 16) ? rhs_pf1.z : rhs_pf1.x; \ - float swap_val3 = (lane < 16) ? rhs_pf1.w : rhs_pf1.y; \ - \ - swap_val0 = __shfl_xor(swap_val0, 16); \ - swap_val1 = __shfl_xor(swap_val1, 16); \ - swap_val2 = __shfl_xor(swap_val2, 16); \ - swap_val3 = __shfl_xor(swap_val3, 16); \ - \ - if (lane < 16) { \ - rhs_pf0.z = swap_val0; \ - rhs_pf0.w = swap_val1; \ - rhs_pf1.z = swap_val2; \ - rhs_pf1.w = swap_val3; \ - } else { \ - rhs_pf0.x = swap_val0; \ - rhs_pf0.y = swap_val1; \ - rhs_pf1.x = swap_val2; \ - rhs_pf1.y = swap_val3; \ - } \ -} \ - - -#define writeRegToShmem(_) \ - lhs_shmem4[lhs_store_idx_0] = lhs_pf0; \ - \ - rhs_shmem2[rhs_store_idx_0] = make_float2(rhs_pf0.x, rhs_pf0.z); \ - rhs_shmem2[rhs_store_idx_1] = make_float2(rhs_pf0.y, rhs_pf0.w); \ - \ - lhs_shmem4[lhs_store_idx_1] = lhs_pf1; \ - \ - rhs_shmem2[rhs_store_idx_2] = make_float2(rhs_pf1.x, rhs_pf1.z); \ - rhs_shmem2[rhs_store_idx_3] = make_float2(rhs_pf1.y, rhs_pf1.w); \ - - // declare and initialize result array -#define res(i, j) _res_##i##j -#define initResultRow(i) \ - Scalar res(i, 0) = Scalar(0); \ - Scalar res(i, 1) = Scalar(0); \ - Scalar res(i, 2) = Scalar(0); \ - Scalar res(i, 3) = Scalar(0); \ - Scalar res(i, 4) = Scalar(0); \ - Scalar res(i, 5) = Scalar(0); \ - Scalar res(i, 6) = Scalar(0); \ - Scalar res(i, 7) = Scalar(0); \ - - initResultRow(0); - initResultRow(1); - initResultRow(2); - initResultRow(3); - initResultRow(4); - initResultRow(5); - initResultRow(6); - initResultRow(7); -#undef initResultRow - - for (Index base_k = 0; base_k < k_size; base_k += 64) { - // wait for previous iteration to finish with shmem. Despite common sense, - // the code is a bit faster with this here then at bottom of loop - __syncthreads(); - - prefetchIntoRegisters(base_k); - writeRegToShmem(); - -#undef prefetchIntoRegisters -#undef writeRegoToShmem - - // wait for shared mem packing to be done before starting computation - __syncthreads(); - - // compute 8x8 matrix product by outer product. This involves packing one column - // of LHS and one row of RHS into registers (takes 16 registers). - - float4 _lcol0; - float4 _lcol1; - float2 _rrow0; - float2 _rrow1; - float2 _rrow2; - float2 _rrow3; - -#define lcol0 _lcol0.x -#define lcol1 _lcol0.y -#define lcol2 _lcol0.z -#define lcol3 _lcol0.w -#define lcol4 _lcol1.x -#define lcol5 _lcol1.y -#define lcol6 _lcol1.z -#define lcol7 _lcol1.w -#define rrow0 _rrow0.x -#define rrow1 _rrow0.y -#define rrow2 _rrow1.x -#define rrow3 _rrow1.y -#define rrow4 _rrow2.x -#define rrow5 _rrow2.y -#define rrow6 _rrow3.x -#define rrow7 _rrow3.y - - // Now x corresponds to k, y to m, and z to n - const float4* lhs_block = &lhs_shmem4[threadIdx.x + 8 * (threadIdx.y % 2) + 17 * (threadIdx.y / 2)]; - const float2* rhs_block = &rhs_shmem2[2 * threadIdx.x + 16 * (threadIdx.z % 2) + 276 * (threadIdx.z / 2)]; - -#define lhs_element(i, k) lhs_block[68 * i + 136 * k] -#define rhs_element(k, j) rhs_block[33 * k + 1104 * j + ((k < 4) ? 0 : 12)] - -#define loadData(i) \ - _lcol0 = lhs_element(0, i); \ - _rrow0 = rhs_element(i, 0); \ - _rrow1 = *(&(rhs_element(i, 0)) + 1); \ - _lcol1 = lhs_element(1, i); \ - _rrow2 = rhs_element(i, 1); \ - _rrow3 = *(&(rhs_element(i, 1)) + 1); \ - -#define computeCol(j) \ - res(0, j) += lcol0 * rrow##j; \ - res(1, j) += lcol1 * rrow##j; \ - res(2, j) += lcol2 * rrow##j; \ - res(3, j) += lcol3 * rrow##j; \ - res(4, j) += lcol4 * rrow##j; \ - res(5, j) += lcol5 * rrow##j; \ - res(6, j) += lcol6 * rrow##j; \ - res(7, j) += lcol7 * rrow##j; \ - -#define computePass(i) \ - loadData(i); \ - \ - computeCol(0); \ - computeCol(1); \ - computeCol(2); \ - computeCol(3); \ - computeCol(4); \ - computeCol(5); \ - computeCol(6); \ - computeCol(7); \ - - computePass(0); - computePass(1); - computePass(2); - computePass(3); - computePass(4); - computePass(5); - computePass(6); - computePass(7); - -#undef lcol0 -#undef lcol1 -#undef lcol2 -#undef lcol3 -#undef lcol4 -#undef lcol5 -#undef lcol6 -#undef lcol7 -#undef rrow0 -#undef rrow1 -#undef rrow2 -#undef rrow3 -#undef rrow4 -#undef rrow5 -#undef rrow6 -#undef rrow7 - -#undef computePass -#undef computeCol -#undef loadData -#undef lhs_element -#undef rhs_element - - } // end loop over k - - // we've now iterated over all of the large (ie width 64) k blocks and - // accumulated results in registers. At this point thread (x, y, z) contains - // the sum across all big k blocks of the product of little k block of index (x, y) - // with block of index (y, z). To compute the final output, we need to reduce - // the 8 threads over y by summation. -#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) - -#define reduceRow(i, mask) \ - shuffleInc(i, 0, mask); \ - shuffleInc(i, 1, mask); \ - shuffleInc(i, 2, mask); \ - shuffleInc(i, 3, mask); \ - shuffleInc(i, 4, mask); \ - shuffleInc(i, 5, mask); \ - shuffleInc(i, 6, mask); \ - shuffleInc(i, 7, mask); \ - -#define reduceMatrix(mask) \ - reduceRow(0, mask); \ - reduceRow(1, mask); \ - reduceRow(2, mask); \ - reduceRow(3, mask); \ - reduceRow(4, mask); \ - reduceRow(5, mask); \ - reduceRow(6, mask); \ - reduceRow(7, mask); \ - - // actually perform the reduction, now each thread of index (_, y, z) - // contains the correct values in its registers that belong in the output - // block - reduceMatrix(1); - reduceMatrix(2); - reduceMatrix(4); - -#undef shuffleInc -#undef reduceRow -#undef reduceMatrix - - // now we need to copy the 64 values into main memory. We can't split work - // among threads because all variables are in registers. There's 2 ways - // to do this: - // (1) have 1 thread do 64 writes from registers into global memory - // (2) have 1 thread do 64 writes into shared memory, and then 8 threads - // each do 8 writes into global memory. We can just overwrite the shared - // memory from the problem we just solved. - // (3) Copies the values into new registers using conditional logic. - -#define makeAssignments(i) \ - val0 = res(i, 0); \ - val1 = res(i, 1); \ - val2 = res(i, 2); \ - val3 = res(i, 3); \ - val4 = res(i, 4); \ - val5 = res(i, 5); \ - val6 = res(i, 6); \ - val7 = res(i, 7); \ - - Scalar val0; - Scalar val1; - Scalar val2; - Scalar val3; - Scalar val4; - Scalar val5; - Scalar val6; - Scalar val7; - - switch (threadIdx.x) { - case 0: - makeAssignments(0); - break; - case 1: - makeAssignments(1); - break; - case 2: - makeAssignments(2); - break; - case 3: - makeAssignments(3); - break; - case 4: - makeAssignments(4); - break; - case 5: - makeAssignments(5); - break; - case 6: - makeAssignments(6); - break; - case 7: - makeAssignments(7); - break; - } - -#undef res - - const Index vert_base = base_m + 4 * threadIdx.y + (threadIdx.x % 4) + 32 * (threadIdx.x / 4); - const Index horiz_base = base_n + 4 * threadIdx.z; - - if (!needs_edge_check || vert_base < m_size) { - if (!needs_edge_check || horiz_base + 35 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - output(vert_base, horiz_base + 2) = val2; - output(vert_base, horiz_base + 3) = val3; - output(vert_base, horiz_base + 32) = val4; - output(vert_base, horiz_base + 33) = val5; - output(vert_base, horiz_base + 34) = val6; - output(vert_base, horiz_base + 35) = val7; - } else if (horiz_base + 34 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - output(vert_base, horiz_base + 2) = val2; - output(vert_base, horiz_base + 3) = val3; - output(vert_base, horiz_base + 32) = val4; - output(vert_base, horiz_base + 33) = val5; - output(vert_base, horiz_base + 34) = val6; - } else if (horiz_base + 33 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - output(vert_base, horiz_base + 2) = val2; - output(vert_base, horiz_base + 3) = val3; - output(vert_base, horiz_base + 32) = val4; - output(vert_base, horiz_base + 33) = val5; - } else if (horiz_base + 32 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - output(vert_base, horiz_base + 2) = val2; - output(vert_base, horiz_base + 3) = val3; - output(vert_base, horiz_base + 32) = val4; - } else if (horiz_base + 3 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - output(vert_base, horiz_base + 2) = val2; - output(vert_base, horiz_base + 3) = val3; - } else if (horiz_base + 2 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - output(vert_base, horiz_base + 2) = val2; - } else if (horiz_base + 1 < n_size) { - output(vert_base, horiz_base + 0) = val0; - output(vert_base, horiz_base + 1) = val1; - } else if (horiz_base < n_size) { - output(vert_base, horiz_base + 0) = val0; - } - } - } - - - template -__global__ void - __launch_bounds__(512) - EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, - const OutputMapper output, - const Index m_size, const Index n_size, const Index k_size) { - __shared__ float4 lhs_shmem[(68 * 64) / 4]; - __shared__ float2 rhs_shmem[((66 * 8 + 24) * 8) / 2]; - - const Index m_block_idx = blockIdx.x; - const Index n_block_idx = blockIdx.y; - - const Index base_m = 64 * m_block_idx; - const Index base_n = 64 * n_block_idx; - - if (base_m + 63 < m_size && base_n + 63 < n_size) { - EigenFloatContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); - } else { - EigenFloatContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); - } - } - - - template - struct TensorEvaluator, GpuDevice> : - public TensorContractionEvaluatorBase, GpuDevice> > { - - typedef GpuDevice Device; - - typedef TensorEvaluator, Device> Self; - typedef TensorContractionEvaluatorBase Base; - - typedef TensorContractionOp XprType; - typedef typename internal::remove_const::type Scalar; - typedef typename XprType::Packet Packet; - typedef typename XprType::Index Index; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketReturnType PacketReturnType; - - typedef array::Dimensions::count> left_dim_mapper_t; - typedef array::Dimensions::count> right_dim_mapper_t; - - typedef array::value> contract_t; - typedef array::Dimensions::count - internal::array_size::value> left_nocontract_t; - typedef array::Dimensions::count - internal::array_size::value> right_nocontract_t; - - static const int NumDims = max_n_1::Dimensions::count + TensorEvaluator::Dimensions::count - 2 * internal::array_size::value>::size; - - typedef DSizes Dimensions; - - // typedefs needed in evalTo - typedef typename internal::remove_const::type LhsScalar; - typedef typename internal::remove_const::type RhsScalar; - - typedef TensorEvaluator LeftEvaluator; - typedef TensorEvaluator RightEvaluator; - - typedef typename LeftEvaluator::Dimensions LeftDimensions; - typedef typename RightEvaluator::Dimensions RightDimensions; - - EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : - Base(op, device) {} - - // We need to redefine this method to make nvcc happy - EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { - this->m_leftImpl.evalSubExprsIfNeeded(NULL); - this->m_rightImpl.evalSubExprsIfNeeded(NULL); - if (data) { - evalTo(data); - return false; - } else { - this->m_result = static_cast(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); - evalTo(this->m_result); - return true; - } - } - - void evalTo(Scalar* buffer) const { - if (this->m_lhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_reordered) { - evalTyped(buffer); - } - else { - evalTyped(buffer); - } - } - else { - if (this->m_rhs_inner_dim_reordered) { - evalTyped(buffer); - } - else { - evalTyped(buffer); - } - } + void evalTo(Scalar* buffer) const { + if (this->m_lhs_inner_dim_contiguous) { + if (this->m_rhs_inner_dim_contiguous) { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); } else { - if (this->m_rhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_reordered) { - evalTyped(buffer); - } - else { - evalTyped(buffer); - } - } - else { - if (this->m_rhs_inner_dim_reordered) { - evalTyped(buffer); - } - else { - evalTyped(buffer); - } - } + evalTyped(buffer); } } + else { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + } + else { + if (this->m_rhs_inner_dim_contiguous) { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + else { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + } + } - template - void evalTyped(Scalar* buffer) const { - // columns in left side, rows in right side - const Index k = this->m_k_size; + template + void evalTyped(Scalar* buffer) const { + // columns in left side, rows in right side + const Index k = this->m_k_size; - // rows in left side - const Index m = this->m_i_size; + // rows in left side + const Index m = this->m_i_size; - // columns in right side - const Index n = this->m_j_size; + // columns in right side + const Index n = this->m_j_size; - // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) - this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) + this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); - typedef internal::TensorContractionInputMapper LhsMapper; + typedef internal::TensorContractionInputMapper LhsMapper; - typedef internal::TensorContractionInputMapper RhsMapper; + typedef internal::TensorContractionInputMapper RhsMapper; - typedef internal::blas_data_mapper OutputMapper; + typedef internal::blas_data_mapper OutputMapper; - // initialize data mappers - LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, - this->m_left_contracting_strides, this->m_k_strides); + // initialize data mappers + LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, + this->m_left_contracting_strides, this->m_k_strides); - RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, - this->m_right_contracting_strides, this->m_k_strides); + RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, + this->m_right_contracting_strides, this->m_k_strides); - OutputMapper output(buffer, m); + OutputMapper output(buffer, m); + setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte); + if (internal::is_same::value && + internal::is_same::value) { + if (m < 768 || n < 768) { const Index m_blocks = (m + 63) / 64; const Index n_blocks = (n + 63) / 64; const dim3 num_blocks(m_blocks, n_blocks, 1); - const dim3 block_size(8, 8, 8); - - cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); - if (internal::is_same::value && - internal::is_same::value) { - EigenFloatContractionKernel - <<m_device.stream()>>>(lhs, rhs, output, m, n, k); - } else { - EigenContractionKernel - <<m_device.stream()>>>(lhs, rhs, output, m, n, k); - } - - assert(cudaGetLastError() == cudaSuccess); + const dim3 block_size(16, 16, 1); + LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k); + } else { + const Index m_blocks = (m + 127) / 128; + const Index n_blocks = (n + 63) / 64; + const dim3 num_blocks(m_blocks, n_blocks, 1); + const dim3 block_size(8, 32, 1); + LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k); } - }; + } else { + const Index m_blocks = (m + 63) / 64; + const Index n_blocks = (n + 63) / 64; + const dim3 num_blocks(m_blocks, n_blocks, 1); + const dim3 block_size(8, 8, 8); + LAUNCH_CUDA_KERNEL((EigenContractionKernel), num_blocks, block_size, 0, this->m_device, lhs, rhs, output, m, n, k); + } + } +}; } // end namespace Eigen #endif // EIGEN_USE_GPU and __CUDACC__ - #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H