Properly implement PartialReduxExpr on top of evaluators, and fix multiple evaluation of nested expression

This commit is contained in:
Gael Guennebaud 2015-10-08 15:57:05 +02:00
parent 5cc7251188
commit aa6b1aebf3
3 changed files with 40 additions and 43 deletions

View File

@ -965,17 +965,16 @@ protected:
// -------------------- PartialReduxExpr --------------------
//
// This is a wrapper around the expression object.
// TODO: Find out how to write a proper evaluator without duplicating
// the row() and col() member functions.
template< typename ArgType, typename MemberOp, int Direction>
struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
: evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
{
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
typedef typename XprType::Scalar InputScalar;
typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
typedef typename ArgType::Scalar InputScalar;
typedef typename XprType::Scalar Scalar;
enum {
TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime)
};
@ -986,27 +985,34 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits),
Alignment = 0 // FIXME this could be improved
Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
};
EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr)
: m_expr(expr)
EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
: m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
{}
typedef typename XprType::CoeffReturnType CoeffReturnType;
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
{
return m_expr.coeff(row, col);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
{
if (Direction==Vertical)
return m_functor(m_arg.col(j));
else
return m_functor(m_arg.row(i));
}
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
return m_expr.coeff(index);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
{
if (Direction==Vertical)
return m_functor(m_arg.col(index));
else
return m_functor(m_arg.row(index));
}
protected:
const XprType m_expr;
const ArgTypeNested m_arg;
const MemberOp m_functor;
};

View File

@ -41,8 +41,6 @@ struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> >
typedef typename traits<MatrixType>::StorageKind StorageKind;
typedef typename traits<MatrixType>::XprKind XprKind;
typedef typename MatrixType::Scalar InputScalar;
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
typedef typename remove_all<MatrixTypeNested>::type _MatrixTypeNested;
enum {
RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime,
ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime,
@ -62,8 +60,6 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
typedef typename internal::dense_xpr_base<PartialReduxExpr>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr)
typedef typename internal::traits<PartialReduxExpr>::MatrixTypeNested MatrixTypeNested;
typedef typename internal::traits<PartialReduxExpr>::_MatrixTypeNested _MatrixTypeNested;
EIGEN_DEVICE_FUNC
explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp())
@ -74,24 +70,11 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
EIGEN_DEVICE_FUNC
Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
{
if (Direction==Vertical)
return m_functor(m_matrix.col(j));
else
return m_functor(m_matrix.row(i));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
{
if (Direction==Vertical)
return m_functor(m_matrix.col(index));
else
return m_functor(m_matrix.row(index));
}
typename MatrixType::Nested nestedExpression() const { return m_matrix; }
const MemberOp& functor() const { return m_functor; }
protected:
MatrixTypeNested m_matrix;
typename MatrixType::Nested m_matrix;
const MemberOp m_functor;
};

View File

@ -2,11 +2,13 @@
// for linear algebra.
//
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#define TEST_ENABLE_TEMPORARY_TRACKING
#define EIGEN_NO_STATIC_ASSERT
#include "main.h"
@ -209,14 +211,20 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
m2 = m1;
m2.rowwise().normalize();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
// test with partial reduction of products
Matrix<Scalar,MatrixType::RowsAtCompileTime,MatrixType::RowsAtCompileTime> m1m1 = m1 * m1.transpose();
VERIFY_IS_APPROX( (m1 * m1.transpose()).colwise().sum(), m1m1.colwise().sum());
Matrix<Scalar,1,MatrixType::RowsAtCompileTime> tmp(rows);
VERIFY_EVALUATION_COUNT( tmp = (m1 * m1.transpose()).colwise().sum(), (MatrixType::RowsAtCompileTime==Dynamic ? 1 : 0));
}
void test_vectorwiseop()
{
CALL_SUBTEST_1(vectorwiseop_array(Array22cd()));
CALL_SUBTEST_2(vectorwiseop_array(Array<double, 3, 2>()));
CALL_SUBTEST_3(vectorwiseop_array(ArrayXXf(3, 4)));
CALL_SUBTEST_4(vectorwiseop_matrix(Matrix4cf()));
CALL_SUBTEST_5(vectorwiseop_matrix(Matrix<float,4,5>()));
CALL_SUBTEST_6(vectorwiseop_matrix(MatrixXd(7,2)));
CALL_SUBTEST_1( vectorwiseop_array(Array22cd()) );
CALL_SUBTEST_2( vectorwiseop_array(Array<double, 3, 2>()) );
CALL_SUBTEST_3( vectorwiseop_array(ArrayXXf(3, 4)) );
CALL_SUBTEST_4( vectorwiseop_matrix(Matrix4cf()) );
CALL_SUBTEST_5( vectorwiseop_matrix(Matrix<float,4,5>()) );
CALL_SUBTEST_6( vectorwiseop_matrix(MatrixXd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
}