diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 18f14f75a..610d5c84a 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -153,7 +153,7 @@ class GeneralProduct GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==1 && dst.cols()==1); dst.coeffRef(0,0) += alpha * (m_lhs.cwise()*m_rhs).sum(); @@ -179,7 +179,7 @@ class GeneralProduct GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template void addTo(Dest& dest, Scalar alpha) const + template void scaleAndAddTo(Dest& dest, Scalar alpha) const { ei_outer_product_selector::run(*this, dest, alpha); } @@ -236,7 +236,7 @@ class GeneralProduct enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight }; typedef typename ei_meta_if::ret MatrixType; - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols()); ei_gemv_selector inline int cols() const { return m_rhs.cols(); } template - inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); } + inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,1); } template - inline void addTo(Dest& dst) const { addTo(dst,1); } + inline void addTo(Dest& dst) const { scaleAndAddTo(dst,1); } template - inline void subTo(Dest& dst) const { addTo(dst,-1); } + inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-1); } template - inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); } + inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { derived().scaleAndAddTo(dst,alpha); } PlainMatrixType eval() const { @@ -141,13 +141,68 @@ class ProductBase : public MatrixBase void coeffRef(int); }; +template +class ScaledProduct; + +// Note that these two operator* functions are not defined as member +// functions of ProductBase, because, otherwise we would have to +// define all overloads defined in MatrixBase. Furthermore, Using +// "using Base::operator*" would not work with MSVC. +template +const ScaledProduct operator*(const ProductBase& prod, typename Derived::Scalar x) +{ return ScaledProduct(prod.derived(), x); } + +template +const ScaledProduct operator*(typename Derived::Scalar x,const ProductBase& prod) +{ return ScaledProduct(prod.derived(), x); } + +template +struct ei_traits > + : ei_traits, + typename NestedProduct::_LhsNested, + typename NestedProduct::_RhsNested> > +{}; + +template +class ScaledProduct + : public ProductBase, + typename NestedProduct::_LhsNested, + typename NestedProduct::_RhsNested> +{ + public: + typedef ProductBase, + typename NestedProduct::_LhsNested, + typename NestedProduct::_RhsNested> Base; + typedef typename Base::Scalar Scalar; +// EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct) + + ScaledProduct(const NestedProduct& prod, Scalar& x) + : Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {} + + template + inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,m_alpha); } + + template + inline void addTo(Dest& dst) const { scaleAndAddTo(dst,m_alpha); } + + template + inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-m_alpha); } + + template + inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { m_prod.derived().scaleAndAddTo(dst,alpha); } + + protected: + const NestedProduct& m_prod; + Scalar m_alpha; +}; + /** \internal * Overloaded to perform an efficient C = (A*B).lazy() */ template template Derived& MatrixBase::lazyAssign(const ProductBase& other) { - other.evalTo(derived()); return derived(); + other.derived().evalTo(derived()); return derived(); } /** \internal @@ -157,7 +212,7 @@ template Derived& MatrixBase::operator+=(const Flagged, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) { - other._expression().addTo(derived()); return derived(); + other._expression().derived().addTo(derived()); return derived(); } /** \internal @@ -167,7 +222,7 @@ template Derived& MatrixBase::operator-=(const Flagged, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) { - other._expression().subTo(derived()); return derived(); + other._expression().derived().subTo(derived()); return derived(); } #endif // EIGEN_PRODUCTBASE_H diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index ff0f2c1b4..8b3b13266 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -137,7 +137,7 @@ class GeneralProduct GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index 358da3752..5e025b90b 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -375,7 +375,7 @@ struct SelfadjointProductMatrix RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit }; - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h index f0004cdb9..c2c33d5b8 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixVector.h +++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h @@ -175,7 +175,7 @@ struct SelfadjointProductMatrix SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index c2ee39e79..701ccb644 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -333,7 +333,7 @@ struct TriangularProduct TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index a21afa2f6..620b090b9 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -130,7 +130,7 @@ struct TriangularProduct TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} - template void addTo(Dest& dst, Scalar alpha) const + template void scaleAndAddTo(Dest& dst, Scalar alpha) const { ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index f4311e495..d5d996e49 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -79,7 +79,7 @@ template void product_notemporary(const MatrixType& m) VERIFY_EVALUATION_COUNT( m3 = (s1 * m1 * s2 * (m1*s3+m2*s2).adjoint()).lazy(), 1); VERIFY_EVALUATION_COUNT( m3 = ((s1 * m1).adjoint() * s2 * m2).lazy(), 0); VERIFY_EVALUATION_COUNT( m3 -= (s1 * (-m1*s3).adjoint() * (s2 * m2 * s3)).lazy(), 0); - VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 1); + VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 0); VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) += (-m1.block(r0,c0,r1,c1) * (s2*m2.block(r0,c0,r1,c1)).adjoint()).lazy() ), 0); VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) -= (s1 * m1.block(r0,c0,r1,c1) * m2.block(c0,r0,c1,r1)).lazy() ), 0); diff --git a/test/product_symm.cpp b/test/product_symm.cpp index 88bac878b..1300928a2 100644 --- a/test/product_symm.cpp +++ b/test/product_symm.cpp @@ -94,6 +94,11 @@ template void symm(int size = Size, in VERIFY_IS_APPROX(rhs12 = (s1*m2.adjoint()).template selfadjointView() * (s2*rhs3).conjugate(), rhs13 = (s1*m1.adjoint()) * (s2*rhs3).conjugate()); + + m2 = m1.template triangularView(); rhs13 = rhs12; + VERIFY_IS_APPROX(rhs12 += (s1 * ((m2.adjoint()).template selfadjointView() * (s2*rhs3).conjugate())).lazy(), + rhs13 += (s1*m1.adjoint()) * (s2*rhs3).conjugate()); + // test matrix * selfadjoint symm_extra::run(m1,m2,rhs2,rhs22,rhs23,s1,s2);