Implement a generic vectorized version of Smith's algorithms for complex division.

This commit is contained in:
Rasmus Munk Larsen 2021-06-30 15:53:06 -07:00
parent 6035da5283
commit 9312a5bf5c
9 changed files with 46 additions and 114 deletions

View File

@ -167,15 +167,12 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet4cf>(const P
Packet2cf(_mm256_extractf128_ps(a.v, 1))));
}
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f)
template<> EIGEN_STRONG_INLINE Packet4cf pdiv<Packet4cf>(const Packet4cf& a, const Packet4cf& b)
{
Packet4cf num = pmul(a, pconj(b));
__m256 tmp = _mm256_mul_ps(b.v, b.v);
__m256 tmp2 = _mm256_shuffle_ps(tmp,tmp,0xB1);
__m256 denom = _mm256_add_ps(tmp, tmp2);
return Packet4cf(_mm256_div_ps(num.v, denom));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4cf pcplxflip<Packet4cf>(const Packet4cf& x)
@ -321,10 +318,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d)
template<> EIGEN_STRONG_INLINE Packet2cd pdiv<Packet2cd>(const Packet2cd& a, const Packet2cd& b)
{
Packet2cd num = pmul(a, pconj(b));
__m256d tmp = _mm256_mul_pd(b.v, b.v);
__m256d denom = _mm256_hadd_pd(tmp, tmp);
return Packet2cd(_mm256_div_pd(num.v, denom));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet2cd pcplxflip<Packet2cd>(const Packet2cd& x)

View File

@ -157,11 +157,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf,Packet16f)
template<> EIGEN_STRONG_INLINE Packet8cf pdiv<Packet8cf>(const Packet8cf& a, const Packet8cf& b)
{
Packet8cf num = pmul(a, pconj(b));
__m512 tmp = _mm512_mul_ps(b.v, b.v);
__m512 tmp2 = _mm512_shuffle_ps(tmp,tmp,0xB1);
__m512 denom = _mm512_add_ps(tmp, tmp2);
return Packet8cf(_mm512_div_ps(num.v, denom));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet8cf pcplxflip<Packet8cf>(const Packet8cf& x)
@ -309,47 +305,11 @@ template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet4cd>(const
Packet2cd(_mm512_extractf64x4_pd(a.v,1))));
}
template<> struct conj_helper<Packet4cd, Packet4cd, false,true>
{
EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
{ return padd(pmul(x,y),c); }
EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
{
return internal::pmul(a, pconj(b));
}
};
template<> struct conj_helper<Packet4cd, Packet4cd, true,false>
{
EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
{ return padd(pmul(x,y),c); }
EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
{
return internal::pmul(pconj(a), b);
}
};
template<> struct conj_helper<Packet4cd, Packet4cd, true,true>
{
EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
{ return padd(pmul(x,y),c); }
EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
{
return pconj(internal::pmul(a, b));
}
};
EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd,Packet8d)
template<> EIGEN_STRONG_INLINE Packet4cd pdiv<Packet4cd>(const Packet4cd& a, const Packet4cd& b)
{
Packet4cd num = pmul(a, pconj(b));
__m512d tmp = _mm512_mul_pd(b.v, b.v);
__m512d denom = padd(_mm512_permute_pd(tmp,0x55), tmp);
return Packet4cd(_mm512_div_pd(num.v, denom));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x)

View File

@ -210,10 +210,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for AltiVec
Packet2cf res = pmul(a, pconj(b));
Packet4f s = pmul<Packet4f>(b.v, b.v);
return Packet2cf(pdiv(res.v, padd<Packet4f>(s, vec_perm(s, s, p16uc_COMPLEX32_REV))));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& x)
@ -375,10 +372,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for AltiVec
Packet1cd res = pmul(a,pconj(b));
Packet2d s = pmul<Packet2d>(b.v, b.v);
return Packet1cd(pdiv(res.v, padd<Packet2d>(s, vec_perm(s, s, p16uc_REVERSE64))));
return pdiv_complex(a, b);
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)

View File

@ -757,6 +757,26 @@ Packet pcos_float(const Packet& x)
return psincos_float<false>(x);
}
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED Packet pdiv_complex(const Packet& x, const Packet& y) {
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// In the following we annotate the code for the case where the inputs
// are a pair length-2 SIMD vectors representing a single pair of complex
// numbers x = a + i*b, y = c + i*d.
const RealPacket y_abs = pabs(y.v); // |c|, |d|
const RealPacket y_abs_flip = pcplxflip(Packet(y_abs)).v; // |d|, |c|
const RealPacket y_max = pmax(y_abs, y_abs_flip); // max(|c|, |d|), max(|c|, |d|)
const RealPacket y_scaled = pdiv(y.v, y_max); // c / max(|c|, |d|), d / max(|c|, |d|)
// Compute scaled denominator.
const RealPacket y_scaled_sq = pmul(y_scaled, y_scaled); // c'**2, d'**2
const RealPacket denom = y_scaled_sq + pcplxflip(Packet(y_scaled_sq)).v;
Packet result_scaled = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c
// Divide elementwise by denom.
result_scaled = Packet(pdiv(result_scaled.v, denom));
// Rescale result
return Packet(pdiv(result_scaled.v, y_max));
}
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS

