Get rid of GeneralProduct for outer-products, and get rid of ScaledProduct

This commit is contained in:
Gael Guennebaud 2014-02-21 16:27:24 +01:00
parent af31b6c37a
commit 728c3d2cb9
3 changed files with 66 additions and 11 deletions

View File

@ -247,6 +247,7 @@ class GeneralProduct<Lhs, Rhs, InnerProduct>
* Implementation of Outer Vector Vector Product
***********************************************************************/
#ifndef EIGEN_TEST_EVALUATORS
namespace internal {
// Column major
@ -326,6 +327,8 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
}
};
#endif // EIGEN_TEST_EVALUATORS
/***********************************************************************
* Implementation of General Matrix Vector Product
***********************************************************************/

View File

@ -174,6 +174,7 @@ class ProductBase : public MatrixBase<Derived>
mutable PlainObject m_result;
};
#ifndef EIGEN_TEST_EVALUATORS
// here we need to overload the nested rule for products
// such that the nested type is a const reference to a plain matrix
namespace internal {
@ -263,6 +264,8 @@ class ScaledProduct
Scalar m_alpha;
};
#endif // EIGEN_TEST_EVALUATORS
/** \internal
* Overloaded to perform an efficient C = (A*B).lazy() */
template<typename Derived>

View File

@ -17,7 +17,11 @@ namespace Eigen {
namespace internal {
// Like more general binary expressions, products need their own evaluator:
/** \internal
* \class product_evaluator
* Products need their own evaluator with more template arguments allowing for
* easier partial template specializations.
*/
template< typename T,
int ProductTag = internal::product_type<typename T::Lhs,typename T::Rhs>::ret,
typename LhsShape = typename evaluator_traits<typename T::Lhs>::Shape,
@ -26,6 +30,14 @@ template< typename T,
typename RhsScalar = typename T::Rhs::Scalar
> struct product_evaluator;
/** \internal
* Evaluator of a product expression.
* Since products require special treatments to handle all possible cases,
* we simply deffer the evaluation logic to a product_evaluator class
* which offers more partial specialization possibilities.
*
* \sa class product_evaluator
*/
template<typename Lhs, typename Rhs, int Options>
struct evaluator<Product<Lhs, Rhs, Options> >
: public product_evaluator<Product<Lhs, Rhs, Options> >
@ -40,7 +52,7 @@ struct evaluator<Product<Lhs, Rhs, Options> >
};
// Catch scalar * ( A * B ) and transform it to (A*scalar) * B
// TODO we should apply that rule if that's really helpful
// TODO we should apply that rule only if that's really helpful
template<typename Lhs, typename Rhs, typename Scalar>
struct evaluator<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > >
: public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> >
@ -66,7 +78,7 @@ struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> >
typedef evaluator type;
typedef evaluator nestedType;
//
evaluator(const XprType& xpr)
: Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()),
@ -183,38 +195,75 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct>
};
/***********************************************************************
* Implementation of outer dense * dense vector product
***********************************************************************/
// Column major result
template<typename Dst, typename Lhs, typename Rhs, typename Func>
EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const false_type&)
{
typedef typename Dst::Index Index;
// FIXME make sure lhs is sequentially stored
// FIXME not very good if rhs is real and lhs complex while alpha is real too
// FIXME we should probably build an evaluator for dst and rhs
const Index cols = dst.cols();
for (Index j=0; j<cols; ++j)
func(dst.col(j), rhs.coeff(j) * lhs);
}
// Row major result
template<typename Dst, typename Lhs, typename Rhs, typename Func>
EIGEN_DONT_INLINE void outer_product_selector_run(Dst& dst, const Lhs &lhs, const Rhs &rhs, const Func& func, const true_type&) {
typedef typename Dst::Index Index;
// FIXME make sure rhs is sequentially stored
// FIXME not very good if lhs is real and rhs complex while alpha is real too
// FIXME we should probably build an evaluator for dst and lhs
const Index rows = dst.rows();
for (Index i=0; i<rows; ++i)
func(dst.row(i), lhs.coeff(i) * rhs);
}
template<typename Lhs, typename Rhs>
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,OuterProduct>
{
template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
// TODO it would be nice to be able to exploit our *_assign_op functors for that purpose
struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
struct adds {
Scalar m_scale;
adds(const Scalar& s) : m_scale(s) {}
template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
dst.const_cast_derived() += m_scale * src;
}
};
template<typename Dst>
static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
// TODO bypass GeneralProduct class
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).evalTo(dst);
internal::outer_product_selector_run(dst, lhs, rhs, set(), IsRowMajor<Dst>());
}
template<typename Dst>
static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
// TODO bypass GeneralProduct class
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).addTo(dst);
internal::outer_product_selector_run(dst, lhs, rhs, add(), IsRowMajor<Dst>());
}
template<typename Dst>
static inline void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
// TODO bypass GeneralProduct class
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).subTo(dst);
internal::outer_product_selector_run(dst, lhs, rhs, sub(), IsRowMajor<Dst>());
}
template<typename Dst>
static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
// TODO bypass GeneralProduct class
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), IsRowMajor<Dst>());
}
};