mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-15 07:10:37 +08:00
Properly implement PartialReduxExpr on top of evaluators, and fix multiple evaluation of nested expression
This commit is contained in:
parent
5cc7251188
commit
aa6b1aebf3
@ -965,17 +965,16 @@ protected:
|
||||
|
||||
|
||||
// -------------------- PartialReduxExpr --------------------
|
||||
//
|
||||
// This is a wrapper around the expression object.
|
||||
// TODO: Find out how to write a proper evaluator without duplicating
|
||||
// the row() and col() member functions.
|
||||
|
||||
template< typename ArgType, typename MemberOp, int Direction>
|
||||
struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
||||
: evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
||||
{
|
||||
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
|
||||
typedef typename XprType::Scalar InputScalar;
|
||||
typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
|
||||
typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
|
||||
typedef typename ArgType::Scalar InputScalar;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
enum {
|
||||
TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime)
|
||||
};
|
||||
@ -986,27 +985,34 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
||||
|
||||
Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits),
|
||||
|
||||
Alignment = 0 // FIXME this could be improved
|
||||
Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr)
|
||||
: m_expr(expr)
|
||||
EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
|
||||
: m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
|
||||
{}
|
||||
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
|
||||
{
|
||||
return m_expr.coeff(row, col);
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
|
||||
{
|
||||
if (Direction==Vertical)
|
||||
return m_functor(m_arg.col(j));
|
||||
else
|
||||
return m_functor(m_arg.row(i));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
return m_expr.coeff(index);
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
||||
{
|
||||
if (Direction==Vertical)
|
||||
return m_functor(m_arg.col(index));
|
||||
else
|
||||
return m_functor(m_arg.row(index));
|
||||
}
|
||||
|
||||
protected:
|
||||
const XprType m_expr;
|
||||
const ArgTypeNested m_arg;
|
||||
const MemberOp m_functor;
|
||||
};
|
||||
|
||||
|
||||
|
@ -41,8 +41,6 @@ struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> >
|
||||
typedef typename traits<MatrixType>::StorageKind StorageKind;
|
||||
typedef typename traits<MatrixType>::XprKind XprKind;
|
||||
typedef typename MatrixType::Scalar InputScalar;
|
||||
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
|
||||
typedef typename remove_all<MatrixTypeNested>::type _MatrixTypeNested;
|
||||
enum {
|
||||
RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime,
|
||||
ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime,
|
||||
@ -62,8 +60,6 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
|
||||
|
||||
typedef typename internal::dense_xpr_base<PartialReduxExpr>::type Base;
|
||||
EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr)
|
||||
typedef typename internal::traits<PartialReduxExpr>::MatrixTypeNested MatrixTypeNested;
|
||||
typedef typename internal::traits<PartialReduxExpr>::_MatrixTypeNested _MatrixTypeNested;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp())
|
||||
@ -74,24 +70,11 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
|
||||
EIGEN_DEVICE_FUNC
|
||||
Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
|
||||
{
|
||||
if (Direction==Vertical)
|
||||
return m_functor(m_matrix.col(j));
|
||||
else
|
||||
return m_functor(m_matrix.row(i));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
||||
{
|
||||
if (Direction==Vertical)
|
||||
return m_functor(m_matrix.col(index));
|
||||
else
|
||||
return m_functor(m_matrix.row(index));
|
||||
}
|
||||
typename MatrixType::Nested nestedExpression() const { return m_matrix; }
|
||||
const MemberOp& functor() const { return m_functor; }
|
||||
|
||||
protected:
|
||||
MatrixTypeNested m_matrix;
|
||||
typename MatrixType::Nested m_matrix;
|
||||
const MemberOp m_functor;
|
||||
};
|
||||
|
||||
|
@ -2,11 +2,13 @@
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
|
||||
// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#define TEST_ENABLE_TEMPORARY_TRACKING
|
||||
#define EIGEN_NO_STATIC_ASSERT
|
||||
|
||||
#include "main.h"
|
||||
@ -209,14 +211,20 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
|
||||
m2 = m1;
|
||||
m2.rowwise().normalize();
|
||||
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
|
||||
|
||||
// test with partial reduction of products
|
||||
Matrix<Scalar,MatrixType::RowsAtCompileTime,MatrixType::RowsAtCompileTime> m1m1 = m1 * m1.transpose();
|
||||
VERIFY_IS_APPROX( (m1 * m1.transpose()).colwise().sum(), m1m1.colwise().sum());
|
||||
Matrix<Scalar,1,MatrixType::RowsAtCompileTime> tmp(rows);
|
||||
VERIFY_EVALUATION_COUNT( tmp = (m1 * m1.transpose()).colwise().sum(), (MatrixType::RowsAtCompileTime==Dynamic ? 1 : 0));
|
||||
}
|
||||
|
||||
void test_vectorwiseop()
|
||||
{
|
||||
CALL_SUBTEST_1(vectorwiseop_array(Array22cd()));
|
||||
CALL_SUBTEST_2(vectorwiseop_array(Array<double, 3, 2>()));
|
||||
CALL_SUBTEST_3(vectorwiseop_array(ArrayXXf(3, 4)));
|
||||
CALL_SUBTEST_4(vectorwiseop_matrix(Matrix4cf()));
|
||||
CALL_SUBTEST_5(vectorwiseop_matrix(Matrix<float,4,5>()));
|
||||
CALL_SUBTEST_6(vectorwiseop_matrix(MatrixXd(7,2)));
|
||||
CALL_SUBTEST_1( vectorwiseop_array(Array22cd()) );
|
||||
CALL_SUBTEST_2( vectorwiseop_array(Array<double, 3, 2>()) );
|
||||
CALL_SUBTEST_3( vectorwiseop_array(ArrayXXf(3, 4)) );
|
||||
CALL_SUBTEST_4( vectorwiseop_matrix(Matrix4cf()) );
|
||||
CALL_SUBTEST_5( vectorwiseop_matrix(Matrix<float,4,5>()) );
|
||||
CALL_SUBTEST_6( vectorwiseop_matrix(MatrixXd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user