Make evaluators for Matrix and Array inherit from common base class.

This gets rid of some code duplication.
This commit is contained in:
Jitse Niesen 2011-04-04 15:35:14 +01:00
parent afdd26f229
commit ae06b8af5c

View File

@ -106,128 +106,92 @@ protected:
typename evaluator<ExpressionType>::type m_argImpl;
};
// -------------------- Matrix --------------------
// -------------------- Matrix and Array--------------------
//
// evaluator_impl<PlainObjectBase> is a common base class for the
// Matrix and Array evaluators.
template<typename Derived>
struct evaluator_impl<PlainObjectBase<Derived> >
{
typedef PlainObjectBase<Derived> PlainObjectType;
evaluator_impl(const PlainObjectType& m) : m_plainObject(m) {}
typedef typename PlainObjectType::Index Index;
typedef typename PlainObjectType::Scalar Scalar;
typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
typedef typename PlainObjectType::PacketScalar PacketScalar;
typedef typename PlainObjectType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index i, Index j) const
{
return m_plainObject.coeff(i, j);
}
CoeffReturnType coeff(Index index) const
{
return m_plainObject.coeff(index);
}
Scalar& coeffRef(Index i, Index j)
{
return m_plainObject.const_cast_derived().coeffRef(i, j);
}
Scalar& coeffRef(Index index)
{
return m_plainObject.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_plainObject.template packet<LoadMode>(row, col);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_plainObject.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_plainObject.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_plainObject.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
const PlainObjectType &m_plainObject;
};
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
: evaluator_impl<PlainObjectBase<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
{
typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType;
evaluator_impl(const MatrixType& m) : m_matrix(m) {}
typedef typename MatrixType::Index Index;
typename MatrixType::CoeffReturnType coeff(Index i, Index j) const
{
return m_matrix.coeff(i, j);
}
typename MatrixType::CoeffReturnType coeff(Index index) const
{
return m_matrix.coeff(index);
}
typename MatrixType::Scalar& coeffRef(Index i, Index j)
{
return m_matrix.const_cast_derived().coeffRef(i, j);
}
typename MatrixType::Scalar& coeffRef(Index index)
{
return m_matrix.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
typename MatrixType::PacketReturnType packet(Index row, Index col) const
{
return m_matrix.template packet<LoadMode>(row, col);
}
template<int LoadMode>
typename MatrixType::PacketReturnType packet(Index index) const
{
// eigen_internal_assert(index >= 0 && index < size());
return m_matrix.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const typename MatrixType::PacketScalar& x)
{
m_matrix.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const typename MatrixType::PacketScalar& x)
{
// eigen_internal_assert(index >= 0 && index < size());
m_matrix.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
const MatrixType &m_matrix;
evaluator_impl(const MatrixType& m)
: evaluator_impl<PlainObjectBase<MatrixType> >(m)
{ }
};
// -------------------- Array --------------------
// TODO: should be sharing code with Matrix case
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
: evaluator_impl<PlainObjectBase<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
{
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> ArrayType;
evaluator_impl(const ArrayType& a) : m_array(a) {}
typedef typename ArrayType::Index Index;
typename ArrayType::CoeffReturnType coeff(Index i, Index j) const
{
return m_array.coeff(i, j);
}
typename ArrayType::CoeffReturnType coeff(Index index) const
{
return m_array.coeff(index);
}
typename ArrayType::Scalar& coeffRef(Index i, Index j)
{
return m_array.const_cast_derived().coeffRef(i, j);
}
typename ArrayType::Scalar& coeffRef(Index index)
{
return m_array.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
typename ArrayType::PacketReturnType packet(Index row, Index col) const
{
return m_array.template packet<LoadMode>(row, col);
}
template<int LoadMode>
typename ArrayType::PacketReturnType packet(Index index) const
{
// eigen_internal_assert(index >= 0 && index < size());
return m_array.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const typename ArrayType::PacketScalar& x)
{
m_array.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const typename ArrayType::PacketScalar& x)
{
// eigen_internal_assert(index >= 0 && index < size());
m_array.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
const ArrayType &m_array;
evaluator_impl(const ArrayType& m)
: evaluator_impl<PlainObjectBase<ArrayType> >(m)
{ }
};
// -------------------- CwiseNullaryOp --------------------
@ -400,8 +364,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
CoeffReturnType coeff(Index index) const
{
return m_argImpl.coeff(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
m_startCol + (RowsAtCompileTime == 1 ? index : 0));
return coeff(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0);
}
Scalar& coeffRef(Index row, Index col)
@ -411,8 +375,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
m_startCol + (RowsAtCompileTime == 1 ? index : 0));
return coeffRef(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0);
}
template<int LoadMode>
@ -424,8 +388,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_argImpl.template packet<LoadMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
m_startCol + (RowsAtCompileTime == 1 ? index : 0));
return packet<LoadMode>(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0);
}
template<int StoreMode>
@ -437,9 +401,9 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
return m_argImpl.template writePacket<StoreMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
m_startCol + (RowsAtCompileTime == 1 ? index : 0),
x);
return writePacket<StoreMode>(RowsAtCompileTime == 1 ? 0 : index,
RowsAtCompileTime == 1 ? index : 0,
x);
}
protected: