Make product eval-at-once.

* Make product EvalAtOnce in cases OuterProduct, GemmProduct and
  GemvProduct
* Ensure that product evaluators are nested inside EvalToTemp
  evaluator
* As temporary kludge, evaluate expression to temporary in AllAtOnce
  traversal and pass expression operator to evalTo()
This commit is contained in:
Jitse Niesen 2012-06-29 13:49:25 +01:00
parent 2393ceb380
commit d0b873822f
4 changed files with 173 additions and 91 deletions

View File

@ -616,7 +616,13 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, AllAtOnceTraversal, NoU
DstEvaluatorType dstEvaluator(dst);
SrcEvaluatorType srcEvaluator(src);
srcEvaluator.evalTo(dstEvaluator);
// Evaluate rhs in temporary to prevent aliasing problems in a = a * a;
// TODO: Be smarter about this
// TODO: Do not pass the xpr object to evalTo()
typename DstXprType::PlainObject tmp;
typename evaluator<typename DstXprType::PlainObject>::type tmpEvaluator(tmp);
srcEvaluator.evalTo(tmpEvaluator, tmp);
copy_using_evaluator(dst, tmp);
}
};

View File

@ -3,7 +3,7 @@
//
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
// Copyright (C) 2011-2012 Jitse Niesen <jitse@maths.leeds.ac.uk>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
@ -42,24 +42,46 @@ struct evaluator_traits
static const int HasEvalTo = 0;
};
// expression class for evaluating nested expression to a temporary
template<typename ArgType>
class EvalToTemp;
// evaluator<T>::type is type of evaluator for T
// evaluator<T>::nestedType is type of evaluator if T is nested inside another evaluator
template<typename T>
struct evaluator_impl
{ };
template<typename T, int Nested = evaluator_traits<T>::HasEvalTo>
struct evaluator_nested_type;
template<typename T>
struct evaluator_impl {};
struct evaluator_nested_type<T, 0>
{
typedef evaluator_impl<T> type;
};
template<typename T>
struct evaluator_nested_type<T, 1>
{
typedef evaluator_impl<EvalToTemp<T> > type;
};
template<typename T>
struct evaluator
{
typedef evaluator_impl<T> type;
typedef typename evaluator_nested_type<T>::type nestedType;
};
// TODO: Think about const-correctness
template<typename T>
struct evaluator<const T>
{
typedef evaluator_impl<T> type;
};
: evaluator<T>
{ };
// ---------- base class for all writable evaluators ----------
@ -132,70 +154,6 @@ struct evaluator_impl_base
}
};
// -------------------- Transpose --------------------
template<typename ArgType>
struct evaluator_impl<Transpose<ArgType> >
: evaluator_impl_base<Transpose<ArgType> >
{
typedef Transpose<ArgType> XprType;
evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {}
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketScalar PacketScalar;
typedef typename XprType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(col, row);
}
CoeffReturnType coeff(Index index) const
{
return m_argImpl.coeff(index);
}
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(col, row);
}
typename XprType::Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(index);
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_argImpl.template packet<LoadMode>(col, row);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_argImpl.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_argImpl.template writePacket<StoreMode>(col, row, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_argImpl.template writePacket<StoreMode>(index, x);
}
protected:
typename evaluator<ArgType>::type m_argImpl;
};
// -------------------- Matrix and Array --------------------
//
// evaluator_impl<PlainObjectBase> is a common base class for the
@ -285,6 +243,89 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{ }
};
// -------------------- EvalToTemp --------------------
template<typename ArgType>
struct evaluator_impl<EvalToTemp<ArgType> >
: evaluator_impl<typename ArgType::PlainObject>
{
typedef typename ArgType::PlainObject PlainObject;
typedef evaluator_impl<PlainObject> BaseType;
evaluator_impl(const ArgType& arg)
: BaseType(m_result)
{
copy_using_evaluator(m_result, arg);
};
protected:
PlainObject m_result;
};
// -------------------- Transpose --------------------
template<typename ArgType>
struct evaluator_impl<Transpose<ArgType> >
: evaluator_impl_base<Transpose<ArgType> >
{
typedef Transpose<ArgType> XprType;
evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {}
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketScalar PacketScalar;
typedef typename XprType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(col, row);
}
CoeffReturnType coeff(Index index) const
{
return m_argImpl.coeff(index);
}
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(col, row);
}
typename XprType::Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(index);
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_argImpl.template packet<LoadMode>(col, row);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_argImpl.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_argImpl.template writePacket<StoreMode>(col, row, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_argImpl.template writePacket<StoreMode>(index, x);
}
protected:
typename evaluator<ArgType>::nestedType m_argImpl;
};
// -------------------- CwiseNullaryOp --------------------
template<typename NullaryOp, typename PlainObjectType>
@ -366,7 +407,7 @@ struct evaluator_impl<CwiseUnaryOp<UnaryOp, ArgType> >
protected:
const UnaryOp m_functor;
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
};
// -------------------- CwiseBinaryOp --------------------
@ -412,8 +453,8 @@ struct evaluator_impl<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
protected:
const BinaryOp m_functor;
typename evaluator<Lhs>::type m_lhsImpl;
typename evaluator<Rhs>::type m_rhsImpl;
typename evaluator<Lhs>::nestedType m_lhsImpl;
typename evaluator<Rhs>::nestedType m_rhsImpl;
};
// -------------------- CwiseUnaryView --------------------
@ -455,7 +496,7 @@ struct evaluator_impl<CwiseUnaryView<UnaryOp, ArgType> >
protected:
const UnaryOp m_unaryOp;
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
};
// -------------------- Map --------------------
@ -626,7 +667,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
}
protected:
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
// TODO: Get rid of m_startRow, m_startCol if known at compile time
Index m_startRow;
@ -681,9 +722,9 @@ struct evaluator_impl<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType
}
protected:
typename evaluator<ConditionMatrixType>::type m_conditionImpl;
typename evaluator<ThenMatrixType>::type m_thenImpl;
typename evaluator<ElseMatrixType>::type m_elseImpl;
typename evaluator<ConditionMatrixType>::nestedType m_conditionImpl;
typename evaluator<ThenMatrixType>::nestedType m_thenImpl;
typename evaluator<ElseMatrixType>::nestedType m_elseImpl;
};
@ -731,7 +772,7 @@ struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> >
}
protected:
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
Index m_rows; // TODO: Get rid of this if known at compile time
Index m_cols;
};
@ -834,7 +875,7 @@ struct evaluator_impl_wrapper_base
}
protected:
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
};
template<typename ArgType>
@ -949,7 +990,7 @@ struct evaluator_impl<Reverse<ArgType, Direction> >
}
protected:
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
Index m_rows; // TODO: Don't use if known at compile time or not needed
Index m_cols;
};
@ -993,7 +1034,7 @@ struct evaluator_impl<Diagonal<ArgType, DiagIndex> >
}
protected:
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
Index m_index; // TODO: Don't use if known at compile time
private:
@ -1069,7 +1110,7 @@ struct evaluator_impl<SwapWrapper<ArgType> >
}
protected:
typename evaluator<ArgType>::type m_argImpl;
typename evaluator<ArgType>::nestedType m_argImpl;
};
@ -1133,7 +1174,7 @@ struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
}
protected:
typename evaluator<LhsXpr>::type m_argImpl;
typename evaluator<LhsXpr>::nestedType m_argImpl;
const BinaryOp m_functor;
};

