* add real scalar * complex matrix, real matrix * complex scalar,

and complex scalar * real matrix overloads
* allows the inner and outer product specialisations to mix real and complex
This commit is contained in:
Gael Guennebaud 2009-09-04 11:22:32 +02:00
parent 6902ef0824
commit b0aa2520f1
6 changed files with 70 additions and 29 deletions

View File

@ -232,7 +232,7 @@ Cwise<ExpressionType>::log() const
}
/** \relates MatrixBase */
/** \returns an expression of \c *this scaled by the scalar factor \a scalar */
template<typename Derived>
EIGEN_STRONG_INLINE const typename MatrixBase<Derived>::ScalarMultipleReturnType
MatrixBase<Derived>::operator*(const Scalar& scalar) const
@ -241,7 +241,17 @@ MatrixBase<Derived>::operator*(const Scalar& scalar) const
(derived(), ei_scalar_multiple_op<Scalar>(scalar));
}
/** \relates MatrixBase */
/** Overloaded for efficient real matrix times complex scalar value */
template<typename Derived>
EIGEN_STRONG_INLINE const CwiseUnaryOp<ei_scalar_multiple2_op<typename ei_traits<Derived>::Scalar,
std::complex<typename ei_traits<Derived>::Scalar> >, Derived>
MatrixBase<Derived>::operator*(const std::complex<Scalar>& scalar) const
{
return CwiseUnaryOp<ei_scalar_multiple2_op<Scalar,std::complex<Scalar> >, Derived>
(*static_cast<const Derived*>(this), ei_scalar_multiple2_op<Scalar,std::complex<Scalar> >(scalar));
}
/** \returns an expression of \c *this divided by the scalar value \a scalar */
template<typename Derived>
EIGEN_STRONG_INLINE const CwiseUnaryOp<ei_scalar_quotient1_op<typename ei_traits<Derived>::Scalar>, Derived>
MatrixBase<Derived>::operator/(const Scalar& scalar) const

View File

