Move implementation of coeff() &c to Matrix/Array evaluator.

This commit is contained in:
Jitse Niesen 2012-07-05 11:09:42 +01:00
parent 54d55aeaf6
commit cb9f3685d3

View File

@ -169,7 +169,19 @@ struct evaluator_impl<PlainObjectBase<Derived> >
{
typedef PlainObjectBase<Derived> PlainObjectType;
evaluator_impl(const PlainObjectType& m) : m_plainObject(m) {}
evaluator_impl(void)
{ }
evaluator_impl(const PlainObjectType& m)
{
init(m);
}
void init(const PlainObjectType& m)
{
m_data = m.data();
m_outerStride = m.outerStride();
}
typedef typename PlainObjectType::Index Index;
typedef typename PlainObjectType::Scalar Scalar;
@ -177,52 +189,71 @@ struct evaluator_impl<PlainObjectBase<Derived> >
typedef typename PlainObjectType::PacketScalar PacketScalar;
typedef typename PlainObjectType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index i, Index j) const
enum {
IsRowMajor = PlainObjectType::IsRowMajor,
};
CoeffReturnType coeff(Index row, Index col) const
{
return m_plainObject.coeff(i, j);
if (IsRowMajor)
return m_data[row * m_outerStride + col];
else
return m_data[row + col * m_outerStride];
}
CoeffReturnType coeff(Index index) const
{
return m_plainObject.coeff(index);
return m_data[index];
}
Scalar& coeffRef(Index i, Index j)
Scalar& coeffRef(Index row, Index col)
{
return m_plainObject.const_cast_derived().coeffRef(i, j);
if (IsRowMajor)
return const_cast<Scalar*>(m_data)[row * m_outerStride + col];
else
return const_cast<Scalar*>(m_data)[row + col * m_outerStride];
}
Scalar& coeffRef(Index index)
{
return m_plainObject.const_cast_derived().coeffRef(index);
return const_cast<Scalar*>(m_data)[index];
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_plainObject.template packet<LoadMode>(row, col);
if (IsRowMajor)
return ploadt<PacketScalar, LoadMode>(m_data + row * m_outerStride + col);
else
return ploadt<PacketScalar, LoadMode>(m_data + row + col * m_outerStride);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_plainObject.template packet<LoadMode>(index);
return ploadt<PacketScalar, LoadMode>(m_data + index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_plainObject.const_cast_derived().template writePacket<StoreMode>(row, col, x);
if (IsRowMajor)
return pstoret<Scalar, PacketScalar, StoreMode>
(const_cast<Scalar*>(m_data) + row * m_outerStride + col, x);
else
return pstoret<Scalar, PacketScalar, StoreMode>
(const_cast<Scalar*>(m_data) + row + col * m_outerStride, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_plainObject.const_cast_derived().template writePacket<StoreMode>(index, x);
return pstoret<Scalar, PacketScalar, StoreMode>(const_cast<Scalar*>(m_data) + index, x);
}
protected:
const PlainObjectType &m_plainObject;
const Scalar *m_data;
Index m_outerStride;
};
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
@ -231,6 +262,9 @@ struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{
typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
evaluator_impl(void)
{ }
evaluator_impl(const XprType& m)
: evaluator_impl<PlainObjectBase<XprType> >(m)
{ }
@ -242,6 +276,9 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
evaluator_impl(void)
{ }
evaluator_impl(const XprType& m)
: evaluator_impl<PlainObjectBase<XprType> >(m)
{ }
@ -295,17 +332,25 @@ struct evaluator_impl<EvalToTemp<ArgType> >
typedef evaluator_impl<PlainObject> BaseType;
evaluator_impl(const XprType& xpr)
: BaseType(m_result)
{
noalias_copy_using_evaluator(m_result, xpr.arg());
};
init(xpr.arg());
}
// this constructor is used when nesting an EvalTo evaluator in another evaluator
// This constructor is used when nesting an EvalTo evaluator in another evaluator
evaluator_impl(const ArgType& arg)
: BaseType(m_result)
{
init(arg);
}
protected:
void init(const ArgType& arg)
{
// We can only initialize the base class evaluator after m_result is initialized.
// TODO: Redesign to get rid of inheritance, so that we can remove default c'tors in
// PlainObject, Matrix and Array evaluators.
noalias_copy_using_evaluator(m_result, arg);
};
BaseType::init(m_result);
}
protected:
PlainObject m_result;