Evaluators: Remove member variables if known at compile-time.

Also, use composition instead of inheritance in EvalToTemp evaluator.
This commit is contained in:
Jitse Niesen 2012-07-06 14:50:03 +01:00
parent 7bfd8eabff
commit b1b6864c88
2 changed files with 116 additions and 81 deletions

View File

@ -169,19 +169,16 @@ struct evaluator_impl<PlainObjectBase<Derived> >
{
typedef PlainObjectBase<Derived> PlainObjectType;
evaluator_impl(void)
{ }
enum {
IsRowMajor = PlainObjectType::IsRowMajor,
IsVectorAtCompileTime = PlainObjectType::IsVectorAtCompileTime,
RowsAtCompileTime = PlainObjectType::RowsAtCompileTime,
ColsAtCompileTime = PlainObjectType::ColsAtCompileTime
};
evaluator_impl(const PlainObjectType& m)
{
init(m);
}
void init(const PlainObjectType& m)
{
m_data = m.data();
m_outerStride = m.outerStride();
}
: m_data(m.data()), m_outerStride(IsVectorAtCompileTime ? 0 : m.outerStride())
{ }
typedef typename PlainObjectType::Index Index;
typedef typename PlainObjectType::Scalar Scalar;
@ -189,16 +186,12 @@ struct evaluator_impl<PlainObjectBase<Derived> >
typedef typename PlainObjectType::PacketScalar PacketScalar;
typedef typename PlainObjectType::PacketReturnType PacketReturnType;
enum {
IsRowMajor = PlainObjectType::IsRowMajor,
};
CoeffReturnType coeff(Index row, Index col) const
{
if (IsRowMajor)
return m_data[row * m_outerStride + col];
return m_data[row * m_outerStride.value() + col];
else
return m_data[row + col * m_outerStride];
return m_data[row + col * m_outerStride.value()];
}
CoeffReturnType coeff(Index index) const
@ -209,9 +202,9 @@ struct evaluator_impl<PlainObjectBase<Derived> >
Scalar& coeffRef(Index row, Index col)
{
if (IsRowMajor)
return const_cast<Scalar*>(m_data)[row * m_outerStride + col];
return const_cast<Scalar*>(m_data)[row * m_outerStride.value() + col];
else
return const_cast<Scalar*>(m_data)[row + col * m_outerStride];
return const_cast<Scalar*>(m_data)[row + col * m_outerStride.value()];
}
Scalar& coeffRef(Index index)
@ -223,9 +216,9 @@ struct evaluator_impl<PlainObjectBase<Derived> >
PacketReturnType packet(Index row, Index col) const
{
if (IsRowMajor)
return ploadt<PacketScalar, LoadMode>(m_data + row * m_outerStride + col);
return ploadt<PacketScalar, LoadMode>(m_data + row * m_outerStride.value() + col);
else
return ploadt<PacketScalar, LoadMode>(m_data + row + col * m_outerStride);
return ploadt<PacketScalar, LoadMode>(m_data + row + col * m_outerStride.value());
}
template<int LoadMode>
@ -239,10 +232,10 @@ struct evaluator_impl<PlainObjectBase<Derived> >
{
if (IsRowMajor)
return pstoret<Scalar, PacketScalar, StoreMode>
(const_cast<Scalar*>(m_data) + row * m_outerStride + col, x);
(const_cast<Scalar*>(m_data) + row * m_outerStride.value() + col, x);
else
return pstoret<Scalar, PacketScalar, StoreMode>
(const_cast<Scalar*>(m_data) + row + col * m_outerStride, x);
(const_cast<Scalar*>(m_data) + row + col * m_outerStride.value(), x);
}
template<int StoreMode>
@ -253,7 +246,11 @@ struct evaluator_impl<PlainObjectBase<Derived> >
protected:
const Scalar *m_data;
Index m_outerStride;
// We do not need to know the outer stride for vectors
variable_if_dynamic<Index, IsVectorAtCompileTime ? 0
: int(IsRowMajor) ? ColsAtCompileTime
: RowsAtCompileTime> m_outerStride;
};
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
@ -262,9 +259,6 @@ struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{
typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
evaluator_impl(void)
{ }
evaluator_impl(const XprType& m)
: evaluator_impl<PlainObjectBase<XprType> >(m)
{ }
@ -276,9 +270,6 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
evaluator_impl(void)
{ }
evaluator_impl(const XprType& m)
: evaluator_impl<PlainObjectBase<XprType> >(m)
{ }
@ -325,35 +316,78 @@ class EvalToTemp
template<typename ArgType>
struct evaluator_impl<EvalToTemp<ArgType> >
: evaluator_impl<typename ArgType::PlainObject>
{
typedef EvalToTemp<ArgType> XprType;
typedef typename ArgType::PlainObject PlainObject;
typedef evaluator_impl<PlainObject> BaseType;
evaluator_impl(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols()), m_resultImpl(m_result)
{
init(xpr.arg());
copy_using_evaluator_without_resizing(m_result, xpr.arg());
}
// This constructor is used when nesting an EvalTo evaluator in another evaluator
evaluator_impl(const ArgType& arg)
: m_result(arg.rows(), arg.cols()), m_resultImpl(m_result)
{
init(arg);
copy_using_evaluator_without_resizing(m_result, arg);
}
protected:
void init(const ArgType& arg)
typedef typename PlainObject::Index Index;
typedef typename PlainObject::Scalar Scalar;
typedef typename PlainObject::CoeffReturnType CoeffReturnType;
typedef typename PlainObject::PacketScalar PacketScalar;
typedef typename PlainObject::PacketReturnType PacketReturnType;
// All other functions are forwarded to m_resultImpl
CoeffReturnType coeff(Index row, Index col) const
{
return m_resultImpl.coeff(row, col);
}
CoeffReturnType coeff(Index index) const
{
return m_resultImpl.coeff(index);
}
Scalar& coeffRef(Index row, Index col)
{
return m_resultImpl.coeffRef(row, col);
}
Scalar& coeffRef(Index index)
{
return m_resultImpl.coeffRef(index);
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
// We can only initialize the base class evaluator after m_result is initialized.
// TODO: Redesign to get rid of inheritance, so that we can remove default c'tors in
// PlainObject, Matrix and Array evaluators.
noalias_copy_using_evaluator(m_result, arg);
BaseType::init(m_result);
return m_resultImpl.packet<LoadMode>(row, col);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_resultImpl.packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_resultImpl.writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_resultImpl.writePacket<StoreMode>(index, x);
}
protected:
PlainObject m_result;
typename evaluator<PlainObject>::nestedType m_resultImpl;
};
// -------------------- Transpose --------------------
@ -713,7 +747,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(m_startRow + row, m_startCol + col);
return m_argImpl.coeff(m_startRow.value() + row, m_startCol.value() + col);
}
CoeffReturnType coeff(Index index) const
@ -724,7 +758,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(m_startRow + row, m_startCol + col);
return m_argImpl.coeffRef(m_startRow.value() + row, m_startCol.value() + col);
}
Scalar& coeffRef(Index index)
@ -736,7 +770,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_argImpl.template packet<LoadMode>(m_startRow + row, m_startCol + col);
return m_argImpl.template packet<LoadMode>(m_startRow.value() + row, m_startCol.value() + col);
}
template<int LoadMode>
@ -749,7 +783,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
return m_argImpl.template writePacket<StoreMode>(m_startRow + row, m_startCol + col, x);
return m_argImpl.template writePacket<StoreMode>(m_startRow.value() + row, m_startCol.value() + col, x);
}
template<int StoreMode>
@ -762,10 +796,8 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
protected:
typename evaluator<ArgType>::nestedType m_argImpl;
// TODO: Get rid of m_startRow, m_startCol if known at compile time
Index m_startRow;
Index m_startCol;
const variable_if_dynamic<Index, ArgType::RowsAtCompileTime == 1 ? 0 : Dynamic> m_startRow;
const variable_if_dynamic<Index, ArgType::ColsAtCompileTime == 1 ? 0 : Dynamic> m_startCol;
};
// TODO: This evaluator does not actually use the child evaluator;
@ -844,10 +876,10 @@ struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> >
// try to avoid using modulo; this is a pure optimization strategy
const Index actual_row = internal::traits<XprType>::RowsAtCompileTime==1 ? 0
: RowFactor==1 ? row
: row % m_rows;
: row % m_rows.value();
const Index actual_col = internal::traits<XprType>::ColsAtCompileTime==1 ? 0
: ColFactor==1 ? col
: col % m_cols;
: col % m_cols.value();
return m_argImpl.coeff(actual_row, actual_col);
}
@ -857,18 +889,18 @@ struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> >
{
const Index actual_row = internal::traits<XprType>::RowsAtCompileTime==1 ? 0
: RowFactor==1 ? row
: row % m_rows;
: row % m_rows.value();
const Index actual_col = internal::traits<XprType>::ColsAtCompileTime==1 ? 0
: ColFactor==1 ? col
: col % m_cols;
: col % m_cols.value();
return m_argImpl.template packet<LoadMode>(actual_row, actual_col);
}
protected:
typename evaluator<ArgType>::nestedType m_argImpl;
Index m_rows; // TODO: Get rid of this if known at compile time
Index m_cols;
const variable_if_dynamic<Index, XprType::RowsAtCompileTime> m_rows;
const variable_if_dynamic<Index, XprType::ColsAtCompileTime> m_cols;
};
@ -1005,13 +1037,6 @@ struct evaluator_impl<Reverse<ArgType, Direction> >
: evaluator_impl_base<Reverse<ArgType, Direction> >
{
typedef Reverse<ArgType, Direction> XprType;
evaluator_impl(const XprType& reverse)
: m_argImpl(reverse.nestedExpression()),
m_rows(reverse.nestedExpression().rows()),
m_cols(reverse.nestedExpression().cols())
{ }
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
@ -1032,61 +1057,70 @@ struct evaluator_impl<Reverse<ArgType, Direction> >
};
typedef internal::reverse_packet_cond<PacketScalar,ReversePacket> reverse_packet;
evaluator_impl(const XprType& reverse)
: m_argImpl(reverse.nestedExpression()),
m_rows(ReverseRow ? reverse.nestedExpression().rows() : 0),
m_cols(ReverseCol ? reverse.nestedExpression().cols() : 0)
{ }
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(ReverseRow ? m_rows - row - 1 : row,
ReverseCol ? m_cols - col - 1 : col);
return m_argImpl.coeff(ReverseRow ? m_rows.value() - row - 1 : row,
ReverseCol ? m_cols.value() - col - 1 : col);
}
CoeffReturnType coeff(Index index) const
{
return m_argImpl.coeff(m_rows * m_cols - index - 1);
return m_argImpl.coeff(m_rows.value() * m_cols.value() - index - 1);
}
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(ReverseRow ? m_rows - row - 1 : row,
ReverseCol ? m_cols - col - 1 : col);
return m_argImpl.coeffRef(ReverseRow ? m_rows.value() - row - 1 : row,
ReverseCol ? m_cols.value() - col - 1 : col);
}
Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(m_rows * m_cols - index - 1);
return m_argImpl.coeffRef(m_rows.value() * m_cols.value() - index - 1);
}
template<int LoadMode>
PacketScalar packet(Index row, Index col) const
{
return reverse_packet::run(m_argImpl.template packet<LoadMode>(
ReverseRow ? m_rows - row - OffsetRow : row,
ReverseCol ? m_cols - col - OffsetCol : col));
ReverseRow ? m_rows.value() - row - OffsetRow : row,
ReverseCol ? m_cols.value() - col - OffsetCol : col));
}
template<int LoadMode>
PacketScalar packet(Index index) const
{
return preverse(m_argImpl.template packet<LoadMode>(m_rows * m_cols - index - PacketSize));
return preverse(m_argImpl.template packet<LoadMode>(m_rows.value() * m_cols.value() - index - PacketSize));
}
template<int LoadMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_argImpl.template writePacket<LoadMode>(
ReverseRow ? m_rows - row - OffsetRow : row,
ReverseCol ? m_cols - col - OffsetCol : col,
ReverseRow ? m_rows.value() - row - OffsetRow : row,
ReverseCol ? m_cols.value() - col - OffsetCol : col,
reverse_packet::run(x));
}
template<int LoadMode>
void writePacket(Index index, const PacketScalar& x)
{
m_argImpl.template writePacket<LoadMode>(m_rows * m_cols - index - PacketSize, preverse(x));
m_argImpl.template writePacket<LoadMode>
(m_rows.value() * m_cols.value() - index - PacketSize, preverse(x));
}
protected:
typename evaluator<ArgType>::nestedType m_argImpl;
Index m_rows; // TODO: Don't use if known at compile time or not needed
Index m_cols;
// If we do not reverse rows, then we do not need to know the number of rows; same for columns
const variable_if_dynamic<Index, ReverseRow ? ArgType::RowsAtCompileTime : 0> m_rows;
const variable_if_dynamic<Index, ReverseCol ? ArgType::ColsAtCompileTime : 0> m_cols;
};
@ -1129,11 +1163,11 @@ struct evaluator_impl<Diagonal<ArgType, DiagIndex> >
protected:
typename evaluator<ArgType>::nestedType m_argImpl;
Index m_index; // TODO: Don't use if known at compile time
const internal::variable_if_dynamic<Index, XprType::DiagIndex> m_index;
private:
EIGEN_STRONG_INLINE Index rowOffset() const { return m_index>0 ? 0 : -m_index; }
EIGEN_STRONG_INLINE Index colOffset() const { return m_index>0 ? m_index : 0; }
EIGEN_STRONG_INLINE Index rowOffset() const { return m_index.value() > 0 ? 0 : -m_index.value(); }
EIGEN_STRONG_INLINE Index colOffset() const { return m_index.value() > 0 ? m_index.value() : 0; }
};

View File

@ -77,11 +77,12 @@ struct traits<Diagonal<MatrixType,DiagIndex> >
};
}
template<typename MatrixType, int DiagIndex> class Diagonal
: public internal::dense_xpr_base< Diagonal<MatrixType,DiagIndex> >::type
template<typename MatrixType, int _DiagIndex> class Diagonal
: public internal::dense_xpr_base< Diagonal<MatrixType,_DiagIndex> >::type
{
public:
enum { DiagIndex = _DiagIndex };
typedef typename internal::dense_xpr_base<Diagonal>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Diagonal)