View File

@ -101,6 +101,12 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psqrt_complex(const Packet& a);
/** \internal \returns x / y for complex types */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet pdiv_complex(const Packet& x, const Packet& y);
template <typename Packet, int N> struct ppolevl;

View File

@ -75,15 +75,12 @@ struct Packet2cf {
EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const {
return Packet2cf(*this) -= b;
}
EIGEN_STRONG_INLINE Packet2cf& operator/=(const Packet2cf& b) {
*this *= b.conjugate();
Packet4f s = pmul<Packet4f>(b.v, b.v);
s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
v = pdiv(v, s);
return *this;
}
EIGEN_STRONG_INLINE Packet2cf operator/(const Packet2cf& b) const {
return Packet2cf(*this) /= b;
return pdiv_complex(Packet2cf(*this), b);
}
EIGEN_STRONG_INLINE Packet2cf& operator/=(const Packet2cf& b) {
*this = Packet2cf(*this) / b;
return *this;
}
EIGEN_STRONG_INLINE Packet2cf operator-(void) const {
return Packet2cf(pnegate(v));

View File

@ -347,27 +347,11 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet1cf pdiv<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
{
// TODO optimize it for NEON
Packet1cf res = pmul(a, pconj(b));
Packet2f s, rev_s;
// this computes the norm
s = vmul_f32(b.v, b.v);
rev_s = vrev64_f32(s);
return Packet1cf(pdiv<Packet2f>(res.v, vadd_f32(s, rev_s)));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for NEON
Packet2cf res = pmul(a,pconj(b));
Packet4f s, rev_s;
// this computes the norm
s = vmulq_f32(b.v, b.v);
rev_s = vrev64q_f32(s);
return Packet2cf(pdiv<Packet4f>(res.v, vaddq_f32(s, rev_s)));
return pdiv_complex(a, b);
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet1cf, 1>& /*kernel*/) {}
@ -553,12 +537,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for NEON
Packet1cd res = pmul(a,pconj(b));
Packet2d s = pmul<Packet2d>(b.v, b.v);
Packet2d rev_s = preverse<Packet2d>(s);
return Packet1cd(pdiv(res.v, padd<Packet2d>(s,rev_s)));
return pdiv_complex(a, b);
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)

View File

@ -174,14 +174,9 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for SSE3 and 4
Packet2cf res = pmul(a, pconj(b));
__m128 s = _mm_mul_ps(b.v,b.v);
return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,vec4f_swizzle1(s, 1, 0, 3, 2))));
return pdiv_complex(a, b);
}
//---------- double ----------
struct Packet1cd
{
@ -299,10 +294,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for SSE3 and 4
Packet1cd res = pmul(a,pconj(b));
__m128d s = _mm_mul_pd(b.v,b.v);
return Packet1cd(_mm_div_pd(res.v, _mm_add_pd(s,_mm_shuffle_pd(s, s, 0x1))));
return pdiv_complex(a, b);
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip/* <Packet1cd> */(const Packet1cd& x)

View File

@ -169,10 +169,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for AltiVec
Packet1cd res = pmul(a,pconj(b));
Packet2d s = vec_madd(b.v, b.v, p2d_ZERO_);
return Packet1cd(pdiv(res.v, s + vec_perm(s, s, p16uc_REVERSE64)));
return pdiv_complex(a, b);
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)
@ -308,11 +305,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for AltiVec
Packet2cf res;
res.cd[0] = pdiv<Packet1cd>(a.cd[0], b.cd[0]);
res.cd[1] = pdiv<Packet1cd>(a.cd[1], b.cd[1]);
return res;
return pdiv_complex(a, b);
}
EIGEN_STRONG_INLINE Packet2cf pcplxflip/*<Packet2cf>*/(const Packet2cf& x)
@ -394,10 +387,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for AltiVec
Packet2cf res = pmul(a, pconj(b));
Packet4f s = pmul<Packet4f>(b.v, b.v);
return Packet2cf(pdiv(res.v, padd<Packet4f>(s, vec_perm(s, s, p16uc_COMPLEX32_REV))));
return pdiv_complex(a, b);
}
template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& x)