@ -35,7 +35,7 @@
*
* Notice that this class is trivial, it is only used to disambiguate overloaded functions.
*/
template<typename Derived> struct AnyMatrixBase
template<typename Derived> struct AnyMatrixBase
: public ei_special_scalar_op_base<Derived,typename ei_traits<Derived>::Scalar,
typename NumTraits<typename ei_traits<Derived>::Scalar>::Real>
{
@ -93,7 +93,7 @@ template<typename Derived> struct AnyMatrixBase
*/
template<typename Derived> class MatrixBase
#ifndef EIGEN_PARSED_BY_DOXYGEN
: public AnyMatrixBase<Derived>
: public AnyMatrixBase<Derived>
#endif // not EIGEN_PARSED_BY_DOXYGEN
{
public:
@ -419,10 +419,17 @@ template<typename Derived> class MatrixBase
const CwiseUnaryOp<ei_scalar_quotient1_op<typename ei_traits<Derived>::Scalar>, Derived>
operator/(const Scalar& scalar) const;
inline friend const CwiseUnaryOp<ei_scalar_multiple_op<typename ei_traits<Derived>::Scalar>, Derived>
const CwiseUnaryOp<ei_scalar_multiple2_op<Scalar,std::complex<Scalar> >, Derived>
operator*(const std::complex<Scalar>& scalar) const;
inline friend const ScalarMultipleReturnType
operator*(const Scalar& scalar, const MatrixBase& matrix)
{ return matrix*scalar; }
inline friend const CwiseUnaryOp<ei_scalar_multiple2_op<Scalar,std::complex<Scalar> >, Derived>
operator*(const std::complex<Scalar>& scalar, const MatrixBase& matrix)
{ return matrix*scalar; }
template<typename OtherDerived>
const typename ProductReturnType<Derived,OtherDerived>::Type
operator*(const MatrixBase<OtherDerived> &other) const;

View File

@ -84,18 +84,18 @@ public:
* based on the three dimensions of the product.
* This is a compile time mapping from {1,Small,Large}^3 -> {product types} */
// FIXME I'm not sure the current mapping is the ideal one.
template<int Rows, int Cols> struct ei_product_type_selector<Rows,Cols,1> { enum { ret = OuterProduct }; };
template<int Depth> struct ei_product_type_selector<1,1,Depth> { enum { ret = InnerProduct }; };
template<> struct ei_product_type_selector<1,1,1> { enum { ret = InnerProduct }; };
template<> struct ei_product_type_selector<Small,1,Small> { enum { ret = UnrolledProduct }; };
template<> struct ei_product_type_selector<1,Small,Small> { enum { ret = UnrolledProduct }; };
template<int Rows, int Cols> struct ei_product_type_selector<Rows, Cols, 1> { enum { ret = OuterProduct }; };
template<int Depth> struct ei_product_type_selector<1, 1, Depth> { enum { ret = InnerProduct }; };
template<> struct ei_product_type_selector<1, 1, 1> { enum { ret = InnerProduct }; };
template<> struct ei_product_type_selector<Small,1, Small> { enum { ret = UnrolledProduct }; };
template<> struct ei_product_type_selector<1, Small,Small> { enum { ret = UnrolledProduct }; };
template<> struct ei_product_type_selector<Small,Small,Small> { enum { ret = UnrolledProduct }; };
template<> struct ei_product_type_selector<1,Large,Small> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<1,Large,Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<1,Small,Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Large,1,Small> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Large,1,Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Small,1,Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<1, Large,Small> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<1, Large,Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<1, Small,Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Large,1, Small> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Large,1, Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Small,1, Large> { enum { ret = GemvProduct }; };
template<> struct ei_product_type_selector<Small,Small,Large> { enum { ret = GemmProduct }; };
template<> struct ei_product_type_selector<Large,Small,Large> { enum { ret = GemmProduct }; };
template<> struct ei_product_type_selector<Small,Large,Large> { enum { ret = GemmProduct }; };
@ -164,7 +164,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
{
EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret),
EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::RealScalar, typename Rhs::RealScalar>::ret),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
}
@ -203,7 +203,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
{
EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret),
EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::RealScalar, typename Rhs::RealScalar>::ret),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
}
@ -217,6 +217,7 @@ template<> struct ei_outer_product_selector<ColMajor> {
template<typename ProductType, typename Dest>
EIGEN_DONT_INLINE static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
// FIXME make sure lhs is sequentially stored
// FIXME not very good if rhs is real and lhs complex while alpha is real too
const int cols = dest.cols();
for (int j=0; j<cols; ++j)
dest.col(j) += (alpha * prod.rhs().coeff(j)) * prod.lhs();
@ -227,6 +228,7 @@ template<> struct ei_outer_product_selector<RowMajor> {
template<typename ProductType, typename Dest>
EIGEN_DONT_INLINE static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
// FIXME make sure rhs is sequentially stored
// FIXME not very good if lhs is real and rhs complex while alpha is real too
const int rows = dest.rows();
for (int i=0; i<rows; ++i)
dest.row(i) += (alpha * prod.lhs().coeff(i)) * prod.rhs();

View File

@ -33,7 +33,7 @@ struct ei_traits<ProductBase<Derived,_Lhs,_Rhs> >
{
typedef typename ei_cleantype<_Lhs>::type Lhs;
typedef typename ei_cleantype<_Rhs>::type Rhs;
typedef typename ei_traits<Lhs>::Scalar Scalar;
typedef typename ei_scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
enum {
RowsAtCompileTime = ei_traits<Lhs>::RowsAtCompileTime,
ColsAtCompileTime = ei_traits<Rhs>::ColsAtCompileTime,
@ -146,7 +146,7 @@ class ScaledProduct;
// functions of ProductBase, because, otherwise we would have to
// define all overloads defined in MatrixBase. Furthermore, Using
// "using Base::operator*" would not work with MSVC.
//
//
// Also note that here we accept any compatible scalar types
template<typename Derived,typename Lhs,typename Rhs>
const ScaledProduct<Derived>

View File

@ -233,6 +233,10 @@ struct ei_special_scalar_op_base<Derived,Scalar,OtherScalar,true>
return CwiseUnaryOp<ei_scalar_multiple2_op<Scalar,OtherScalar>, Derived>
(*static_cast<const Derived*>(this), ei_scalar_multiple2_op<Scalar,OtherScalar>(scalar));
}
inline friend const CwiseUnaryOp<ei_scalar_multiple2_op<Scalar,OtherScalar>, Derived>
operator*(const OtherScalar& scalar, const Derived& matrix)
{ return matrix*scalar; }
};
/** \internal Gives the type of a sub-matrix or sub-vector of a matrix of type \a ExpressionType and size \a Size

View File

@ -54,6 +54,11 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType)
Vec_d vd = vf.template cast<double>();
Vec_cf vcf = Vec_cf::Random(size,1);
Vec_cd vcd = vcf.template cast<complex<double> >();
float sf = ei_random<float>();
double sd = ei_random<double>();
complex<float> scf = ei_random<complex<float> >();
complex<double> scd = ei_random<complex<double> >();
mf+mf;
VERIFY_RAISES_ASSERT(mf+md);
@ -62,18 +67,31 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType)
VERIFY_RAISES_ASSERT(vf+=vd);
VERIFY_RAISES_ASSERT(mcd=md);
// check scalar products
VERIFY_IS_APPROX(vcf * sf , vcf * complex<float>(sf));
VERIFY_IS_APPROX(sd * vcd, complex<double>(sd) * vcd);
VERIFY_IS_APPROX(vf * scf , vf.template cast<complex<float> >() * scf);
VERIFY_IS_APPROX(scd * vd, scd * vd.template cast<complex<double> >());
// check dot product
vf.dot(vf);
VERIFY_RAISES_ASSERT(vd.dot(vf));
VERIFY_RAISES_ASSERT(vcf.dot(vf)); // yeah eventually we should allow this but i'm too lazy to make that change now in Dot.h
// especially as that might be rewritten as cwise product .sum() which would make that automatic.
// check diagonal product
VERIFY_IS_APPROX(vf.asDiagonal() * mcf, vf.template cast<complex<float> >().asDiagonal() * mcf);
VERIFY_IS_APPROX(vcd.asDiagonal() * md, vcd.asDiagonal() * md.template cast<complex<double> >());
VERIFY_IS_APPROX(mcf * vf.asDiagonal(), mcf * vf.template cast<complex<float> >().asDiagonal());
VERIFY_IS_APPROX(md * vcd.asDiagonal(), md.template cast<complex<double> >() * vcd.asDiagonal());
// vd.asDiagonal() * mf; // does not even compile
// vcd.asDiagonal() * mf; // does not even compile
// check inner product
VERIFY_IS_APPROX((vf.transpose() * vcf).value(), (vf.template cast<complex<float> >().transpose() * vcf).value());
// check outer product
VERIFY_IS_APPROX((vf * vcf.transpose()).eval(), (vf.template cast<complex<float> >() * vcf.transpose()).eval());
}
@ -108,9 +126,9 @@ void mixingtypes_large(int size)
// VERIFY_RAISES_ASSERT(vcd = md*vcd); // does not even compile (cannot convert complex to double)
VERIFY_RAISES_ASSERT(vcf = mcf*vf);
VERIFY_RAISES_ASSERT(mf*md);
VERIFY_RAISES_ASSERT(mcf*mcd);
VERIFY_RAISES_ASSERT(mcf*vcd);
// VERIFY_RAISES_ASSERT(mf*md); // does not even compile
// VERIFY_RAISES_ASSERT(mcf*mcd); // does not even compile
// VERIFY_RAISES_ASSERT(mcf*vcd); // does not even compile
VERIFY_RAISES_ASSERT(vcf = mf*vf);
}
@ -157,9 +175,9 @@ void test_mixingtypes()
{
// check that our operator new is indeed called:
CALL_SUBTEST(mixingtypes<3>());
CALL_SUBTEST(mixingtypes<4>());
CALL_SUBTEST(mixingtypes<Dynamic>(20));
CALL_SUBTEST(mixingtypes_small<4>());
CALL_SUBTEST(mixingtypes_large(20));
// CALL_SUBTEST(mixingtypes<4>());
// CALL_SUBTEST(mixingtypes<Dynamic>(20));
//
// CALL_SUBTEST(mixingtypes_small<4>());
// CALL_SUBTEST(mixingtypes_large(20));
}