More clever evaluation of arguments: now it occurs in earlier, in operator*,

before the Product<> type is constructed. This resets template depth on each
intermediate evaluation, and gives simpler code. Introducing
ei_eval_if_expensive<Derived, n> which evaluates Derived if it's worth it
given that each of its coeffs will be accessed n times. Operator*
uses this with adequate values of n to evaluate args exactly when needed.
This commit is contained in:
Benoit Jacob 2008-04-03 14:17:56 +00:00
parent 4448f2620d
commit b8900d0b80
2 changed files with 21 additions and 7 deletions

View File

@ -27,6 +27,7 @@
template<typename T> struct ei_traits;
template<typename Lhs, typename Rhs> struct ei_product_eval_mode;
template<typename T> struct NumTraits;
template<typename _Scalar, int _Rows, int _Cols, unsigned int _Flags, int _MaxRows, int _MaxCols> class Matrix;
template<typename ExpressionType> class Lazy;
@ -89,6 +90,13 @@ template<typename T> struct ei_eval
ei_traits<T>::MaxColsAtCompileTime> type;
};
template<typename T, int n> struct ei_eval_if_expensive
{
enum { eval = n * NumTraits<typename T::Scalar>::ReadCost < (n-1) * T::CoeffReadCost };
typedef typename ei_meta_if<eval, typename T::Eval, T>::ret type;
typedef typename ei_meta_if<eval, typename T::Eval, T&>::ret reftype;
};
template<typename T> struct ei_eval_unless_lazy
{
typedef typename ei_meta_if<ei_traits<T>::Flags & LazyBit,

View File

@ -78,6 +78,7 @@ template<typename Lhs, typename Rhs, int EvalMode>
struct ei_traits<Product<Lhs, Rhs, EvalMode> >
{
typedef typename Lhs::Scalar Scalar;
#if 0
typedef typename ei_meta_if<
(int)NumTraits<Scalar>::ReadCost < (int)Lhs::CoeffReadCost,
typename Lhs::Eval,
@ -95,6 +96,7 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
(int)NumTraits<Scalar>::ReadCost < (int)Rhs::CoeffReadCost,
typename Rhs::Eval,
typename Rhs::XprCopy>::ret ActualRhsXprCopy;
#endif
enum {
RowsAtCompileTime = Lhs::RowsAtCompileTime,
ColsAtCompileTime = Rhs::ColsAtCompileTime,
@ -107,7 +109,7 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
= Lhs::ColsAtCompileTime == Dynamic
? Dynamic
: Lhs::ColsAtCompileTime
* (NumTraits<Scalar>::MulCost + ActualLhs::CoeffReadCost + ActualRhs::CoeffReadCost)
* (NumTraits<Scalar>::MulCost + Lhs::CoeffReadCost + Rhs::CoeffReadCost)
+ (Lhs::ColsAtCompileTime - 1) * NumTraits<Scalar>::AddCost
};
};
@ -115,7 +117,7 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
template<typename Lhs, typename Rhs> struct ei_product_eval_mode
{
enum{ value = Lhs::MaxRowsAtCompileTime == Dynamic || Rhs::MaxColsAtCompileTime == Dynamic
? CacheOptimal : UnrolledDotProduct };
? CacheOptimal : UnrolledDotProduct };
};
template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignment_operator,
@ -124,11 +126,12 @@ template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignm
public:
EIGEN_GENERIC_PUBLIC_INTERFACE(Product)
#if 0
typedef typename ei_traits<Product>::ActualLhs ActualLhs;
typedef typename ei_traits<Product>::ActualRhs ActualRhs;
typedef typename ei_traits<Product>::ActualLhsXprCopy ActualLhsXprCopy;
typedef typename ei_traits<Product>::ActualRhsXprCopy ActualRhsXprCopy;
#endif
Product(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{
@ -153,7 +156,7 @@ template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignm
ei_product_unroller<Lhs::ColsAtCompileTime-1,
Lhs::ColsAtCompileTime <= EIGEN_UNROLLING_LIMIT
? Lhs::ColsAtCompileTime : Dynamic,
ActualLhs, ActualRhs>
Lhs, Rhs>
::run(row, col, m_lhs, m_rhs, res);
else
{
@ -165,8 +168,8 @@ template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignm
}
protected:
const ActualLhsXprCopy m_lhs;
const ActualRhsXprCopy m_rhs;
const typename Lhs::XprCopy m_lhs;
const typename Rhs::XprCopy m_rhs;
};
/** \returns the matrix product of \c *this and \a other.
@ -181,7 +184,10 @@ template<typename OtherDerived>
const typename ei_eval_unless_lazy<Product<Derived, OtherDerived> >::type
MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
{
return Product<Derived, OtherDerived>(derived(), other.derived()).eval();
typedef ei_eval_if_expensive<Derived, OtherDerived::ColsAtCompileTime> Lhs;
typedef ei_eval_if_expensive<OtherDerived, Derived::RowsAtCompileTime> Rhs;
return Product<typename Lhs::type, typename Rhs::type>
(typename Lhs::reftype(derived()), typename Rhs::reftype(other.derived())).eval();
}
/** replaces \c *this by \c *this * \a other.