2
0
mirror of https://gitlab.com/libeigen/eigen.git synced 2025-04-24 19:40:45 +08:00

Fix AVX512 builds with MSVC.

This commit is contained in:
Antonio Sánchez 2022-03-18 16:04:53 +00:00
parent 7b10795e39
commit 9a14d91a99
2 changed files with 26 additions and 26 deletions
CMakeLists.txt
Eigen/src/Core/arch/AVX512

@ -360,11 +360,19 @@ else()
endif()
option(EIGEN_TEST_FMA "Enable/Disable FMA/AVX2 in tests/examples" OFF)
if(EIGEN_TEST_FMA AND NOT EIGEN_TEST_NEON)
option(EIGEN_TEST_AVX2 "Enable/Disable FMA/AVX2 in tests/examples" OFF)
if((EIGEN_TEST_FMA AND NOT EIGEN_TEST_NEON) OR EIGEN_TEST_AVX2)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
message(STATUS "Enabling FMA/AVX2 in tests/examples")
endif()
option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF)
option(EIGEN_TEST_AVX512DQ "Enable/Disable AVX512DQ in tests/examples" OFF)
if(EIGEN_TEST_AVX512 OR EIGEN_TEST_AVX512DQ)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX512")
message(STATUS "Enabling AVX512 in tests/examples")
endif()
endif()
option(EIGEN_TEST_NO_EXPLICIT_VECTORIZATION "Disable explicit vectorization in tests/examples" OFF)

@ -784,7 +784,7 @@ EIGEN_STRONG_INLINE Packet8d pload<Packet8d>(const double* from) {
template <>
EIGEN_STRONG_INLINE Packet16i pload<Packet16i>(const int* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512(
reinterpret_cast<const __m512i*>(from));
reinterpret_cast<const __m512i*>(from));
}
template <>
@ -1440,38 +1440,30 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]);
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]);
kernel.packet[0] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2)));
kernel.packet[1] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2)));
kernel.packet[2] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3)));
kernel.packet[3] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3)));
kernel.packet[4] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6)));
kernel.packet[5] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6)));
kernel.packet[6] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7)));
kernel.packet[7] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7)));
kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
T0 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[4]), 0x4E));
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
T4 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[0]), 0x4E));
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
T1 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[5]), 0x4E));
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
T5 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[1]), 0x4E));
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
T2 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[6]), 0x4E));
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
T6 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[2]), 0x4E));
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
T3 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[7]), 0x4E));
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
T7 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[3]), 0x4E));
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
kernel.packet[0] = T0; kernel.packet[1] = T1;