mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-18 14:34:17 +08:00
overload operartor* with a ProductBase such that "scalar * (mat * mat)" is optimized
as one could naturally expect
This commit is contained in:
parent
a4f6642518
commit
afbd73b5cd
@ -153,7 +153,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
|
||||
|
||||
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
ei_assert(dst.rows()==1 && dst.cols()==1);
|
||||
dst.coeffRef(0,0) += alpha * (m_lhs.cwise()*m_rhs).sum();
|
||||
@ -179,7 +179,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
|
||||
|
||||
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
||||
|
||||
template<typename Dest> void addTo(Dest& dest, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
|
||||
{
|
||||
ei_outer_product_selector<Dest::Flags&RowMajorBit>::run(*this, dest, alpha);
|
||||
}
|
||||
@ -236,7 +236,7 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
|
||||
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
|
||||
typedef typename ei_meta_if<int(Side)==OnTheRight,_LhsNested,_RhsNested>::ret MatrixType;
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
|
||||
ei_gemv_selector<Side,int(MatrixType::Flags)&RowMajorBit,
|
||||
|
@ -100,16 +100,16 @@ class ProductBase : public MatrixBase<Derived>
|
||||
inline int cols() const { return m_rhs.cols(); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); }
|
||||
inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,1); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void addTo(Dest& dst) const { addTo(dst,1); }
|
||||
inline void addTo(Dest& dst) const { scaleAndAddTo(dst,1); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void subTo(Dest& dst) const { addTo(dst,-1); }
|
||||
inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-1); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); }
|
||||
inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { derived().scaleAndAddTo(dst,alpha); }
|
||||
|
||||
PlainMatrixType eval() const
|
||||
{
|
||||
@ -141,13 +141,68 @@ class ProductBase : public MatrixBase<Derived>
|
||||
void coeffRef(int);
|
||||
};
|
||||
|
||||
template<typename NestedProduct>
|
||||
class ScaledProduct;
|
||||
|
||||
// Note that these two operator* functions are not defined as member
|
||||
// functions of ProductBase, because, otherwise we would have to
|
||||
// define all overloads defined in MatrixBase. Furthermore, Using
|
||||
// "using Base::operator*" would not work with MSVC.
|
||||
template<typename Derived,typename Lhs,typename Rhs>
|
||||
const ScaledProduct<Derived> operator*(const ProductBase<Derived,Lhs,Rhs>& prod, typename Derived::Scalar x)
|
||||
{ return ScaledProduct<Derived>(prod.derived(), x); }
|
||||
|
||||
template<typename Derived,typename Lhs,typename Rhs>
|
||||
const ScaledProduct<Derived> operator*(typename Derived::Scalar x,const ProductBase<Derived,Lhs,Rhs>& prod)
|
||||
{ return ScaledProduct<Derived>(prod.derived(), x); }
|
||||
|
||||
template<typename NestedProduct>
|
||||
struct ei_traits<ScaledProduct<NestedProduct> >
|
||||
: ei_traits<ProductBase<ScaledProduct<NestedProduct>,
|
||||
typename NestedProduct::_LhsNested,
|
||||
typename NestedProduct::_RhsNested> >
|
||||
{};
|
||||
|
||||
template<typename NestedProduct>
|
||||
class ScaledProduct
|
||||
: public ProductBase<ScaledProduct<NestedProduct>,
|
||||
typename NestedProduct::_LhsNested,
|
||||
typename NestedProduct::_RhsNested>
|
||||
{
|
||||
public:
|
||||
typedef ProductBase<ScaledProduct<NestedProduct>,
|
||||
typename NestedProduct::_LhsNested,
|
||||
typename NestedProduct::_RhsNested> Base;
|
||||
typedef typename Base::Scalar Scalar;
|
||||
// EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct)
|
||||
|
||||
ScaledProduct(const NestedProduct& prod, Scalar& x)
|
||||
: Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {}
|
||||
|
||||
template<typename Dest>
|
||||
inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,m_alpha); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void addTo(Dest& dst) const { scaleAndAddTo(dst,m_alpha); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void subTo(Dest& dst) const { scaleAndAddTo(dst,-m_alpha); }
|
||||
|
||||
template<typename Dest>
|
||||
inline void scaleAndAddTo(Dest& dst,Scalar alpha) const { m_prod.derived().scaleAndAddTo(dst,alpha); }
|
||||
|
||||
protected:
|
||||
const NestedProduct& m_prod;
|
||||
Scalar m_alpha;
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* Overloaded to perform an efficient C = (A*B).lazy() */
|
||||
template<typename Derived>
|
||||
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other)
|
||||
{
|
||||
other.evalTo(derived()); return derived();
|
||||
other.derived().evalTo(derived()); return derived();
|
||||
}
|
||||
|
||||
/** \internal
|
||||
@ -157,7 +212,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::operator+=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
|
||||
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
|
||||
{
|
||||
other._expression().addTo(derived()); return derived();
|
||||
other._expression().derived().addTo(derived()); return derived();
|
||||
}
|
||||
|
||||
/** \internal
|
||||
@ -167,7 +222,7 @@ template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::operator-=(const Flagged<ProductBase<ProductDerived, Lhs,Rhs>, 0,
|
||||
EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
|
||||
{
|
||||
other._expression().subTo(derived()); return derived();
|
||||
other._expression().derived().subTo(derived()); return derived();
|
||||
}
|
||||
|
||||
#endif // EIGEN_PRODUCTBASE_H
|
||||
|
@ -137,7 +137,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
||||
|
||||
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
|
@ -375,7 +375,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
|
||||
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
|
||||
};
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
|
@ -175,7 +175,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
|
||||
|
||||
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
|
@ -333,7 +333,7 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
|
||||
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
||||
|
@ -130,7 +130,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
||||
|
||||
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
||||
|
||||
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
|
@ -79,7 +79,7 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
|
||||
VERIFY_EVALUATION_COUNT( m3 = (s1 * m1 * s2 * (m1*s3+m2*s2).adjoint()).lazy(), 1);
|
||||
VERIFY_EVALUATION_COUNT( m3 = ((s1 * m1).adjoint() * s2 * m2).lazy(), 0);
|
||||
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (-m1*s3).adjoint() * (s2 * m2 * s3)).lazy(), 0);
|
||||
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 1);
|
||||
VERIFY_EVALUATION_COUNT( m3 -= (s1 * (m1.transpose() * m2)).lazy(), 0);
|
||||
|
||||
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) += (-m1.block(r0,c0,r1,c1) * (s2*m2.block(r0,c0,r1,c1)).adjoint()).lazy() ), 0);
|
||||
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) -= (s1 * m1.block(r0,c0,r1,c1) * m2.block(c0,r0,c1,r1)).lazy() ), 0);
|
||||
|
@ -94,6 +94,11 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in
|
||||
VERIFY_IS_APPROX(rhs12 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
|
||||
rhs13 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
|
||||
|
||||
|
||||
m2 = m1.template triangularView<UpperTriangular>(); rhs13 = rhs12;
|
||||
VERIFY_IS_APPROX(rhs12 += (s1 * ((m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate())).lazy(),
|
||||
rhs13 += (s1*m1.adjoint()) * (s2*rhs3).conjugate());
|
||||
|
||||
// test matrix * selfadjoint
|
||||
symm_extra<OtherSize>::run(m1,m2,rhs2,rhs22,rhs23,s1,s2);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user