add a conj_product functor and optimize dot products

This commit is contained in:
Gael Guennebaud 2010-07-07 10:00:08 +02:00
parent f8d3b4c060
commit e38fc9692d
3 changed files with 25 additions and 2 deletions

View File

@ -41,7 +41,7 @@ struct ei_dot_nocheck
{
static inline typename ei_traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
{
return a.conjugate().cwiseProduct(b).sum();
return a.template binaryExpr<ei_scalar_conj_product_op<typename ei_traits<T>::Scalar> >(b).sum();
}
};
@ -50,7 +50,7 @@ struct ei_dot_nocheck<T, U, true>
{
static inline typename ei_traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
{
return a.adjoint().cwiseProduct(b).sum();
return a.transpose().template binaryExpr<ei_scalar_conj_product_op<typename ei_traits<T>::Scalar> >(b).sum();
}
};

View File

@ -73,6 +73,28 @@ struct ei_functor_traits<ei_scalar_product_op<Scalar> > {
};
};
/** \internal
* \brief Template functor to compute the conjugate product of two scalars
*
* This is a short cut for ei_conj(x) * y which is needed for optimization purpose
*/
template<typename Scalar> struct ei_scalar_conj_product_op {
enum { Conj = NumTraits<Scalar>::IsComplex };
EIGEN_EMPTY_STRUCT_CTOR(ei_scalar_conj_product_op)
EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& b) const
{ return ei_conj_helper<Scalar,Scalar,Conj,false>().pmul(a,b); }
template<typename PacketScalar>
EIGEN_STRONG_INLINE const PacketScalar packetOp(const PacketScalar& a, const PacketScalar& b) const
{ return ei_conj_helper<PacketScalar,PacketScalar,Conj,false>().pmul(a,b); }
};
template<typename Scalar>
struct ei_functor_traits<ei_scalar_conj_product_op<Scalar> > {
enum {
Cost = NumTraits<Scalar>::MulCost,
PacketAccess = ei_packet_traits<Scalar>::HasMul
};
};
/** \internal
* \brief Template functor to compute the min of two scalars
*

View File

@ -113,6 +113,7 @@ template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs> str
template<typename Scalar> struct ei_scalar_sum_op;
template<typename Scalar> struct ei_scalar_difference_op;
template<typename Scalar> struct ei_scalar_product_op;
template<typename Scalar> struct ei_scalar_conj_product_op;
template<typename Scalar> struct ei_scalar_quotient_op;
template<typename Scalar> struct ei_scalar_opposite_op;
template<typename Scalar> struct ei_scalar_conjugate_op;