mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-02-17 18:09:55 +08:00
Make product eval-at-once.
* Make product EvalAtOnce in cases OuterProduct, GemmProduct and GemvProduct * Ensure that product evaluators are nested inside EvalToTemp evaluator * As temporary kludge, evaluate expression to temporary in AllAtOnce traversal and pass expression operator to evalTo()
This commit is contained in:
parent
2393ceb380
commit
d0b873822f
@ -616,7 +616,13 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, AllAtOnceTraversal, NoU
|
||||
DstEvaluatorType dstEvaluator(dst);
|
||||
SrcEvaluatorType srcEvaluator(src);
|
||||
|
||||
srcEvaluator.evalTo(dstEvaluator);
|
||||
// Evaluate rhs in temporary to prevent aliasing problems in a = a * a;
|
||||
// TODO: Be smarter about this
|
||||
// TODO: Do not pass the xpr object to evalTo()
|
||||
typename DstXprType::PlainObject tmp;
|
||||
typename evaluator<typename DstXprType::PlainObject>::type tmpEvaluator(tmp);
|
||||
srcEvaluator.evalTo(tmpEvaluator, tmp);
|
||||
copy_using_evaluator(dst, tmp);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
|
||||
// Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
|
||||
// Copyright (C) 2011-2012 Jitse Niesen <jitse@maths.leeds.ac.uk>
|
||||
//
|
||||
// Eigen is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU Lesser General Public
|
||||
@ -42,24 +42,46 @@ struct evaluator_traits
|
||||
static const int HasEvalTo = 0;
|
||||
};
|
||||
|
||||
// expression class for evaluating nested expression to a temporary
|
||||
|
||||
template<typename ArgType>
|
||||
class EvalToTemp;
|
||||
|
||||
// evaluator<T>::type is type of evaluator for T
|
||||
// evaluator<T>::nestedType is type of evaluator if T is nested inside another evaluator
|
||||
|
||||
template<typename T>
|
||||
struct evaluator_impl
|
||||
{ };
|
||||
|
||||
template<typename T, int Nested = evaluator_traits<T>::HasEvalTo>
|
||||
struct evaluator_nested_type;
|
||||
|
||||
template<typename T>
|
||||
struct evaluator_impl {};
|
||||
struct evaluator_nested_type<T, 0>
|
||||
{
|
||||
typedef evaluator_impl<T> type;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct evaluator_nested_type<T, 1>
|
||||
{
|
||||
typedef evaluator_impl<EvalToTemp<T> > type;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct evaluator
|
||||
{
|
||||
typedef evaluator_impl<T> type;
|
||||
typedef typename evaluator_nested_type<T>::type nestedType;
|
||||
};
|
||||
|
||||
// TODO: Think about const-correctness
|
||||
|
||||
template<typename T>
|
||||
struct evaluator<const T>
|
||||
{
|
||||
typedef evaluator_impl<T> type;
|
||||
};
|
||||
: evaluator<T>
|
||||
{ };
|
||||
|
||||
// ---------- base class for all writable evaluators ----------
|
||||
|
||||
@ -132,70 +154,6 @@ struct evaluator_impl_base
|
||||
}
|
||||
};
|
||||
|
||||
// -------------------- Transpose --------------------
|
||||
|
||||
template<typename ArgType>
|
||||
struct evaluator_impl<Transpose<ArgType> >
|
||||
: evaluator_impl_base<Transpose<ArgType> >
|
||||
{
|
||||
typedef Transpose<ArgType> XprType;
|
||||
|
||||
evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {}
|
||||
|
||||
typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketScalar PacketScalar;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
CoeffReturnType coeff(Index row, Index col) const
|
||||
{
|
||||
return m_argImpl.coeff(col, row);
|
||||
}
|
||||
|
||||
CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
return m_argImpl.coeff(index);
|
||||
}
|
||||
|
||||
Scalar& coeffRef(Index row, Index col)
|
||||
{
|
||||
return m_argImpl.coeffRef(col, row);
|
||||
}
|
||||
|
||||
typename XprType::Scalar& coeffRef(Index index)
|
||||
{
|
||||
return m_argImpl.coeffRef(index);
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
PacketReturnType packet(Index row, Index col) const
|
||||
{
|
||||
return m_argImpl.template packet<LoadMode>(col, row);
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_argImpl.template packet<LoadMode>(index);
|
||||
}
|
||||
|
||||
template<int StoreMode>
|
||||
void writePacket(Index row, Index col, const PacketScalar& x)
|
||||
{
|
||||
m_argImpl.template writePacket<StoreMode>(col, row, x);
|
||||
}
|
||||
|
||||
template<int StoreMode>
|
||||
void writePacket(Index index, const PacketScalar& x)
|
||||
{
|
||||
m_argImpl.template writePacket<StoreMode>(index, x);
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
};
|
||||
|
||||
// -------------------- Matrix and Array --------------------
|
||||
//
|
||||
// evaluator_impl<PlainObjectBase> is a common base class for the
|
||||
@ -285,6 +243,89 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
|
||||
{ }
|
||||
};
|
||||
|
||||
// -------------------- EvalToTemp --------------------
|
||||
|
||||
template<typename ArgType>
|
||||
struct evaluator_impl<EvalToTemp<ArgType> >
|
||||
: evaluator_impl<typename ArgType::PlainObject>
|
||||
{
|
||||
typedef typename ArgType::PlainObject PlainObject;
|
||||
typedef evaluator_impl<PlainObject> BaseType;
|
||||
|
||||
evaluator_impl(const ArgType& arg)
|
||||
: BaseType(m_result)
|
||||
{
|
||||
copy_using_evaluator(m_result, arg);
|
||||
};
|
||||
|
||||
protected:
|
||||
PlainObject m_result;
|
||||
};
|
||||
|
||||
// -------------------- Transpose --------------------
|
||||
|
||||
template<typename ArgType>
|
||||
struct evaluator_impl<Transpose<ArgType> >
|
||||
: evaluator_impl_base<Transpose<ArgType> >
|
||||
{
|
||||
typedef Transpose<ArgType> XprType;
|
||||
|
||||
evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {}
|
||||
|
||||
typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketScalar PacketScalar;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
CoeffReturnType coeff(Index row, Index col) const
|
||||
{
|
||||
return m_argImpl.coeff(col, row);
|
||||
}
|
||||
|
||||
CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
return m_argImpl.coeff(index);
|
||||
}
|
||||
|
||||
Scalar& coeffRef(Index row, Index col)
|
||||
{
|
||||
return m_argImpl.coeffRef(col, row);
|
||||
}
|
||||
|
||||
typename XprType::Scalar& coeffRef(Index index)
|
||||
{
|
||||
return m_argImpl.coeffRef(index);
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
PacketReturnType packet(Index row, Index col) const
|
||||
{
|
||||
return m_argImpl.template packet<LoadMode>(col, row);
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_argImpl.template packet<LoadMode>(index);
|
||||
}
|
||||
|
||||
template<int StoreMode>
|
||||
void writePacket(Index row, Index col, const PacketScalar& x)
|
||||
{
|
||||
m_argImpl.template writePacket<StoreMode>(col, row, x);
|
||||
}
|
||||
|
||||
template<int StoreMode>
|
||||
void writePacket(Index index, const PacketScalar& x)
|
||||
{
|
||||
m_argImpl.template writePacket<StoreMode>(index, x);
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
};
|
||||
|
||||
// -------------------- CwiseNullaryOp --------------------
|
||||
|
||||
template<typename NullaryOp, typename PlainObjectType>
|
||||
@ -366,7 +407,7 @@ struct evaluator_impl<CwiseUnaryOp<UnaryOp, ArgType> >
|
||||
|
||||
protected:
|
||||
const UnaryOp m_functor;
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
};
|
||||
|
||||
// -------------------- CwiseBinaryOp --------------------
|
||||
@ -412,8 +453,8 @@ struct evaluator_impl<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
|
||||
|
||||
protected:
|
||||
const BinaryOp m_functor;
|
||||
typename evaluator<Lhs>::type m_lhsImpl;
|
||||
typename evaluator<Rhs>::type m_rhsImpl;
|
||||
typename evaluator<Lhs>::nestedType m_lhsImpl;
|
||||
typename evaluator<Rhs>::nestedType m_rhsImpl;
|
||||
};
|
||||
|
||||
// -------------------- CwiseUnaryView --------------------
|
||||
@ -455,7 +496,7 @@ struct evaluator_impl<CwiseUnaryView<UnaryOp, ArgType> >
|
||||
|
||||
protected:
|
||||
const UnaryOp m_unaryOp;
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
};
|
||||
|
||||
// -------------------- Map --------------------
|
||||
@ -626,7 +667,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
|
||||
// TODO: Get rid of m_startRow, m_startCol if known at compile time
|
||||
Index m_startRow;
|
||||
@ -681,9 +722,9 @@ struct evaluator_impl<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ConditionMatrixType>::type m_conditionImpl;
|
||||
typename evaluator<ThenMatrixType>::type m_thenImpl;
|
||||
typename evaluator<ElseMatrixType>::type m_elseImpl;
|
||||
typename evaluator<ConditionMatrixType>::nestedType m_conditionImpl;
|
||||
typename evaluator<ThenMatrixType>::nestedType m_thenImpl;
|
||||
typename evaluator<ElseMatrixType>::nestedType m_elseImpl;
|
||||
};
|
||||
|
||||
|
||||
@ -731,7 +772,7 @@ struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> >
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
Index m_rows; // TODO: Get rid of this if known at compile time
|
||||
Index m_cols;
|
||||
};
|
||||
@ -834,7 +875,7 @@ struct evaluator_impl_wrapper_base
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
};
|
||||
|
||||
template<typename ArgType>
|
||||
@ -949,7 +990,7 @@ struct evaluator_impl<Reverse<ArgType, Direction> >
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
Index m_rows; // TODO: Don't use if known at compile time or not needed
|
||||
Index m_cols;
|
||||
};
|
||||
@ -993,7 +1034,7 @@ struct evaluator_impl<Diagonal<ArgType, DiagIndex> >
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
Index m_index; // TODO: Don't use if known at compile time
|
||||
|
||||
private:
|
||||
@ -1069,7 +1110,7 @@ struct evaluator_impl<SwapWrapper<ArgType> >
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<ArgType>::type m_argImpl;
|
||||
typename evaluator<ArgType>::nestedType m_argImpl;
|
||||
};
|
||||
|
||||
|
||||
@ -1133,7 +1174,7 @@ struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
|
||||
}
|
||||
|
||||
protected:
|
||||
typename evaluator<LhsXpr>::type m_argImpl;
|
||||
typename evaluator<LhsXpr>::nestedType m_argImpl;
|
||||
const BinaryOp m_functor;
|
||||
};
|
||||
|
||||
|
@ -50,12 +50,26 @@ struct evaluator_impl<Product<Lhs, Rhs> >
|
||||
{ }
|
||||
};
|
||||
|
||||
template<typename XprType, typename ProductType>
|
||||
struct product_evaluator_traits_dispatcher;
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct evaluator_traits<Product<Lhs, Rhs> >
|
||||
: product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, typename ProductReturnType<Lhs, Rhs>::Type>
|
||||
{ };
|
||||
|
||||
// Case 1: Evaluate all at once
|
||||
//
|
||||
// We can view the GeneralProduct class as a part of the product evaluator.
|
||||
// Four sub-cases: InnerProduct, OuterProduct, GemmProduct and GemvProduct.
|
||||
// InnerProduct is special because GeneralProduct does not have an evalTo() method in this case.
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
|
||||
{
|
||||
static const int HasEvalTo = 0;
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
|
||||
: public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
|
||||
@ -63,7 +77,8 @@ struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs,
|
||||
typedef Product<Lhs, Rhs> XprType;
|
||||
typedef typename XprType::PlainObject PlainObject;
|
||||
typedef typename evaluator<PlainObject>::type evaluator_base;
|
||||
|
||||
|
||||
// TODO: Computation is too early (?)
|
||||
product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
|
||||
{
|
||||
m_result.coeffRef(0,0) = (xpr.lhs().transpose().cwiseProduct(xpr.rhs())).sum();
|
||||
@ -76,22 +91,31 @@ protected:
|
||||
// For the other three subcases, simply call the evalTo() method of GeneralProduct
|
||||
// TODO: GeneralProduct should take evaluators, not expression objects.
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductType>
|
||||
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
|
||||
{
|
||||
static const int HasEvalTo = 1;
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductType>
|
||||
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
|
||||
: public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
|
||||
{
|
||||
typedef Product<Lhs, Rhs> XprType;
|
||||
typedef typename XprType::PlainObject PlainObject;
|
||||
typedef typename evaluator<PlainObject>::type evaluator_base;
|
||||
|
||||
product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
|
||||
product_evaluator_dispatcher(const XprType& xpr) : m_xpr(xpr)
|
||||
{ }
|
||||
|
||||
template<typename DstEvaluatorType, typename DstXprType>
|
||||
void evalTo(DstEvaluatorType /* not used */, DstXprType& dst)
|
||||
{
|
||||
m_result.resize(xpr.rows(), xpr.cols());
|
||||
GeneralProduct<Lhs, Rhs, ProductType>(xpr.lhs(), xpr.rhs()).evalTo(m_result);
|
||||
dst.resize(m_xpr.rows(), m_xpr.cols());
|
||||
GeneralProduct<Lhs, Rhs, ProductType>(m_xpr.lhs(), m_xpr.rhs()).evalTo(dst);
|
||||
}
|
||||
|
||||
protected:
|
||||
PlainObject m_result;
|
||||
protected:
|
||||
const XprType& m_xpr;
|
||||
};
|
||||
|
||||
// Case 2: Evaluate coeff by coeff
|
||||
@ -106,6 +130,12 @@ struct etor_product_coeff_impl;
|
||||
template<int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
|
||||
struct etor_product_packet_impl;
|
||||
|
||||
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
|
||||
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
|
||||
{
|
||||
static const int HasEvalTo = 0;
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
|
||||
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
|
||||
: evaluator_impl_base<Product<Lhs, Rhs> >
|
||||
|
@ -65,6 +65,11 @@ void test_evaluators()
|
||||
VERIFY_IS_APPROX_EVALUATOR2(d, s * prod(a,b), s * a*b);
|
||||
VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b).transpose(), (a*b).transpose());
|
||||
VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b) + prod(b,c), a*b + b*c);
|
||||
|
||||
// check that prod works even with aliasing present
|
||||
c = a*a;
|
||||
copy_using_evaluator(a, prod(a,a));
|
||||
VERIFY_IS_APPROX(a,c);
|
||||
}
|
||||
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user