Commit 52a5f982 broke conjhelper functionality for HIP GPUs.

This commit addresses this.
This commit is contained in:
Rohit Santhanam 2021-06-25 19:28:00 +00:00
parent bffd267d17
commit 2d132d1736

View File

@ -45,16 +45,16 @@ template<bool Conjugate> struct conj_if;
template<> struct conj_if<true> {
template<typename T>
inline T operator()(const T& x) const { return numext::conj(x); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return numext::conj(x); }
template<typename T>
inline T pconj(const T& x) const { return internal::pconj(x); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T pconj(const T& x) const { return internal::pconj(x); }
};
template<> struct conj_if<false> {
template<typename T>
inline const T& operator()(const T& x) const { return x; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator()(const T& x) const { return x; }
template<typename T>
inline const T& pconj(const T& x) const { return x; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; }
};
// Generic implementation.
@ -63,10 +63,10 @@ struct conj_helper
{
typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType;
EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
{ return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); }
EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
{ return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); }
};
@ -75,10 +75,10 @@ struct conj_helper<LhsType, RhsType, true, true>
{
typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType;
EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
{ return Eigen::internal::pmadd(pconj(x), pconj(y), c); }
// We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
{ return pconj(Eigen::internal::pmul(x, y)); }
};
@ -86,18 +86,18 @@ struct conj_helper<LhsType, RhsType, true, true>
template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
{
typedef std::complex<RealScalar> Scalar;
EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
{ return c + conj_if<Conj>().pconj(x) * y; }
EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
{ return conj_if<Conj>().pconj(x) * y; }
};
template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
{
typedef std::complex<RealScalar> Scalar;
EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
{ return c + pmul(x,y); }
EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
{ return x * conj_if<Conj>().pconj(y); }
};