mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-31 19:00:35 +08:00
allow mixed complex-real and real-complex dot products
This commit is contained in:
parent
fe3bb545e0
commit
0bfb78c824
@ -41,18 +41,20 @@ template<typename T, typename U,
|
||||
>
|
||||
struct dot_nocheck
|
||||
{
|
||||
static inline typename traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
|
||||
typedef typename scalar_product_traits<typename traits<T>::Scalar,typename traits<U>::Scalar>::ReturnType ResScalar;
|
||||
static inline ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
|
||||
{
|
||||
return a.template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar> >(b).sum();
|
||||
return a.template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> >(b).sum();
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename U>
|
||||
struct dot_nocheck<T, U, true>
|
||||
{
|
||||
static inline typename traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
|
||||
typedef typename scalar_product_traits<typename traits<T>::Scalar,typename traits<U>::Scalar>::ReturnType ResScalar;
|
||||
static inline ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
|
||||
{
|
||||
return a.transpose().template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar> >(b).sum();
|
||||
return a.transpose().template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> >(b).sum();
|
||||
}
|
||||
};
|
||||
|
||||
@ -70,14 +72,14 @@ struct dot_nocheck<T, U, true>
|
||||
*/
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
typename internal::traits<Derived>::Scalar
|
||||
typename internal::scalar_product_traits<typename internal::traits<Derived>::Scalar,typename internal::traits<OtherDerived>::Scalar>::ReturnType
|
||||
MatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
|
||||
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived,OtherDerived)
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value),
|
||||
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
|
||||
typedef internal::scalar_conj_product_op<Scalar,typename OtherDerived::Scalar> func;
|
||||
EIGEN_CHECK_BINARY_COMPATIBILIY(func,Scalar,typename OtherDerived::Scalar);
|
||||
|
||||
eigen_assert(size() == other.size());
|
||||
|
||||
|
@ -59,6 +59,7 @@ struct functor_traits<scalar_sum_op<Scalar> > {
|
||||
*/
|
||||
template<typename LhsScalar,typename RhsScalar> struct scalar_product_op {
|
||||
enum {
|
||||
// TODO vectorize mixed product
|
||||
Vectorizable = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasMul && packet_traits<RhsScalar>::HasMul
|
||||
};
|
||||
typedef typename scalar_product_traits<LhsScalar,RhsScalar>::ReturnType result_type;
|
||||
@ -84,24 +85,27 @@ struct functor_traits<scalar_product_op<LhsScalar,RhsScalar> > {
|
||||
*
|
||||
* This is a short cut for conj(x) * y which is needed for optimization purpose; in Eigen2 support mode, this becomes x * conj(y)
|
||||
*/
|
||||
template<typename Scalar> struct scalar_conj_product_op {
|
||||
template<typename LhsScalar,typename RhsScalar> struct scalar_conj_product_op {
|
||||
|
||||
enum {
|
||||
Conj = NumTraits<Scalar>::IsComplex
|
||||
Conj = NumTraits<LhsScalar>::IsComplex
|
||||
};
|
||||
|
||||
typedef typename scalar_product_traits<LhsScalar,RhsScalar>::ReturnType result_type;
|
||||
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op)
|
||||
EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& b) const
|
||||
{ return conj_helper<Scalar,Scalar,Conj,false>().pmul(a,b); }
|
||||
EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const
|
||||
{ return conj_helper<LhsScalar,RhsScalar,Conj,false>().pmul(a,b); }
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
|
||||
{ return conj_helper<Packet,Packet,Conj,false>().pmul(a,b); }
|
||||
};
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_conj_product_op<Scalar> > {
|
||||
template<typename LhsScalar,typename RhsScalar>
|
||||
struct functor_traits<scalar_conj_product_op<LhsScalar,RhsScalar> > {
|
||||
enum {
|
||||
Cost = NumTraits<Scalar>::MulCost,
|
||||
PacketAccess = packet_traits<Scalar>::HasMul
|
||||
Cost = NumTraits<LhsScalar>::MulCost,
|
||||
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMul
|
||||
};
|
||||
};
|
||||
|
||||
@ -622,6 +626,7 @@ template<typename Scalar> struct functor_has_linear_access<scalar_identity_op<Sc
|
||||
// FIXME move this to functor_traits adding a functor_default
|
||||
template<typename Functor> struct functor_allows_mixing_real_and_complex { enum { ret = 0 }; };
|
||||
template<typename LhsScalar,typename RhsScalar> struct functor_allows_mixing_real_and_complex<scalar_product_op<LhsScalar,RhsScalar> > { enum { ret = 1 }; };
|
||||
template<typename LhsScalar,typename RhsScalar> struct functor_allows_mixing_real_and_complex<scalar_conj_product_op<LhsScalar,RhsScalar> > { enum { ret = 1 }; };
|
||||
|
||||
|
||||
/** \internal
|
||||
|
@ -202,10 +202,11 @@ template<typename Derived> class MatrixBase
|
||||
|
||||
#if EIGEN2_SUPPORT_STAGE != STAGE20_RESOLVE_API_CONFLICTS
|
||||
template<typename OtherDerived>
|
||||
typename internal::scalar_product_traits<typename internal::traits<Derived>::Scalar,typename internal::traits<OtherDerived>::Scalar>::ReturnType
|
||||
#if EIGEN2_SUPPORT_STAGE == STAGE15_RESOLVE_API_CONFLICTS_WARN
|
||||
EIGEN_DEPRECATED
|
||||
EIGEN_DEPRECATED Scalar
|
||||
#endif
|
||||
Scalar dot(const MatrixBase<OtherDerived>& other) const;
|
||||
dot(const MatrixBase<OtherDerived>& other) const;
|
||||
#endif
|
||||
|
||||
#ifdef EIGEN2_SUPPORT
|
||||
|
@ -155,7 +155,7 @@ template<typename LhsScalar, typename RhsScalar, bool ConjLhs=false, bool ConjRh
|
||||
|
||||
template<typename Scalar> struct scalar_sum_op;
|
||||
template<typename Scalar> struct scalar_difference_op;
|
||||
template<typename Scalar> struct scalar_conj_product_op;
|
||||
template<typename LhsScalar,typename RhsScalar> struct scalar_conj_product_op;
|
||||
template<typename Scalar> struct scalar_quotient_op;
|
||||
template<typename Scalar> struct scalar_opposite_op;
|
||||
template<typename Scalar> struct scalar_conjugate_op;
|
||||
|
@ -106,6 +106,11 @@ template<typename MatrixType> void adjoint(const MatrixType& m)
|
||||
m3.transposeInPlace();
|
||||
VERIFY_IS_APPROX(m3,m1.conjugate());
|
||||
|
||||
// check mixed dot product
|
||||
typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealVectorType;
|
||||
RealVectorType rv1 = RealVectorType::Random(rows);
|
||||
VERIFY_IS_APPROX(v1.dot(rv1.template cast<Scalar>()), v1.dot(rv1));
|
||||
VERIFY_IS_APPROX(rv1.template cast<Scalar>().dot(v1), rv1.dot(v1));
|
||||
}
|
||||
|
||||
void test_adjoint()
|
||||
|
Loading…
x
Reference in New Issue
Block a user