mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Invert rows and depth in non-vectorized portion of packing (PowerPC).
This commit is contained in:
parent
e1cb6369b0
commit
9cf34ee0ae
@ -129,20 +129,20 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<
|
||||
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstore((float*)to, from.v); }
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstoreu((float*)to, from.v); }
|
||||
|
||||
EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>* from0, const std::complex<float>* from1)
|
||||
EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>& from0, const std::complex<float>& from1)
|
||||
{
|
||||
Packet4f res0, res1;
|
||||
#ifdef __VSX__
|
||||
__asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0));
|
||||
__asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1));
|
||||
__asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0));
|
||||
__asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1));
|
||||
#ifdef _BIG_ENDIAN
|
||||
__asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
|
||||
#else
|
||||
__asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
|
||||
#endif
|
||||
#else
|
||||
*reinterpret_cast<std::complex<float> *>(&res0) = *from0;
|
||||
*reinterpret_cast<std::complex<float> *>(&res1) = *from1;
|
||||
*reinterpret_cast<std::complex<float> *>(&res0) = from0;
|
||||
*reinterpret_cast<std::complex<float> *>(&res1) = from1;
|
||||
res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI);
|
||||
#endif
|
||||
return Packet2cf(res0);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,22 +11,8 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
|
||||
EIGEN_STRONG_INLINE void gemm_extra_col(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index row,
|
||||
Index col,
|
||||
Index remaining_rows,
|
||||
Index remaining_cols,
|
||||
const Packet& pAlpha);
|
||||
|
||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||||
EIGEN_STRONG_INLINE void gemm_extra_row(
|
||||
EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
@ -41,41 +27,28 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
|
||||
const Packet& pAlpha,
|
||||
const Packet& pMask);
|
||||
|
||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
|
||||
EIGEN_STRONG_INLINE void gemm_unrolled_col(
|
||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
EIGEN_STRONG_INLINE void gemm_extra_cols(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
const Scalar* blockA,
|
||||
const Scalar* blockB,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index& row,
|
||||
Index rows,
|
||||
Index strideB,
|
||||
Index offsetB,
|
||||
Index col,
|
||||
Index remaining_cols,
|
||||
const Packet& pAlpha);
|
||||
Index rows,
|
||||
Index cols,
|
||||
Index remaining_rows,
|
||||
const Packet& pAlpha,
|
||||
const Packet& pMask);
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows);
|
||||
|
||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
EIGEN_STRONG_INLINE void gemm_complex_extra_col(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index strideB,
|
||||
Index row,
|
||||
Index col,
|
||||
Index remaining_rows,
|
||||
Index remaining_cols,
|
||||
const Packet& pAlphaReal,
|
||||
const Packet& pAlphaImag);
|
||||
|
||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
EIGEN_STRONG_INLINE void gemm_complex_extra_row(
|
||||
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
@ -93,123 +66,88 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
|
||||
const Packet& pMask);
|
||||
|
||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
|
||||
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
const Scalar* blockA,
|
||||
const Scalar* blockB,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index strideB,
|
||||
Index& row,
|
||||
Index rows,
|
||||
Index offsetB,
|
||||
Index col,
|
||||
Index remaining_cols,
|
||||
Index rows,
|
||||
Index cols,
|
||||
Index remaining_rows,
|
||||
const Packet& pAlphaReal,
|
||||
const Packet& pAlphaImag);
|
||||
const Packet& pAlphaImag,
|
||||
const Packet& pMask);
|
||||
|
||||
template<typename Scalar, typename Packet>
|
||||
EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs);
|
||||
|
||||
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
|
||||
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col);
|
||||
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
|
||||
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row, Index col);
|
||||
|
||||
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
|
||||
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col);
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha);
|
||||
template<typename Packet, int N>
|
||||
EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha);
|
||||
|
||||
template<typename Packet, int N>
|
||||
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
|
||||
|
||||
const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
|
||||
16, 17, 18, 19,
|
||||
4, 5, 6, 7,
|
||||
20, 21, 22, 23};
|
||||
|
||||
const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11,
|
||||
24, 25, 26, 27,
|
||||
12, 13, 14, 15,
|
||||
28, 29, 30, 31};
|
||||
//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64
|
||||
const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7,
|
||||
16, 17, 18, 19, 20, 21, 22, 23};
|
||||
|
||||
//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64
|
||||
const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
|
||||
// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
|
||||
template<typename Packet, typename Packetc>
|
||||
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
|
||||
template<typename Packet, typename Packetc, int N>
|
||||
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
|
||||
{
|
||||
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
|
||||
acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
|
||||
acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST);
|
||||
acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST);
|
||||
acc1.packet[0].v = vec_mergeh(taccReal.packet[0], taccImag.packet[0]);
|
||||
if (N > 1) {
|
||||
acc1.packet[1].v = vec_mergeh(taccReal.packet[1], taccImag.packet[1]);
|
||||
}
|
||||
if (N > 2) {
|
||||
acc1.packet[2].v = vec_mergeh(taccReal.packet[2], taccImag.packet[2]);
|
||||
}
|
||||
if (N > 3) {
|
||||
acc1.packet[3].v = vec_mergeh(taccReal.packet[3], taccImag.packet[3]);
|
||||
}
|
||||
|
||||
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
|
||||
acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
|
||||
acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
|
||||
acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
|
||||
acc2.packet[0].v = vec_mergel(taccReal.packet[0], taccImag.packet[0]);
|
||||
if (N > 1) {
|
||||
acc2.packet[1].v = vec_mergel(taccReal.packet[1], taccImag.packet[1]);
|
||||
}
|
||||
if (N > 2) {
|
||||
acc2.packet[2].v = vec_mergel(taccReal.packet[2], taccImag.packet[2]);
|
||||
}
|
||||
if (N > 3) {
|
||||
acc2.packet[3].v = vec_mergel(taccReal.packet[3], taccImag.packet[3]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Packet, typename Packetc>
|
||||
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
|
||||
template<typename Packet, typename Packetc, int N>
|
||||
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc,N*2>& tRes, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
|
||||
{
|
||||
bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
|
||||
bcouple_common<Packet, Packetc, N>(taccReal, taccImag, acc1, acc2);
|
||||
|
||||
acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
|
||||
acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
|
||||
acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
|
||||
acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
|
||||
if (N > 1) {
|
||||
acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
|
||||
}
|
||||
if (N > 2) {
|
||||
acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
|
||||
}
|
||||
if (N > 3) {
|
||||
acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
|
||||
}
|
||||
|
||||
acc2.packet[0] = padd<Packetc>(tRes.packet[4], acc2.packet[0]);
|
||||
acc2.packet[1] = padd<Packetc>(tRes.packet[5], acc2.packet[1]);
|
||||
acc2.packet[2] = padd<Packetc>(tRes.packet[6], acc2.packet[2]);
|
||||
acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
|
||||
}
|
||||
|
||||
template<typename Packet, typename Packetc>
|
||||
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
|
||||
{
|
||||
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
|
||||
|
||||
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
|
||||
}
|
||||
|
||||
template<typename Packet, typename Packetc>
|
||||
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
|
||||
{
|
||||
bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
|
||||
|
||||
acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
|
||||
|
||||
acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
|
||||
{
|
||||
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
|
||||
acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
|
||||
acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST);
|
||||
acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST);
|
||||
|
||||
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
|
||||
acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
|
||||
acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
|
||||
acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2)
|
||||
{
|
||||
acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
|
||||
|
||||
acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
|
||||
acc2.packet[0] = padd<Packetc>(tRes.packet[0+N], acc2.packet[0]);
|
||||
if (N > 1) {
|
||||
acc2.packet[1] = padd<Packetc>(tRes.packet[1+N], acc2.packet[1]);
|
||||
}
|
||||
if (N > 2) {
|
||||
acc2.packet[2] = padd<Packetc>(tRes.packet[2+N], acc2.packet[2]);
|
||||
}
|
||||
if (N > 3) {
|
||||
acc2.packet[3] = padd<Packetc>(tRes.packet[3+N], acc2.packet[3]);
|
||||
}
|
||||
}
|
||||
|
||||
// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
|
||||
|
@ -11,7 +11,7 @@
|
||||
#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
|
||||
#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
|
||||
|
||||
#pragma GCC target("cpu=power10")
|
||||
#pragma GCC target("cpu=power10,htm")
|
||||
|
||||
#ifdef __has_builtin
|
||||
#if !__has_builtin(__builtin_vsx_assemble_pair)
|
||||
@ -32,37 +32,37 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
|
||||
}
|
||||
|
||||
template<typename DataMapper, typename Index, typename Packet, const Index accCols>
|
||||
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
|
||||
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
|
||||
{
|
||||
PacketBlock<Packet, 4> result;
|
||||
__builtin_mma_disassemble_acc(&result.packet, acc);
|
||||
|
||||
PacketBlock<Packet, 4> tRes;
|
||||
bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
|
||||
bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
|
||||
|
||||
bscale<Packet>(tRes, result, alpha);
|
||||
bscale<Packet, 4>(tRes, result, alpha);
|
||||
|
||||
data.template storePacketBlock<Packet, 4>(i, j, tRes);
|
||||
data.template storePacketBlock<Packet, 4>(i, 0, tRes);
|
||||
}
|
||||
|
||||
template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
|
||||
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
|
||||
template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC>
|
||||
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
|
||||
{
|
||||
PacketBlock<Packet, 4> resultReal, resultImag;
|
||||
__builtin_mma_disassemble_acc(&resultReal.packet, accReal);
|
||||
__builtin_mma_disassemble_acc(&resultImag.packet, accImag);
|
||||
|
||||
PacketBlock<Packetc, 8> tRes;
|
||||
bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
|
||||
bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
|
||||
|
||||
PacketBlock<Packet,4> taccReal, taccImag;
|
||||
bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
|
||||
|
||||
PacketBlock<Packetc, 4> acc1, acc2;
|
||||
bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
|
||||
bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
|
||||
|
||||
data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
|
||||
data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
|
||||
data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
|
||||
data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
|
||||
}
|
||||
|
||||
// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
|
||||
@ -127,7 +127,7 @@ EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag
|
||||
template<typename Scalar, typename Packet>
|
||||
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
|
||||
{
|
||||
rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
|
||||
rhsV = ploadRhs<Scalar, Packet>(rhs);
|
||||
}
|
||||
|
||||
template<>
|
||||
@ -186,12 +186,11 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
|
||||
}
|
||||
|
||||
#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
|
||||
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
|
||||
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
|
||||
MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
|
||||
MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
|
||||
MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
|
||||
MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
|
||||
MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
|
||||
MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
|
||||
|
||||
#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
|
||||
type rhsV0; \
|
||||
@ -224,7 +223,7 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
|
||||
|
||||
#define MICRO_MMA_SRC_PTR_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
|
||||
lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
|
||||
}
|
||||
@ -240,21 +239,19 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
|
||||
|
||||
#define MICRO_MMA_STORE_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
|
||||
storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
|
||||
}
|
||||
|
||||
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
|
||||
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||||
EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
|
||||
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index& row,
|
||||
Index col,
|
||||
const Packet& pAlpha)
|
||||
{
|
||||
const Scalar* rhs_ptr = rhs_base;
|
||||
@ -280,11 +277,84 @@ EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
|
||||
row += unroll_factor*accCols;
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||||
EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
||||
const DataMapper& res,
|
||||
const Scalar* blockA,
|
||||
const Scalar* blockB,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index strideB,
|
||||
Index offsetB,
|
||||
Index col,
|
||||
Index rows,
|
||||
Index cols,
|
||||
Index remaining_rows,
|
||||
const Packet& pAlpha,
|
||||
const Packet& pMask)
|
||||
{
|
||||
const DataMapper res3 = res.getSubMapper(0, col);
|
||||
|
||||
const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
|
||||
const Scalar* lhs_base = blockA + accCols*offsetA;
|
||||
Index row = 0;
|
||||
|
||||
#define MAX_MMA_UNROLL 7
|
||||
while(row + MAX_MMA_UNROLL*accCols <= rows) {
|
||||
gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
}
|
||||
switch( (rows-row)/accCols ) {
|
||||
#if MAX_MMA_UNROLL > 7
|
||||
case 7:
|
||||
gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 6
|
||||
case 6:
|
||||
gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 5
|
||||
case 5:
|
||||
gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 4
|
||||
case 4:
|
||||
gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 3
|
||||
case 3:
|
||||
gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 2
|
||||
case 2:
|
||||
gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 1
|
||||
case 1:
|
||||
gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
break;
|
||||
}
|
||||
#undef MAX_MMA_UNROLL
|
||||
|
||||
if(remaining_rows > 0)
|
||||
{
|
||||
gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
const Index remaining_rows = rows % accCols;
|
||||
const Index remaining_cols = cols % accRows;
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
@ -295,79 +365,10 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
Index col = 0;
|
||||
for(; col + accRows <= cols; col += accRows)
|
||||
{
|
||||
const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
|
||||
const Scalar* lhs_base = blockA;
|
||||
|
||||
Index row = 0;
|
||||
#define MAX_MMA_UNROLL 7
|
||||
while(row + MAX_MMA_UNROLL*accCols <= rows) {
|
||||
gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
}
|
||||
switch( (rows-row)/accCols ) {
|
||||
#if MAX_MMA_UNROLL > 7
|
||||
case 7:
|
||||
gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 6
|
||||
case 6:
|
||||
gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 5
|
||||
case 5:
|
||||
gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 4
|
||||
case 4:
|
||||
gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 3
|
||||
case 3:
|
||||
gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 2
|
||||
case 2:
|
||||
gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 1
|
||||
case 1:
|
||||
gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
break;
|
||||
}
|
||||
#undef MAX_MMA_UNROLL
|
||||
|
||||
if(remaining_rows > 0)
|
||||
{
|
||||
gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||
}
|
||||
gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||
}
|
||||
|
||||
if(remaining_cols > 0)
|
||||
{
|
||||
const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
|
||||
const Scalar* lhs_base = blockA;
|
||||
|
||||
for(; col < cols; col++)
|
||||
{
|
||||
Index row = 0;
|
||||
|
||||
gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
|
||||
|
||||
if (remaining_rows > 0)
|
||||
{
|
||||
gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
|
||||
}
|
||||
rhs_base++;
|
||||
}
|
||||
}
|
||||
gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||
}
|
||||
|
||||
#define accColsC (accCols / 2)
|
||||
@ -375,21 +376,20 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
#define advanceCols ((RhsIsReal) ? 1 : 2)
|
||||
|
||||
// PEEL_COMPLEX_MMA loop factor.
|
||||
#define PEEL_COMPLEX_MMA 7
|
||||
#define PEEL_COMPLEX_MMA 3
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL(func) \
|
||||
func(0) func(1) func(2) func(3) func(4)
|
||||
func(0) func(1) func(2) func(3)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
|
||||
lhs_ptr_real##iter += accCols; \
|
||||
if(!LhsIsReal) { \
|
||||
lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
|
||||
lhs_ptr_imag##iter += accCols; \
|
||||
lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
|
||||
} \
|
||||
lhs_ptr_real##iter += accCols; \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhsV##iter); \
|
||||
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
|
||||
@ -402,8 +402,8 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
|
||||
#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
|
||||
if (PEEL_COMPLEX_MMA > peel) { \
|
||||
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
|
||||
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
|
||||
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
|
||||
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
|
||||
ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
|
||||
if(!RhsIsReal) { \
|
||||
ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
|
||||
@ -411,20 +411,17 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
|
||||
} \
|
||||
MICRO_COMPLEX_MMA_UNROLL(func2); \
|
||||
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
|
||||
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
|
||||
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
|
||||
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
|
||||
type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
|
||||
type rhsV0, rhsV1, rhsV2, rhsV3; \
|
||||
type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
|
||||
type rhsV0, rhsVi0; \
|
||||
@ -461,15 +458,9 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
|
||||
#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
|
||||
if(!LhsIsReal) { \
|
||||
lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
|
||||
} \
|
||||
lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
|
||||
EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
|
||||
@ -477,45 +468,40 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
|
||||
if(!LhsIsReal) { \
|
||||
EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
|
||||
storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
|
||||
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||
const DataMapper& res,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index strideB,
|
||||
Index& row,
|
||||
Index col,
|
||||
const Packet& pAlphaReal,
|
||||
const Packet& pAlphaImag)
|
||||
{
|
||||
const Scalar* rhs_ptr_real = rhs_base;
|
||||
const Scalar* rhs_ptr_imag;
|
||||
const Scalar* rhs_ptr_imag = NULL;
|
||||
const Index imag_delta = accCols*strideA;
|
||||
if(!RhsIsReal) {
|
||||
rhs_ptr_imag = rhs_base + accRows*strideB;
|
||||
} else {
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
|
||||
}
|
||||
const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
|
||||
const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
|
||||
const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
|
||||
__vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
|
||||
const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
|
||||
const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
|
||||
__vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
|
||||
|
||||
MICRO_COMPLEX_MMA_SRC_PTR
|
||||
MICRO_COMPLEX_MMA_DST_PTR
|
||||
@ -539,11 +525,70 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||
row += unroll_factor*accCols;
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
||||
const DataMapper& res,
|
||||
const Scalar* blockA,
|
||||
const Scalar* blockB,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index offsetA,
|
||||
Index strideB,
|
||||
Index offsetB,
|
||||
Index col,
|
||||
Index rows,
|
||||
Index cols,
|
||||
Index remaining_rows,
|
||||
const Packet& pAlphaReal,
|
||||
const Packet& pAlphaImag,
|
||||
const Packet& pMask)
|
||||
{
|
||||
const DataMapper res3 = res.getSubMapper(0, col);
|
||||
|
||||
const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
|
||||
const Scalar* lhs_base = blockA + accCols*offsetA;
|
||||
Index row = 0;
|
||||
|
||||
#define MAX_COMPLEX_MMA_UNROLL 4
|
||||
while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
|
||||
gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
|
||||
}
|
||||
switch( (rows-row)/accCols ) {
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 4
|
||||
case 4:
|
||||
gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 3
|
||||
case 3:
|
||||
gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 2
|
||||
case 2:
|
||||
gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 1
|
||||
case 1:
|
||||
gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
break;
|
||||
}
|
||||
#undef MAX_COMPLEX_MMA_UNROLL
|
||||
|
||||
if(remaining_rows > 0)
|
||||
{
|
||||
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
const Index remaining_rows = rows % accCols;
|
||||
const Index remaining_cols = cols % accRows;
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
@ -558,64 +603,10 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS
|
||||
Index col = 0;
|
||||
for(; col + accRows <= cols; col += accRows)
|
||||
{
|
||||
const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
|
||||
const Scalar* lhs_base = blockA;
|
||||
Index row = 0;
|
||||
|
||||
#define MAX_COMPLEX_MMA_UNROLL 4
|
||||
while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
|
||||
gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
|
||||
}
|
||||
switch( (rows-row)/accCols ) {
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 4
|
||||
case 4:
|
||||
gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 3
|
||||
case 3:
|
||||
gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 2
|
||||
case 2:
|
||||
gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 1
|
||||
case 1:
|
||||
gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
break;
|
||||
}
|
||||
#undef MAX_COMPLEX_MMA_UNROLL
|
||||
|
||||
if(remaining_rows > 0)
|
||||
{
|
||||
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
}
|
||||
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
}
|
||||
|
||||
if(remaining_cols > 0)
|
||||
{
|
||||
const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
|
||||
const Scalar* lhs_base = blockA;
|
||||
|
||||
for(; col < cols; col++)
|
||||
{
|
||||
Index row = 0;
|
||||
|
||||
gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
|
||||
|
||||
if (remaining_rows > 0)
|
||||
{
|
||||
gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
|
||||
}
|
||||
rhs_base++;
|
||||
}
|
||||
}
|
||||
gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
}
|
||||
|
||||
#undef accColsC
|
||||
|
Loading…
Reference in New Issue
Block a user