Fix GPU support.

This commit is contained in:
Gael Guennebaud 2018-09-20 18:29:21 +02:00
parent e0f6d352fb
commit 3c6dc93f99

View File

@ -549,12 +549,12 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
#define prefetch_lhs(reg, row, col) \
if (!CHECK_LHS_BOUNDARY) { \
if (col < k_size) { \
reg =lhs.template loadPacket<Unaligned>(row, col); \
reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
} \
} else { \
if (col < k_size) { \
if (row + 3 < m_size) { \
reg =lhs.template loadPacket<Unaligned>(row, col); \
reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
} else if (row + 2 < m_size) { \
reg.x =lhs(row + 0, col); \
reg.y =lhs(row + 1, col); \
@ -584,7 +584,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@ -599,7 +599,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
} else {
if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
} else if ((rhs_vert + 2) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
@ -799,37 +799,37 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (!CHECK_LHS_BOUNDARY) {
if ((threadIdx.y/4+k+24) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
} else if ((threadIdx.y/4+k+16) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
} else if ((threadIdx.y/4+k+8) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
} else if ((threadIdx.y/4+k) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
}
} else {
// just CHECK_LHS_BOUNDARY
if (lhs_vert + 3 < m_size) {
if ((threadIdx.y/4+k+24) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
} else if ((threadIdx.y/4+k+16) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
} else if ((threadIdx.y/4+k+8) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
} else if ((threadIdx.y/4+k) < k_size) {
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
}
} else if (lhs_vert + 2 < m_size) {
if ((threadIdx.y/4+k+24) < k_size) {
@ -918,8 +918,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@ -941,8 +941,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (rhs_horiz1 < n_size) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@ -963,7 +963,7 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
} else if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
} else if ((rhs_vert + 2) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);