Make selfqdjoint products use evaluators

This commit is contained in:
Gael Guennebaud 2013-12-13 18:09:07 +01:00
parent e94fe4cc3e
commit 27c068e9d6
3 changed files with 113 additions and 4 deletions

View File

@ -628,6 +628,77 @@ protected:
/***************************************************************************
* SelfAdjoint products
***************************************************************************/
template<typename Lhs, typename Rhs, int ProductTag>
struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag>
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> >
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
template<typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{// SelfadjointProductMatrix<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
// TODO bypass SelfadjointProductMatrix class
SelfadjointProductMatrix<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha);
}
};
template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SelfAdjointShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
{
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type Base;
product_evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
::new (static_cast<Base*>(this)) Base(m_result);
generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
}
protected:
PlainObject m_result;
};
template<typename Lhs, typename Rhs, int ProductTag>
struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag>
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> >
{
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
template<typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{//SelfadjointProductMatrix<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
// TODO bypass SelfadjointProductMatrix class
SelfadjointProductMatrix<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha);
}
};
template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DenseShape, SelfAdjointShape, typename Lhs::Scalar, typename Rhs::Scalar>
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
{
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type Base;
product_evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
::new (static_cast<Base*>(this)) Base(m_result);
generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
}
protected:
PlainObject m_result;
};
} // end namespace internal

View File

@ -50,11 +50,12 @@ template <typename Lhs, int LhsMode, bool LhsIsVector,
struct SelfadjointProductMatrix;
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
: public TriangularBase<SelfAdjointView<MatrixType, UpLo> >
template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView
: public TriangularBase<SelfAdjointView<_MatrixType, UpLo> >
{
public:
typedef _MatrixType MatrixType;
typedef TriangularBase<SelfAdjointView> Base;
typedef typename internal::traits<SelfAdjointView>::MatrixTypeNested MatrixTypeNested;
typedef typename internal::traits<SelfAdjointView>::MatrixTypeNestedCleaned MatrixTypeNestedCleaned;
@ -65,7 +66,8 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
typedef typename MatrixType::Index Index;
enum {
Mode = internal::traits<SelfAdjointView>::Mode
Mode = internal::traits<SelfAdjointView>::Mode,
Flags = internal::traits<SelfAdjointView>::Flags
};
typedef typename MatrixType::PlainObject PlainObject;
@ -111,6 +113,28 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
EIGEN_DEVICE_FUNC
MatrixTypeNestedCleaned& nestedExpression() { return *const_cast<MatrixTypeNestedCleaned*>(&m_matrix); }
#ifdef EIGEN_TEST_EVALUATORS
/** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
const Product<SelfAdjointView,OtherDerived>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return Product<SelfAdjointView,OtherDerived>(*this, rhs.derived());
}
/** Efficient vector/matrix times triangular matrix product */
template<typename OtherDerived> friend
EIGEN_DEVICE_FUNC
const Product<OtherDerived,SelfAdjointView>
operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs)
{
return Product<OtherDerived,SelfAdjointView>(lhs.derived(),rhs);
}
#else // EIGEN_TEST_EVALUATORS
/** Efficient self-adjoint matrix times vector/matrix product */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
@ -132,6 +156,7 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
(lhs.derived(),rhs.m_matrix);
}
#endif
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
* \f$ this = this + \alpha u v^* + conj(\alpha) v u^* \f$
@ -311,6 +336,18 @@ struct triangular_assignment_selector<Derived1, Derived2, SelfAdjoint|Lower, Dyn
}
};
// TODO currently a selfadjoint expression has the form SelfAdjointView<.,.>
// in the future selfadjoint-ness should be defined by the expression traits
// such that Transpose<SelfAdjointView<.,.> > is valid. (currently TriangularBase::transpose() is overloaded to make it work)
template<typename MatrixType, unsigned int Mode>
struct evaluator_traits<SelfAdjointView<MatrixType,Mode> >
{
typedef typename storage_kind_to_evaluator_kind<typename MatrixType::StorageKind>::Kind Kind;
typedef SelfAdjointShape Shape;
static const int AssumeAliasing = 0;
};
} // end namespace internal
/***************************************************************************

View File

@ -455,6 +455,7 @@ void test_evaluators()
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.triangularView<Upper>(),A), MatrixXd(A.triangularView<Upper>()*A));
B.col(0).noalias() = prod( (2.1 * A.adjoint()).triangularView<UnitUpper>() , (A.row(0)).adjoint() );
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
}
}