View File

@ -50,12 +50,26 @@ struct evaluator_impl<Product<Lhs, Rhs> >
{ }
};
template<typename XprType, typename ProductType>
struct product_evaluator_traits_dispatcher;
template<typename Lhs, typename Rhs>
struct evaluator_traits<Product<Lhs, Rhs> >
: product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, typename ProductReturnType<Lhs, Rhs>::Type>
{ };
// Case 1: Evaluate all at once
//
// We can view the GeneralProduct class as a part of the product evaluator.
// Four sub-cases: InnerProduct, OuterProduct, GemmProduct and GemvProduct.
// InnerProduct is special because GeneralProduct does not have an evalTo() method in this case.
template<typename Lhs, typename Rhs>
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
{
static const int HasEvalTo = 0;
};
template<typename Lhs, typename Rhs>
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
: public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
@ -63,7 +77,8 @@ struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs,
typedef Product<Lhs, Rhs> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type evaluator_base;
// TODO: Computation is too early (?)
product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
{
m_result.coeffRef(0,0) = (xpr.lhs().transpose().cwiseProduct(xpr.rhs())).sum();
@ -76,22 +91,31 @@ protected:
// For the other three subcases, simply call the evalTo() method of GeneralProduct
// TODO: GeneralProduct should take evaluators, not expression objects.
template<typename Lhs, typename Rhs, int ProductType>
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
{
static const int HasEvalTo = 1;
};
template<typename Lhs, typename Rhs, int ProductType>
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
: public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
{
typedef Product<Lhs, Rhs> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type evaluator_base;
product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
product_evaluator_dispatcher(const XprType& xpr) : m_xpr(xpr)
{ }
template<typename DstEvaluatorType, typename DstXprType>
void evalTo(DstEvaluatorType /* not used */, DstXprType& dst)
{
m_result.resize(xpr.rows(), xpr.cols());
GeneralProduct<Lhs, Rhs, ProductType>(xpr.lhs(), xpr.rhs()).evalTo(m_result);
dst.resize(m_xpr.rows(), m_xpr.cols());
GeneralProduct<Lhs, Rhs, ProductType>(m_xpr.lhs(), m_xpr.rhs()).evalTo(dst);
}
protected:
PlainObject m_result;
protected:
const XprType& m_xpr;
};
// Case 2: Evaluate coeff by coeff
@ -106,6 +130,12 @@ struct etor_product_coeff_impl;
template<int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct etor_product_packet_impl;
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
{
static const int HasEvalTo = 0;
};
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
: evaluator_impl_base<Product<Lhs, Rhs> >

View File

@ -65,6 +65,11 @@ void test_evaluators()
VERIFY_IS_APPROX_EVALUATOR2(d, s * prod(a,b), s * a*b);
VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b).transpose(), (a*b).transpose());
VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b) + prod(b,c), a*b + b*c);
// check that prod works even with aliasing present
c = a*a;
copy_using_evaluator(a, prod(a,a));
VERIFY_IS_APPROX(a,c);
}
{