mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-12 14:25:16 +08:00
Implemented preduxp for AVX512
This commit is contained in:
parent
5f85662ad8
commit
5e89ded685
@ -651,13 +651,221 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
|
||||
_mm512_set1_epi64(0x7fffffffffffffff));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f* vecs)
|
||||
#ifdef EIGEN_VECTORIZE_AVX512DQ
|
||||
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
|
||||
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
|
||||
__m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0) __m256 OUTPUT##_1 = \
|
||||
_mm512_extractf32x8_ps(INPUT, 1)
|
||||
#else
|
||||
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
|
||||
__m256 OUTPUT##_0 = _mm256_insertf128_ps( \
|
||||
_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \
|
||||
_mm512_extractf32x4_ps(INPUT, 1), 1); \
|
||||
__m256 OUTPUT##_1 = _mm256_insertf128_ps( \
|
||||
_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \
|
||||
_mm512_extractf32x4_ps(INPUT, 3), 1);
|
||||
#endif
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX512DQ
|
||||
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
|
||||
OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTA, 0); \
|
||||
OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTB, 1);
|
||||
#else
|
||||
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
|
||||
#endif
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f*
|
||||
vecs)
|
||||
{
|
||||
assert(false && "To be implemented");
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[0], vecs0);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[1], vecs1);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[2], vecs2);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[3], vecs3);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[4], vecs4);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[5], vecs5);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[6], vecs6);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[7], vecs7);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[8], vecs8);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[9], vecs9);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[10], vecs10);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[11], vecs11);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[12], vecs12);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[13], vecs13);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[14], vecs14);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(vecs[15], vecs15);
|
||||
|
||||
__m256 hsum1 = _mm256_hadd_ps(vecs0_0, vecs1_0);
|
||||
__m256 hsum2 = _mm256_hadd_ps(vecs2_0, vecs3_0);
|
||||
__m256 hsum3 = _mm256_hadd_ps(vecs4_0, vecs5_0);
|
||||
__m256 hsum4 = _mm256_hadd_ps(vecs6_0, vecs7_0);
|
||||
|
||||
__m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1);
|
||||
__m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2);
|
||||
__m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3);
|
||||
__m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4);
|
||||
|
||||
__m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
|
||||
__m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
|
||||
__m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
|
||||
__m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
|
||||
|
||||
__m256 sum1 = _mm256_add_ps(perm1, hsum5);
|
||||
__m256 sum2 = _mm256_add_ps(perm2, hsum6);
|
||||
__m256 sum3 = _mm256_add_ps(perm3, hsum7);
|
||||
__m256 sum4 = _mm256_add_ps(perm4, hsum8);
|
||||
|
||||
__m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
|
||||
__m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
|
||||
|
||||
__m256 final = _mm256_blend_ps(blend1, blend2, 0xf0);
|
||||
|
||||
hsum1 = _mm256_hadd_ps(vecs0_1, vecs1_1);
|
||||
hsum2 = _mm256_hadd_ps(vecs2_1, vecs3_1);
|
||||
hsum3 = _mm256_hadd_ps(vecs4_1, vecs5_1);
|
||||
hsum4 = _mm256_hadd_ps(vecs6_1, vecs7_1);
|
||||
|
||||
hsum5 = _mm256_hadd_ps(hsum1, hsum1);
|
||||
hsum6 = _mm256_hadd_ps(hsum2, hsum2);
|
||||
hsum7 = _mm256_hadd_ps(hsum3, hsum3);
|
||||
hsum8 = _mm256_hadd_ps(hsum4, hsum4);
|
||||
|
||||
perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
|
||||
perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
|
||||
perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
|
||||
perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
|
||||
|
||||
sum1 = _mm256_add_ps(perm1, hsum5);
|
||||
sum2 = _mm256_add_ps(perm2, hsum6);
|
||||
sum3 = _mm256_add_ps(perm3, hsum7);
|
||||
sum4 = _mm256_add_ps(perm4, hsum8);
|
||||
|
||||
blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
|
||||
blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
|
||||
|
||||
final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0));
|
||||
|
||||
hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0);
|
||||
hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0);
|
||||
hsum3 = _mm256_hadd_ps(vecs12_0, vecs13_0);
|
||||
hsum4 = _mm256_hadd_ps(vecs14_0, vecs15_0);
|
||||
|
||||
hsum5 = _mm256_hadd_ps(hsum1, hsum1);
|
||||
hsum6 = _mm256_hadd_ps(hsum2, hsum2);
|
||||
hsum7 = _mm256_hadd_ps(hsum3, hsum3);
|
||||
hsum8 = _mm256_hadd_ps(hsum4, hsum4);
|
||||
|
||||
perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
|
||||
perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
|
||||
perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
|
||||
perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
|
||||
|
||||
sum1 = _mm256_add_ps(perm1, hsum5);
|
||||
sum2 = _mm256_add_ps(perm2, hsum6);
|
||||
sum3 = _mm256_add_ps(perm3, hsum7);
|
||||
sum4 = _mm256_add_ps(perm4, hsum8);
|
||||
|
||||
blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
|
||||
blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
|
||||
|
||||
__m256 final_1 = _mm256_blend_ps(blend1, blend2, 0xf0);
|
||||
|
||||
hsum1 = _mm256_hadd_ps(vecs8_1, vecs9_1);
|
||||
hsum2 = _mm256_hadd_ps(vecs10_1, vecs11_1);
|
||||
hsum3 = _mm256_hadd_ps(vecs12_1, vecs13_1);
|
||||
hsum4 = _mm256_hadd_ps(vecs14_1, vecs15_1);
|
||||
|
||||
hsum5 = _mm256_hadd_ps(hsum1, hsum1);
|
||||
hsum6 = _mm256_hadd_ps(hsum2, hsum2);
|
||||
hsum7 = _mm256_hadd_ps(hsum3, hsum3);
|
||||
hsum8 = _mm256_hadd_ps(hsum4, hsum4);
|
||||
|
||||
perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
|
||||
perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
|
||||
perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
|
||||
perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
|
||||
|
||||
sum1 = _mm256_add_ps(perm1, hsum5);
|
||||
sum2 = _mm256_add_ps(perm2, hsum6);
|
||||
sum3 = _mm256_add_ps(perm3, hsum7);
|
||||
sum4 = _mm256_add_ps(perm4, hsum8);
|
||||
|
||||
blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
|
||||
blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
|
||||
|
||||
final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0));
|
||||
|
||||
__m512 final_output;
|
||||
|
||||
EIGEN_INSERT_8f_INTO_16f(final_output, final, final_1);
|
||||
return final_output;
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
|
||||
{
|
||||
assert(false && "To be implemented");
|
||||
Packet4d vecs0_0 = _mm512_extractf64x4_pd(vecs[0], 0);
|
||||
Packet4d vecs0_1 = _mm512_extractf64x4_pd(vecs[0], 1);
|
||||
|
||||
Packet4d vecs1_0 = _mm512_extractf64x4_pd(vecs[1], 0);
|
||||
Packet4d vecs1_1 = _mm512_extractf64x4_pd(vecs[1], 1);
|
||||
|
||||
Packet4d vecs2_0 = _mm512_extractf64x4_pd(vecs[2], 0);
|
||||
Packet4d vecs2_1 = _mm512_extractf64x4_pd(vecs[2], 1);
|
||||
|
||||
Packet4d vecs3_0 = _mm512_extractf64x4_pd(vecs[3], 0);
|
||||
Packet4d vecs3_1 = _mm512_extractf64x4_pd(vecs[3], 1);
|
||||
|
||||
Packet4d vecs4_0 = _mm512_extractf64x4_pd(vecs[4], 0);
|
||||
Packet4d vecs4_1 = _mm512_extractf64x4_pd(vecs[4], 1);
|
||||
|
||||
Packet4d vecs5_0 = _mm512_extractf64x4_pd(vecs[5], 0);
|
||||
Packet4d vecs5_1 = _mm512_extractf64x4_pd(vecs[5], 1);
|
||||
|
||||
Packet4d vecs6_0 = _mm512_extractf64x4_pd(vecs[6], 0);
|
||||
Packet4d vecs6_1 = _mm512_extractf64x4_pd(vecs[6], 1);
|
||||
|
||||
Packet4d vecs7_0 = _mm512_extractf64x4_pd(vecs[7], 0);
|
||||
Packet4d vecs7_1 = _mm512_extractf64x4_pd(vecs[7], 1);
|
||||
|
||||
Packet4d tmp0, tmp1;
|
||||
|
||||
tmp0 = _mm256_hadd_pd(vecs0_0, vecs1_0);
|
||||
tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
|
||||
|
||||
tmp1 = _mm256_hadd_pd(vecs2_0, vecs3_0);
|
||||
tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
|
||||
|
||||
__m256d final_0 = _mm256_blend_pd(tmp0, tmp1, 0xC);
|
||||
|
||||
tmp0 = _mm256_hadd_pd(vecs0_1, vecs1_1);
|
||||
tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
|
||||
|
||||
tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1);
|
||||
tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
|
||||
|
||||
final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC));
|
||||
|
||||
tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0);
|
||||
tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
|
||||
|
||||
tmp1 = _mm256_hadd_pd(vecs6_0, vecs7_0);
|
||||
tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
|
||||
|
||||
__m256d final_1 = _mm256_blend_pd(tmp0, tmp1, 0xC);
|
||||
|
||||
tmp0 = _mm256_hadd_pd(vecs4_1, vecs5_1);
|
||||
tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
|
||||
|
||||
tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1);
|
||||
tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
|
||||
|
||||
final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC));
|
||||
|
||||
__m512d final_output = _mm512_insertf64x4(final_output, final_0, 0);
|
||||
|
||||
return _mm512_insertf64x4(final_output, final_1, 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -798,20 +1006,6 @@ struct palign_impl<Offset, Packet8d> {
|
||||
}
|
||||
};
|
||||
|
||||
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
|
||||
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT) \
|
||||
__m256 INPUT##_0 = _mm256_insertf128_ps( \
|
||||
_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \
|
||||
_mm512_extractf32x4_ps(INPUT, 1), 1); \
|
||||
__m256 INPUT##_1 = _mm256_insertf128_ps( \
|
||||
_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \
|
||||
_mm512_extractf32x4_ps(INPUT, 3), 1);
|
||||
|
||||
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
|
||||
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
|
||||
|
||||
#define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \
|
||||
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
|
||||
@ -850,22 +1044,22 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
|
||||
__m512 S14 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(1, 0, 1, 0));
|
||||
__m512 S15 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S0);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S1);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S2);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S3);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S4);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S5);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S6);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S7);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S8);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S9);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S10);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S11);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S12);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S13);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S14);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S15);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S0, S0);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S1, S1);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S2, S2);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S3, S3);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S4, S4);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S5, S5);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S6, S6);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S7, S7);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S8, S8);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S9, S9);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S10, S10);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S11, S11);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S12, S12);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S13, S13);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S14, S14);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S15, S15);
|
||||
|
||||
PacketBlock<Packet8f, 32> tmp;
|
||||
|
||||
@ -942,10 +1136,10 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
|
||||
__m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
|
||||
__m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S0);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S1);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S2);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S3);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S0, S0);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S1, S1);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S2, S2);
|
||||
EIGEN_EXTRACT_8f_FROM_16f(S3, S3);
|
||||
|
||||
PacketBlock<Packet8f, 8> tmp;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user