mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
Get rid of GeneralProduct for outer-products, and get rid of ScaledProduct
This commit is contained in:
parent
af31b6c37a
commit
728c3d2cb9
@ -247,6 +247,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
|
||||
* Implementation of Outer Vector Vector Product
|
||||
***********************************************************************/
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
namespace internal {
|
||||
|
||||
// Column major
|
||||
@ -326,6 +327,8 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
/***********************************************************************
|
||||
* Implementation of General Matrix Vector Product
|
||||
***********************************************************************/
|
||||
|
@ -174,6 +174,7 @@ class ProductBase : public MatrixBase<Derived>
|
||||
mutable PlainObject m_result;
|
||||
};
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
// here we need to overload the nested rule for products
|
||||
// such that the nested type is a const reference to a plain matrix
|
||||
namespace internal {
|
||||
@ -263,6 +264,8 @@ class ScaledProduct
|
||||
Scalar m_alpha;
|
||||
};
|
||||
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
/** \internal
|
||||
* Overloaded to perform an efficient C = (A*B).lazy() */
|
||||
template<typename Derived>
|
||||
|
@ -17,7 +17,11 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Like more general binary expressions, products need their own evaluator:
|
||||
/** \internal
|
||||
* \class product_evaluator
|
||||
* Products need their own evaluator with more template arguments allowing for
|
||||
* easier partial template specializations.
|
||||
*/
|
||||
template< typename T,
|
||||
int ProductTag = internal::product_type<typename T::Lhs,typename T::Rhs>::ret,
|
||||
typename LhsShape = typename evaluator_traits<typename T::Lhs>::Shape,
|
||||
@ -26,6 +30,14 @@ template< typename T,
|
||||
typename RhsScalar = typename T::Rhs::Scalar
|
||||
> struct product_evaluator;
|
||||
|
||||
/** \internal
|
||||
* Evaluator of a product expression.
|
||||
* Since products require special treatments to handle all possible cases,
|
||||
* we simply deffer the evaluation logic to a product_evaluator class
|
||||
* which offers more partial specialization possibilities.
|
||||
*
|
||||
* \sa class product_evaluator
|
||||
*/
|
||||
template<typename Lhs, typename Rhs, int Options>
|
||||
struct evaluator<Product<Lhs, Rhs, Options> >
|
||||
: public product_evaluator<Product<Lhs, Rhs, Options> >
|
||||
@ -40,7 +52,7 @@ struct evaluator<Product<Lhs, Rhs, Options> >
|
||||
};
|
||||
|
||||
// Catch scalar * ( A * B ) and transform it to (A*scalar) * B
|
||||
// TODO we should apply that rule if that's really helpful
|
||||
// TODO we should apply that rule only if that's really helpful
|
||||
template<typename Lhs, typename Rhs, typename Scalar>
|
||||
struct evaluator<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > >
|
||||
: public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> >
|
||||
@ -66,7 +78,7 @@ struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> >
|
||||
|
||||
typedef evaluator type;
|
||||
typedef evaluator nestedType;
|
||||
//
|
||||
|
||||
evaluator(const XprType& xpr)
|
||||
: Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
|
||||
Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()),
|
||||
@ -183,38 +195,75 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct>
|
||||
};
|
||||
|
||||
|
||||
/***********************************************************************
|
||||
* Implementation of outer dense * dense vector product
|
||||
***********************************************************************/
|
||||
|
||||
// Column major result
|
||||
template<typename Dst, typename Lhs, typename Rhs, typename Func>
|
||||
EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const false_type&)
|
||||
{
|
||||
typedef typename Dst::Index Index;
|
||||
// FIXME make sure lhs is sequentially stored
|
||||
// FIXME not very good if rhs is real and lhs complex while alpha is real too
|
||||
// FIXME we should probably build an evaluator for dst and rhs
|
||||
const Index cols = dst.cols();
|
||||
for (Index j=0; j<cols; ++j)
|
||||
func(dst.col(j), rhs.coeff(j) * lhs);
|
||||
}
|
||||
|
||||
// Row major result
|
||||
template<typename Dst, typename Lhs, typename Rhs, typename Func>
|
||||
EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const true_type&) {
|
||||
typedef typename Dst::Index Index;
|
||||
// FIXME make sure rhs is sequentially stored
|
||||
// FIXME not very good if lhs is real and rhs complex while alpha is real too
|
||||
// FIXME we should probably build an evaluator for dst and lhs
|
||||
const Index rows = dst.rows();
|
||||
for (Index i=0; i<rows; ++i)
|
||||
func(dst.row(i), lhs.coeff(i) * rhs);
|
||||
}
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,OuterProduct>
|
||||
{
|
||||
template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
|
||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||
|
||||
// TODO it would be nice to be able to exploit our *_assign_op functors for that purpose
|
||||
struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
|
||||
struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
|
||||
struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
|
||||
struct adds {
|
||||
Scalar m_scale;
|
||||
adds(const Scalar& s) : m_scale(s) {}
|
||||
template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
|
||||
dst.const_cast_derived() += m_scale * src;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Dst>
|
||||
static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||
{
|
||||
// TODO bypass GeneralProduct class
|
||||
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).evalTo(dst);
|
||||
internal::outer_product_selector_run(dst, lhs, rhs, set(), IsRowMajor<Dst>());
|
||||
}
|
||||
|
||||
template<typename Dst>
|
||||
static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||
{
|
||||
// TODO bypass GeneralProduct class
|
||||
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).addTo(dst);
|
||||
internal::outer_product_selector_run(dst, lhs, rhs, add(), IsRowMajor<Dst>());
|
||||
}
|
||||
|
||||
template<typename Dst>
|
||||
static inline void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||
{
|
||||
// TODO bypass GeneralProduct class
|
||||
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).subTo(dst);
|
||||
internal::outer_product_selector_run(dst, lhs, rhs, sub(), IsRowMajor<Dst>());
|
||||
}
|
||||
|
||||
template<typename Dst>
|
||||
static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||
{
|
||||
// TODO bypass GeneralProduct class
|
||||
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
|
||||
internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), IsRowMajor<Dst>());
|
||||
}
|
||||
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user