overload operartor* with a ProductBase such that "scalar * (mat * mat)" is optimized

as one could naturally expect
This commit is contained in:
Gael Guennebaud 2009-08-11 15:15:06 +02:00
parent a4f6642518
commit afbd73b5cd
9 changed files with 76 additions and 16 deletions

View File

@ -153,7 +153,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==1 && dst.cols()==1);
dst.coeffRef(0,0) += alpha * (m_lhs.cwise()*m_rhs).sum();
@ -179,7 +179,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void addTo(Dest& dest, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
ei_outer_product_selector<Dest::Flags&RowMajorBit>::run(*this, dest, alpha);
}
@ -236,7 +236,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
typedef typename ei_meta_if<int(Side)==OnTheRight,_LhsNested,_RhsNested>::ret MatrixType;
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
ei_gemv_selector<Side,int(MatrixType::Flags)&RowMajorBit,

View File

@ -100,16 +100,16 @@ class ProductBase : public MatrixBase<Derived>
inline int cols() const { return m_rhs.cols(); }
template<typename Dest>
inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); }
inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,1); }
template<typename Dest>
inline void addTo(Dest& dst) const { addTo(dst,1); }
inline void addTo(Dest& dst) const { scaleAndAddTo(dst,1); }
template<typename Dest>
inline void subTo(Dest& dst) const { addTo(dst,-1); }
inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-1); }
template<typename Dest>
inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); }
inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { derived().scaleAndAddTo(dst,alpha); }
PlainMatrixType eval() const
{
@ -141,13 +141,68 @@ class ProductBase : public MatrixBase<Derived>
void coeffRef(int);
};
template<typename NestedProduct>
class ScaledProduct;
// Note that these two operator* functions are not defined as member
// functions of ProductBase, because, otherwise we would have to
// define all overloads defined in MatrixBase. Furthermore, Using
// "using Base::operator*" would not work with MSVC.
template<typename Derived,typename Lhs,typename Rhs>
const ScaledProduct<Derived> operator*(const ProductBase<Derived,Lhs,Rhs>& prod, typename Derived::Scalar x)
{ return ScaledProduct<Derived>(prod.derived(), x); }
template<typename Derived,typename Lhs,typename Rhs>
const ScaledProduct<Derived> operator*(typename Derived::Scalar x,const ProductBase<Derived,Lhs,Rhs>& prod)
{ return ScaledProduct<Derived>(prod.derived(), x); }
template<typename NestedProduct>
struct ei_traits<ScaledProduct<NestedProduct> >
: ei_traits<ProductBase<ScaledProduct<NestedProduct>,
typename NestedProduct::_LhsNested,
typename NestedProduct::_RhsNested> >
{};
template<typename NestedProduct>
class ScaledProduct
: public ProductBase<ScaledProduct<NestedProduct>,
typename NestedProduct::_LhsNested,
typename NestedProduct::_RhsNested>
{
public:
typedef ProductBase<ScaledProduct<NestedProduct>,
typename NestedProduct::_LhsNested,
typename NestedProduct::_RhsNested> Base;
typedef typename Base::Scalar Scalar;
// EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct)
ScaledProduct(const NestedProduct& prod, Scalar& x)
: Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {}
template<typename Dest>
inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,m_alpha); }
template<typename Dest>
inline void addTo(Dest& dst) const { scaleAndAddTo(dst,m_alpha); }
template<typename Dest>
inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-m_alpha); }
template<typename Dest>
inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { m_prod.derived().scaleAndAddTo(dst,alpha); }
protected:
const NestedProduct& m_prod;
Scalar m_alpha;
};
/** \internal
* Overloaded to perform an efficient C = (A*B).lazy() */
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other)
{
other.evalTo(derived()); return derived();
other.derived().evalTo(derived()); return derived();
}
/** \internal
@ -157,7 +212,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator+=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{
other._expression().addTo(derived()); return derived();
other._expression().derived().addTo(derived()); return derived();
}
/** \internal
@ -167,7 +222,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator-=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{
other._expression().subTo(derived()); return derived();
other._expression().derived().subTo(derived()); return derived();
}
#endif // EIGEN_PRODUCTBASE_H

View File

@ -137,7 +137,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());

View File

@ -375,7 +375,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
};
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());

View File

@ -175,7 +175,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());

View File

@ -333,7 +333,7 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);

View File

@ -130,7 +130,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());

View File

@ -79,7 +79,7 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT( m3 = (s1 * m1 * s2 * (m1*s3+m2*s2).adjoint()).lazy(), 1);
VERIFY_EVALUATION_COUNT( m3 = ((s1 * m1).adjoint() * s2 * m2).lazy(), 0);
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (-m1*s3).adjoint() * (s2 * m2 * s3)).lazy(), 0);
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 1);
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 0);
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) += (-m1.block(r0,c0,r1,c1) * (s2*m2.block(r0,c0,r1,c1)).adjoint()).lazy() ), 0);
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) -= (s1 * m1.block(r0,c0,r1,c1) * m2.block(c0,r0,c1,r1)).lazy() ), 0);

View File

@ -94,6 +94,11 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in
VERIFY_IS_APPROX(rhs12 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
rhs13 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
m2 = m1.template triangularView<UpperTriangular>(); rhs13 = rhs12;
VERIFY_IS_APPROX(rhs12 += (s1 * ((m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate())).lazy(),
rhs13 += (s1*m1.adjoint()) * (s2*rhs3).conjugate());
// test matrix * selfadjoint
symm_extra<OtherSize>::run(m1,m2,rhs2,rhs22,rhs23,s1,s2);