From 7f3162f7071db63bdbdc21f4c101543df00e4661 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 27 Mar 2014 17:42:25 -0700 Subject: [PATCH] Implemented the AVX version of the gather and scatter packet primitives. --- Eigen/src/Core/arch/AVX/Complex.h | 39 +++++++++++++++++++++++++++- Eigen/src/Core/arch/AVX/PacketMath.h | 35 +++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index 0121cec86..ec9c861f9 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -99,6 +99,29 @@ template<> EIGEN_STRONG_INLINE Packet4cf ploaddup(const std::complex< template<> EIGEN_STRONG_INLINE void pstore >(std::complex* to, const Packet4cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); } template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex* to, const Packet4cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); } +template<> EIGEN_DEVICE_FUNC inline Packet4cf pgather, Packet4cf>(const std::complex* from, int stride) +{ + return Packet4cf(_mm256_set_ps(std::imag(from[3*stride]), std::real(from[3*stride]), + std::imag(from[2*stride]), std::real(from[2*stride]), + std::imag(from[1*stride]), std::real(from[1*stride]), + std::imag(from[0*stride]), std::real(from[0*stride]))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet4cf>(std::complex* to, const Packet4cf& from, int stride) +{ + __m128 low = _mm256_extractf128_ps(from.v, 0); + to[stride*0] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 0)), + _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1))); + to[stride*1] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)), + _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3))); + + __m128 high = _mm256_extractf128_ps(from.v, 1); + to[stride*2] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 0)), + _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1))); + to[stride*3] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)), + _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3))); + +} template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet4cf& a) { @@ -297,7 +320,21 @@ template<> EIGEN_STRONG_INLINE Packet2cd ploaddup(const std::complex< template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); } template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); } -template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cd& a) +template<> EIGEN_DEVICE_FUNC inline Packet2cd pgather, Packet2cd>(const std::complex* from, int stride) +{ + return Packet2cd(_mm256_set_pd(std::imag(from[1*stride]), std::real(from[1*stride]), + std::imag(from[0*stride]), std::real(from[0*stride]))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet2cd>(std::complex* to, const Packet2cd& from, int stride) +{ + __m128d low = _mm256_extractf128_pd(from.v, 0); + to[stride*0] = std::complex(_mm_cvtsd_f64(low), _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1))); + __m128d high = _mm256_extractf128_pd(from.v, 1); + to[stride*1] = std::complex(_mm_cvtsd_f64(high), _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1))); +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cd& a) { __m128d low = _mm256_extractf128_pd(a.v, 0); EIGEN_ALIGN16 double res[2]; diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 96a4bc08c..aa2ac3b0b 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -179,6 +179,41 @@ template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet8f& template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet4d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_pd(to, from); } template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } +// TODO: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available +template<> EIGEN_DEVICE_FUNC inline Packet8f pgather(const float* from, int stride) +{ + return _mm256_set_ps(from[7*stride], from[6*stride], from[5*stride], from[4*stride], + from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} +template<> EIGEN_DEVICE_FUNC inline Packet4d pgather(const double* from, int stride) +{ + return _mm256_set_pd(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet8f& from, int stride) +{ + __m128 low = _mm256_extractf128_ps(from, 0); + to[stride*0] = _mm_cvtss_f32(low); + to[stride*1] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1)); + to[stride*2] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)); + to[stride*3] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3)); + + __m128 high = _mm256_extractf128_ps(from, 1); + to[stride*4] = _mm_cvtss_f32(high); + to[stride*5] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1)); + to[stride*6] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)); + to[stride*7] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3)); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packet4d& from, int stride) +{ + __m128d low = _mm256_extractf128_pd(from, 0); + to[stride*0] = _mm_cvtsd_f64(low); + to[stride*1] = _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1)); + __m128d high = _mm256_extractf128_pd(from, 1); + to[stride*2] = _mm_cvtsd_f64(high); + to[stride*3] = _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1)); +} + template<> EIGEN_STRONG_INLINE void pstore1(float* to, const float& a) { Packet8f pa = pset1(a);