diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 214114ebe..b96ef99fa 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -965,17 +965,16 @@ protected: // -------------------- PartialReduxExpr -------------------- -// -// This is a wrapper around the expression object. -// TODO: Find out how to write a proper evaluator without duplicating -// the row() and col() member functions. template< typename ArgType, typename MemberOp, int Direction> struct evaluator > : evaluator_base > { typedef PartialReduxExpr XprType; - typedef typename XprType::Scalar InputScalar; + typedef typename internal::nested_eval::type ArgTypeNested; + typedef typename internal::remove_all::type ArgTypeNestedCleaned; + typedef typename ArgType::Scalar InputScalar; + typedef typename XprType::Scalar Scalar; enum { TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime) }; @@ -986,27 +985,34 @@ struct evaluator > Flags = (traits::Flags&RowMajorBit) | (evaluator::Flags&HereditaryBits), - Alignment = 0 // FIXME this could be improved + Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized }; - EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr) - : m_expr(expr) + EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr) + : m_arg(xpr.nestedExpression()), m_functor(xpr.functor()) {} typedef typename XprType::CoeffReturnType CoeffReturnType; - - EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const - { - return m_expr.coeff(row, col); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(j)); + else + return m_functor(m_arg.row(i)); } - - EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const - { - return m_expr.coeff(index); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(index)); + else + return m_functor(m_arg.row(index)); } protected: - const XprType m_expr; + const ArgTypeNested m_arg; + const MemberOp m_functor; }; diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h index 79c7d135d..5de53732e 100644 --- a/Eigen/src/Core/VectorwiseOp.h +++ b/Eigen/src/Core/VectorwiseOp.h @@ -41,8 +41,6 @@ struct traits > typedef typename traits::StorageKind StorageKind; typedef typename traits::XprKind XprKind; typedef typename MatrixType::Scalar InputScalar; - typedef typename ref_selector::type MatrixTypeNested; - typedef typename remove_all::type _MatrixTypeNested; enum { RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime, ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime, @@ -62,8 +60,6 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr::type Base; EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr) - typedef typename internal::traits::MatrixTypeNested MatrixTypeNested; - typedef typename internal::traits::_MatrixTypeNested _MatrixTypeNested; EIGEN_DEVICE_FUNC explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp()) @@ -74,24 +70,11 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr +// Copyright (C) 2015 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#define TEST_ENABLE_TEMPORARY_TRACKING #define EIGEN_NO_STATIC_ASSERT #include "main.h" @@ -209,14 +211,20 @@ template void vectorwiseop_matrix(const MatrixType& m) m2 = m1; m2.rowwise().normalize(); VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized()); + + // test with partial reduction of products + Matrix m1m1 = m1 * m1.transpose(); + VERIFY_IS_APPROX( (m1 * m1.transpose()).colwise().sum(), m1m1.colwise().sum()); + Matrix tmp(rows); + VERIFY_EVALUATION_COUNT( tmp = (m1 * m1.transpose()).colwise().sum(), (MatrixType::RowsAtCompileTime==Dynamic ? 1 : 0)); } void test_vectorwiseop() { - CALL_SUBTEST_1(vectorwiseop_array(Array22cd())); - CALL_SUBTEST_2(vectorwiseop_array(Array())); - CALL_SUBTEST_3(vectorwiseop_array(ArrayXXf(3, 4))); - CALL_SUBTEST_4(vectorwiseop_matrix(Matrix4cf())); - CALL_SUBTEST_5(vectorwiseop_matrix(Matrix())); - CALL_SUBTEST_6(vectorwiseop_matrix(MatrixXd(7,2))); + CALL_SUBTEST_1( vectorwiseop_array(Array22cd()) ); + CALL_SUBTEST_2( vectorwiseop_array(Array()) ); + CALL_SUBTEST_3( vectorwiseop_array(ArrayXXf(3, 4)) ); + CALL_SUBTEST_4( vectorwiseop_matrix(Matrix4cf()) ); + CALL_SUBTEST_5( vectorwiseop_matrix(Matrix()) ); + CALL_SUBTEST_6( vectorwiseop_matrix(MatrixXd(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); }