Implement evaluators for ArrayWrapper and MatrixWrapper.

This commit is contained in:
Jitse Niesen 2011-04-22 22:36:45 +01:00
parent 6441e8727b
commit bb2d70d211
4 changed files with 109 additions and 1 deletions

View File

@ -119,6 +119,12 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
template<typename Dest>
inline void evalTo(Dest& dst) const { dst = m_expression; }
const typename internal::remove_all<NestedExpressionType>::type&
nestedExpression() const
{
return m_expression;
}
protected:
const NestedExpressionType m_expression;
};
@ -214,6 +220,12 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
m_expression.const_cast_derived().template writePacket<LoadMode>(index, x);
}
const typename internal::remove_all<NestedExpressionType>::type&
nestedExpression() const
{
return m_expression;
}
protected:
const NestedExpressionType m_expression;
};

View File

@ -106,7 +106,7 @@ protected:
typename evaluator<ExpressionType>::type m_argImpl;
};
// -------------------- Matrix and Array--------------------
// -------------------- Matrix and Array --------------------
//
// evaluator_impl<PlainObjectBase> is a common base class for the
// Matrix and Array evaluators.
@ -704,6 +704,89 @@ protected:
};
// -------------------- MatrixWrapper and ArrayWrapper --------------------
//
// evaluator_impl_wrapper_base<T> is a common base class for the
// MatrixWrapper and ArrayWrapper evaluators.
template<typename ArgType>
struct evaluator_impl_wrapper_base
{
evaluator_impl_wrapper_base(const ArgType& arg) : m_argImpl(arg) {}
typedef typename ArgType::Index Index;
typedef typename ArgType::Scalar Scalar;
typedef typename ArgType::CoeffReturnType CoeffReturnType;
typedef typename ArgType::PacketScalar PacketScalar;
typedef typename ArgType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(row, col);
}
CoeffReturnType coeff(Index index) const
{
return m_argImpl.coeff(index);
}
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(row, col);
}
Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(index);
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_argImpl.template packet<LoadMode>(row, col);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_argImpl.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_argImpl.template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_argImpl.template writePacket<StoreMode>(index, x);
}
protected:
typename evaluator<ArgType>::type m_argImpl;
};
template<typename ArgType>
struct evaluator_impl<MatrixWrapper<ArgType> >
: evaluator_impl_wrapper_base<ArgType>
{
evaluator_impl(const MatrixWrapper<ArgType>& wrapper)
: evaluator_impl_wrapper_base<ArgType>(wrapper.nestedExpression())
{ }
};
template<typename ArgType>
struct evaluator_impl<ArrayWrapper<ArgType> >
: evaluator_impl_wrapper_base<ArgType>
{
evaluator_impl(const ArrayWrapper<ArgType>& wrapper)
: evaluator_impl_wrapper_base<ArgType>(wrapper.nestedExpression())
{ }
};
} // namespace internal
#endif // EIGEN_COREEVALUATORS_H

View File

@ -133,6 +133,7 @@ template<typename ExpressionType> class WithFormat;
template<typename MatrixType> struct CommaInitializer;
template<typename Derived> class ReturnByValue;
template<typename ExpressionType> class ArrayWrapper;
template<typename ExpressionType> class MatrixWrapper;
namespace internal {
template<typename DecompositionType, typename Rhs> struct solve_retval_base;

View File

@ -180,4 +180,16 @@ void test_evaluators()
VectorXd vec1(6);
VERIFY_IS_APPROX_EVALUATOR(vec1, mat1.rowwise().sum());
VERIFY_IS_APPROX_EVALUATOR(vec1, mat1.colwise().sum().transpose());
// test MatrixWrapper and ArrayWrapper
mat1.setRandom(6,6);
arr1.setRandom(6,6);
VERIFY_IS_APPROX_EVALUATOR(mat2, arr1.matrix());
VERIFY_IS_APPROX_EVALUATOR(arr2, mat1.array());
VERIFY_IS_APPROX_EVALUATOR(mat2, (arr1 + 2).matrix());
VERIFY_IS_APPROX_EVALUATOR(arr2, mat1.array() + 2);
mat2.array() = arr1 * arr1;
VERIFY_IS_APPROX(mat2, (arr1 * arr1).matrix());
arr2.matrix() = MatrixXd::Identity(6,6);
VERIFY_IS_APPROX(arr2, MatrixXd::Identity(6,6).array());
}