mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Fix GPU support.
This commit is contained in:
parent
e0f6d352fb
commit
3c6dc93f99
@ -549,12 +549,12 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
#define prefetch_lhs(reg, row, col) \
|
||||
if (!CHECK_LHS_BOUNDARY) { \
|
||||
if (col < k_size) { \
|
||||
reg =lhs.template loadPacket<Unaligned>(row, col); \
|
||||
reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
|
||||
} \
|
||||
} else { \
|
||||
if (col < k_size) { \
|
||||
if (row + 3 < m_size) { \
|
||||
reg =lhs.template loadPacket<Unaligned>(row, col); \
|
||||
reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
|
||||
} else if (row + 2 < m_size) { \
|
||||
reg.x =lhs(row + 0, col); \
|
||||
reg.y =lhs(row + 1, col); \
|
||||
@ -584,7 +584,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
if (!CHECK_RHS_BOUNDARY) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
|
||||
} else if (rhs_vert + 2 < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
@ -599,7 +599,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
} else {
|
||||
if (rhs_horiz0 < n_size) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
|
||||
} else if ((rhs_vert + 2) < k_size) {
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
|
||||
@ -799,37 +799,37 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
|
||||
if (!CHECK_LHS_BOUNDARY) {
|
||||
if ((threadIdx.y/4+k+24) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
} else if ((threadIdx.y/4+k+16) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
} else if ((threadIdx.y/4+k+8) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
} else if ((threadIdx.y/4+k) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
}
|
||||
} else {
|
||||
// just CHECK_LHS_BOUNDARY
|
||||
if (lhs_vert + 3 < m_size) {
|
||||
if ((threadIdx.y/4+k+24) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
} else if ((threadIdx.y/4+k+16) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
} else if ((threadIdx.y/4+k+8) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
} else if ((threadIdx.y/4+k) < k_size) {
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
}
|
||||
} else if (lhs_vert + 2 < m_size) {
|
||||
if ((threadIdx.y/4+k+24) < k_size) {
|
||||
@ -918,8 +918,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
if (!CHECK_RHS_BOUNDARY) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
|
||||
} else if (rhs_vert + 2 < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
@ -941,8 +941,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
if (rhs_horiz1 < n_size) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
|
||||
} else if (rhs_vert + 2 < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
@ -963,7 +963,7 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
} else if (rhs_horiz0 < n_size) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
|
||||
} else if ((rhs_vert + 2) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
|
Loading…
Reference in New Issue
Block a user