Move implementation of coeff() &c to Matrix/Array evaluator.

This commit is contained in:
Jitse Niesen 2012-07-05 11:09:42 +01:00
parent 54d55aeaf6
commit cb9f3685d3

View File

@ -169,7 +169,19 @@ struct evaluator_impl<PlainObjectBase<Derived> >
{ {
typedef PlainObjectBase<Derived> PlainObjectType; typedef PlainObjectBase<Derived> PlainObjectType;
evaluator_impl(const PlainObjectType& m) : m_plainObject(m) {} evaluator_impl(void)
{ }
evaluator_impl(const PlainObjectType& m)
{
init(m);
}
void init(const PlainObjectType& m)
{
m_data = m.data();
m_outerStride = m.outerStride();
}
typedef typename PlainObjectType::Index Index; typedef typename PlainObjectType::Index Index;
typedef typename PlainObjectType::Scalar Scalar; typedef typename PlainObjectType::Scalar Scalar;
@ -177,52 +189,71 @@ struct evaluator_impl<PlainObjectBase<Derived> >
typedef typename PlainObjectType::PacketScalar PacketScalar; typedef typename PlainObjectType::PacketScalar PacketScalar;
typedef typename PlainObjectType::PacketReturnType PacketReturnType; typedef typename PlainObjectType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index i, Index j) const enum {
IsRowMajor = PlainObjectType::IsRowMajor,
};
CoeffReturnType coeff(Index row, Index col) const
{ {
return m_plainObject.coeff(i, j); if (IsRowMajor)
return m_data[row * m_outerStride + col];
else
return m_data[row + col * m_outerStride];
} }
CoeffReturnType coeff(Index index) const CoeffReturnType coeff(Index index) const
{ {
return m_plainObject.coeff(index); return m_data[index];
} }
Scalar& coeffRef(Index i, Index j) Scalar& coeffRef(Index row, Index col)
{ {
return m_plainObject.const_cast_derived().coeffRef(i, j); if (IsRowMajor)
return const_cast<Scalar*>(m_data)[row * m_outerStride + col];
else
return const_cast<Scalar*>(m_data)[row + col * m_outerStride];
} }
Scalar& coeffRef(Index index) Scalar& coeffRef(Index index)
{ {
return m_plainObject.const_cast_derived().coeffRef(index); return const_cast<Scalar*>(m_data)[index];
} }
template<int LoadMode> template<int LoadMode>
PacketReturnType packet(Index row, Index col) const PacketReturnType packet(Index row, Index col) const
{ {
return m_plainObject.template packet<LoadMode>(row, col); if (IsRowMajor)
return ploadt<PacketScalar, LoadMode>(m_data + row * m_outerStride + col);
else
return ploadt<PacketScalar, LoadMode>(m_data + row + col * m_outerStride);
} }
template<int LoadMode> template<int LoadMode>
PacketReturnType packet(Index index) const PacketReturnType packet(Index index) const
{ {
return m_plainObject.template packet<LoadMode>(index); return ploadt<PacketScalar, LoadMode>(m_data + index);
} }
template<int StoreMode> template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x) void writePacket(Index row, Index col, const PacketScalar& x)
{ {
m_plainObject.const_cast_derived().template writePacket<StoreMode>(row, col, x); if (IsRowMajor)
return pstoret<Scalar, PacketScalar, StoreMode>
(const_cast<Scalar*>(m_data) + row * m_outerStride + col, x);
else
return pstoret<Scalar, PacketScalar, StoreMode>
(const_cast<Scalar*>(m_data) + row + col * m_outerStride, x);
} }
template<int StoreMode> template<int StoreMode>
void writePacket(Index index, const PacketScalar& x) void writePacket(Index index, const PacketScalar& x)
{ {
m_plainObject.const_cast_derived().template writePacket<StoreMode>(index, x); return pstoret<Scalar, PacketScalar, StoreMode>(const_cast<Scalar*>(m_data) + index, x);
} }
protected: protected:
const PlainObjectType &m_plainObject; const Scalar *m_data;
Index m_outerStride;
}; };
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
@ -231,6 +262,9 @@ struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{ {
typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType; typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
evaluator_impl(void)
{ }
evaluator_impl(const XprType& m) evaluator_impl(const XprType& m)
: evaluator_impl<PlainObjectBase<XprType> >(m) : evaluator_impl<PlainObjectBase<XprType> >(m)
{ } { }
@ -242,6 +276,9 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{ {
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType; typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType;
evaluator_impl(void)
{ }
evaluator_impl(const XprType& m) evaluator_impl(const XprType& m)
: evaluator_impl<PlainObjectBase<XprType> >(m) : evaluator_impl<PlainObjectBase<XprType> >(m)
{ } { }
@ -295,17 +332,25 @@ struct evaluator_impl<EvalToTemp<ArgType> >
typedef evaluator_impl<PlainObject> BaseType; typedef evaluator_impl<PlainObject> BaseType;
evaluator_impl(const XprType& xpr) evaluator_impl(const XprType& xpr)
: BaseType(m_result)
{ {
noalias_copy_using_evaluator(m_result, xpr.arg()); init(xpr.arg());
}; }
// this constructor is used when nesting an EvalTo evaluator in another evaluator // This constructor is used when nesting an EvalTo evaluator in another evaluator
evaluator_impl(const ArgType& arg) evaluator_impl(const ArgType& arg)
: BaseType(m_result)
{ {
init(arg);
}
protected:
void init(const ArgType& arg)
{
// We can only initialize the base class evaluator after m_result is initialized.
// TODO: Redesign to get rid of inheritance, so that we can remove default c'tors in
// PlainObject, Matrix and Array evaluators.
noalias_copy_using_evaluator(m_result, arg); noalias_copy_using_evaluator(m_result, arg);
}; BaseType::init(m_result);
}
protected: protected:
PlainObject m_result; PlainObject m_result;