From 27c068e9d6230398b74a1c7b7146d7842c509de7 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 13 Dec 2013 18:09:07 +0100 Subject: [PATCH] Make selfqdjoint products use evaluators --- Eigen/src/Core/ProductEvaluators.h | 71 ++++++++++++++++++++++++++++++ Eigen/src/Core/SelfAdjointView.h | 43 ++++++++++++++++-- test/evaluators.cpp | 3 +- 3 files changed, 113 insertions(+), 4 deletions(-) diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index c3a9f0db4..f0eb57d67 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -628,6 +628,77 @@ protected: +/*************************************************************************** +* SelfAdjoint products +***************************************************************************/ + +template +struct generic_product_impl + : generic_product_impl_base > +{ + typedef typename Product::Scalar Scalar; + + template + static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) + {// SelfadjointProductMatrix + // TODO bypass SelfadjointProductMatrix class + SelfadjointProductMatrix(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha); + } +}; + +template +struct product_evaluator, ProductTag, SelfAdjointShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public evaluator::PlainObject>::type +{ + typedef Product XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator::type Base; + + product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast(this)) Base(m_result); + generic_product_impl::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + + +template +struct generic_product_impl +: generic_product_impl_base > +{ + typedef typename Product::Scalar Scalar; + + template + static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) + {//SelfadjointProductMatrix + // TODO bypass SelfadjointProductMatrix class + SelfadjointProductMatrix(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha); + } +}; + +template +struct product_evaluator, ProductTag, DenseShape, SelfAdjointShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public evaluator::PlainObject>::type +{ + typedef Product XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator::type Base; + + product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast(this)) Base(m_result); + generic_product_impl::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + } // end namespace internal diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 8231e3f5c..079b987f8 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -50,11 +50,12 @@ template class SelfAdjointView - : public TriangularBase > +template class SelfAdjointView + : public TriangularBase > { public: + typedef _MatrixType MatrixType; typedef TriangularBase Base; typedef typename internal::traits::MatrixTypeNested MatrixTypeNested; typedef typename internal::traits::MatrixTypeNestedCleaned MatrixTypeNestedCleaned; @@ -65,7 +66,8 @@ template class SelfAdjointView typedef typename MatrixType::Index Index; enum { - Mode = internal::traits::Mode + Mode = internal::traits::Mode, + Flags = internal::traits::Flags }; typedef typename MatrixType::PlainObject PlainObject; @@ -111,6 +113,28 @@ template class SelfAdjointView EIGEN_DEVICE_FUNC MatrixTypeNestedCleaned& nestedExpression() { return *const_cast(&m_matrix); } +#ifdef EIGEN_TEST_EVALUATORS + + /** Efficient triangular matrix times vector/matrix product */ + template + EIGEN_DEVICE_FUNC + const Product + operator*(const MatrixBase& rhs) const + { + return Product(*this, rhs.derived()); + } + + /** Efficient vector/matrix times triangular matrix product */ + template friend + EIGEN_DEVICE_FUNC + const Product + operator*(const MatrixBase& lhs, const SelfAdjointView& rhs) + { + return Product(lhs.derived(),rhs); + } + +#else // EIGEN_TEST_EVALUATORS + /** Efficient self-adjoint matrix times vector/matrix product */ template EIGEN_DEVICE_FUNC @@ -132,6 +156,7 @@ template class SelfAdjointView (lhs.derived(),rhs.m_matrix); } +#endif /** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this: * \f$ this = this + \alpha u v^* + conj(\alpha) v u^* \f$ @@ -311,6 +336,18 @@ struct triangular_assignment_selector +// in the future selfadjoint-ness should be defined by the expression traits +// such that Transpose > is valid. (currently TriangularBase::transpose() is overloaded to make it work) +template +struct evaluator_traits > +{ + typedef typename storage_kind_to_evaluator_kind::Kind Kind; + typedef SelfAdjointShape Shape; + + static const int AssumeAliasing = 0; +}; + } // end namespace internal /*************************************************************************** diff --git a/test/evaluators.cpp b/test/evaluators.cpp index d4b737348..69a45661f 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -455,6 +455,7 @@ void test_evaluators() VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.triangularView(),A), MatrixXd(A.triangularView()*A)); - B.col(0).noalias() = prod( (2.1 * A.adjoint()).triangularView() , (A.row(0)).adjoint() ); + VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView(),A), MatrixXd(A.selfadjointView()*A)); + } }