Fix #1874: workaround MSVC 2017 compilation issue.

This commit is contained in:
Gael Guennebaud 2020-05-15 20:47:32 +02:00
parent 9b411757ab
commit 8ce9630ddb

View File

@ -2682,100 +2682,96 @@ struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
HasHalf = (int)HalfPacketSize < (int)PacketSize,
HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize };
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
};
template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
EIGEN_UNUSED_VARIABLE(stride);
EIGEN_UNUSED_VARIABLE(offset);
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
// if(nr>=8)
// {
// for(Index j2=0; j2<packet_cols8; j2+=8)
// {
// // skip what we have before
// if(PanelMode) count += 8 * offset;
// for(Index k=0; k<depth; k++)
// {
// if (PacketSize==8) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
// pstoreu(blockB+count, cj.pconj(A));
// } else if (PacketSize==4) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
// pstoreu(blockB+count, cj.pconj(A));
// pstoreu(blockB+count+PacketSize, cj.pconj(B));
// } else {
// const Scalar* b0 = &rhs[k*rhsStride + j2];
// blockB[count+0] = cj(b0[0]);
// blockB[count+1] = cj(b0[1]);
// blockB[count+2] = cj(b0[2]);
// blockB[count+3] = cj(b0[3]);
// blockB[count+4] = cj(b0[4]);
// blockB[count+5] = cj(b0[5]);
// blockB[count+6] = cj(b0[6]);
// blockB[count+7] = cj(b0[7]);
// }
// count += 8;
// }
// // skip what we have after
// if(PanelMode) count += 8 * (stride-offset-depth);
// }
// }
if(nr>=4)
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0)
{
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
EIGEN_UNUSED_VARIABLE(stride);
EIGEN_UNUSED_VARIABLE(offset);
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
// if(nr>=8)
// {
// for(Index j2=0; j2<packet_cols8; j2+=8)
// {
// // skip what we have before
// if(PanelMode) count += 8 * offset;
// for(Index k=0; k<depth; k++)
// {
// if (PacketSize==8) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
// pstoreu(blockB+count, cj.pconj(A));
// } else if (PacketSize==4) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
// pstoreu(blockB+count, cj.pconj(A));
// pstoreu(blockB+count+PacketSize, cj.pconj(B));
// } else {
// const Scalar* b0 = &rhs[k*rhsStride + j2];
// blockB[count+0] = cj(b0[0]);
// blockB[count+1] = cj(b0[1]);
// blockB[count+2] = cj(b0[2]);
// blockB[count+3] = cj(b0[3]);
// blockB[count+4] = cj(b0[4]);
// blockB[count+5] = cj(b0[5]);
// blockB[count+6] = cj(b0[6]);
// blockB[count+7] = cj(b0[7]);
// }
// count += 8;
// }
// // skip what we have after
// if(PanelMode) count += 8 * (stride-offset-depth);
// }
// }
if(nr>=4)
{
// skip what we have before
if(PanelMode) count += 4 * offset;
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
// skip what we have before
if(PanelMode) count += 4 * offset;
for(Index k=0; k<depth; k++)
{
if (PacketSize==4) {
Packet A = rhs.template loadPacket<Packet>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += PacketSize;
} else if (HasHalf && HalfPacketSize==4) {
HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += HalfPacketSize;
} else if (HasQuarter && QuarterPacketSize==4) {
QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += QuarterPacketSize;
} else {
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
blockB[count+0] = cj(dm0(0));
blockB[count+1] = cj(dm0(1));
blockB[count+2] = cj(dm0(2));
blockB[count+3] = cj(dm0(3));
count += 4;
}
}
// skip what we have after
if(PanelMode) count += 4 * (stride-offset-depth);
}
}
// copy the remaining columns one at a time (nr==1)
for(Index j2=packet_cols4; j2<cols; ++j2)
{
if(PanelMode) count += offset;
for(Index k=0; k<depth; k++)
{
if (PacketSize==4) {
Packet A = rhs.template loadPacket<Packet>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += PacketSize;
} else if (HasHalf && HalfPacketSize==4) {
HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += HalfPacketSize;
} else if (HasQuarter && QuarterPacketSize==4) {
QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += QuarterPacketSize;
} else {
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
blockB[count+0] = cj(dm0(0));
blockB[count+1] = cj(dm0(1));
blockB[count+2] = cj(dm0(2));
blockB[count+3] = cj(dm0(3));
count += 4;
}
blockB[count] = cj(rhs(k, j2));
count += 1;
}
// skip what we have after
if(PanelMode) count += 4 * (stride-offset-depth);
if(PanelMode) count += stride-offset-depth;
}
}
// copy the remaining columns one at a time (nr==1)
for(Index j2=packet_cols4; j2<cols; ++j2)
{
if(PanelMode) count += offset;
for(Index k=0; k<depth; k++)
{
blockB[count] = cj(rhs(k, j2));
count += 1;
}
if(PanelMode) count += stride-offset-depth;
}
}
};
} // end namespace internal