Add a plain_object_eval<> helper returning a plain object type based on evaluator's Flags,

and base nested_eval on it.
This commit is contained in:
Gael Guennebaud 2015-10-14 10:12:58 +02:00
parent b4c79ee1d3
commit 2598f3987e
3 changed files with 30 additions and 16 deletions

View File

@ -233,33 +233,33 @@ template<typename XprType> struct size_of_xpr_at_compile_time
*/ */
template<typename T, typename StorageKind = typename traits<T>::StorageKind> struct plain_matrix_type; template<typename T, typename StorageKind = typename traits<T>::StorageKind> struct plain_matrix_type;
template<typename T, typename BaseClassType> struct plain_matrix_type_dense; template<typename T, typename BaseClassType, int Flags> struct plain_matrix_type_dense;
template<typename T> struct plain_matrix_type<T,Dense> template<typename T> struct plain_matrix_type<T,Dense>
{ {
typedef typename plain_matrix_type_dense<T,typename traits<T>::XprKind>::type type; typedef typename plain_matrix_type_dense<T,typename traits<T>::XprKind, traits<T>::Flags>::type type;
}; };
template<typename T> struct plain_matrix_type<T,DiagonalShape> template<typename T> struct plain_matrix_type<T,DiagonalShape>
{ {
typedef typename T::PlainObject type; typedef typename T::PlainObject type;
}; };
template<typename T> struct plain_matrix_type_dense<T,MatrixXpr> template<typename T, int Flags> struct plain_matrix_type_dense<T,MatrixXpr,Flags>
{ {
typedef Matrix<typename traits<T>::Scalar, typedef Matrix<typename traits<T>::Scalar,
traits<T>::RowsAtCompileTime, traits<T>::RowsAtCompileTime,
traits<T>::ColsAtCompileTime, traits<T>::ColsAtCompileTime,
AutoAlign | (traits<T>::Flags&RowMajorBit ? RowMajor : ColMajor), AutoAlign | (Flags&RowMajorBit ? RowMajor : ColMajor),
traits<T>::MaxRowsAtCompileTime, traits<T>::MaxRowsAtCompileTime,
traits<T>::MaxColsAtCompileTime traits<T>::MaxColsAtCompileTime
> type; > type;
}; };
template<typename T> struct plain_matrix_type_dense<T,ArrayXpr> template<typename T, int Flags> struct plain_matrix_type_dense<T,ArrayXpr,Flags>
{ {
typedef Array<typename traits<T>::Scalar, typedef Array<typename traits<T>::Scalar,
traits<T>::RowsAtCompileTime, traits<T>::RowsAtCompileTime,
traits<T>::ColsAtCompileTime, traits<T>::ColsAtCompileTime,
AutoAlign | (traits<T>::Flags&RowMajorBit ? RowMajor : ColMajor), AutoAlign | (Flags&RowMajorBit ? RowMajor : ColMajor),
traits<T>::MaxRowsAtCompileTime, traits<T>::MaxRowsAtCompileTime,
traits<T>::MaxColsAtCompileTime traits<T>::MaxColsAtCompileTime
> type; > type;
@ -303,6 +303,15 @@ struct eval<Array<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols>, Dense>
}; };
/* similar to plain_matrix_type, but using the evaluator's Flags */
template<typename T, typename StorageKind = typename traits<T>::StorageKind> struct plain_object_eval;
template<typename T>
struct plain_object_eval<T,Dense>
{
typedef typename plain_matrix_type_dense<T,typename traits<T>::XprKind, evaluator<T>::Flags>::type type;
};
/* plain_matrix_type_column_major : same as plain_matrix_type but guaranteed to be column-major /* plain_matrix_type_column_major : same as plain_matrix_type but guaranteed to be column-major
*/ */
@ -385,7 +394,7 @@ struct transfer_constness
* \param n the number of coefficient accesses in the nested expression for each coefficient access in the bigger expression. * \param n the number of coefficient accesses in the nested expression for each coefficient access in the bigger expression.
* \param PlainObject the type of the temporary if needed. * \param PlainObject the type of the temporary if needed.
*/ */
template<typename T, int n, typename PlainObject = typename eval<T>::type> struct nested_eval template<typename T, int n, typename PlainObject = typename plain_object_eval<T>::type> struct nested_eval
{ {
enum { enum {
// For the purpose of this test, to keep it reasonably simple, we arbitrarily choose a value of Dynamic values. // For the purpose of this test, to keep it reasonably simple, we arbitrarily choose a value of Dynamic values.

View File

@ -138,7 +138,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,C
{ {
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrixAux; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrixAux;
typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime>::type ColMajorMatrix; typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime,ColMajorMatrixAux::Flags>::type ColMajorMatrix;
// If the result is tall and thin (in the extreme case a column vector) // If the result is tall and thin (in the extreme case a column vector)
// then it is faster to sort the coefficients inplace instead of transposing twice. // then it is faster to sort the coefficients inplace instead of transposing twice.

View File

@ -74,20 +74,20 @@ template<typename MatrixType,int UpLo> class SparseSymmetricPermutationProduct;
namespace internal { namespace internal {
template<typename T,int Rows,int Cols> struct sparse_eval; template<typename T,int Rows,int Cols,int Flags> struct sparse_eval;
template<typename T> struct eval<T,Sparse> template<typename T> struct eval<T,Sparse>
: public sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime> : sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime,traits<T>::Flags>
{}; {};
template<typename T,int Cols> struct sparse_eval<T,1,Cols> { template<typename T,int Cols,int Flags> struct sparse_eval<T,1,Cols,Flags> {
typedef typename traits<T>::Scalar _Scalar; typedef typename traits<T>::Scalar _Scalar;
typedef typename traits<T>::StorageIndex _StorageIndex; typedef typename traits<T>::StorageIndex _StorageIndex;
public: public:
typedef SparseVector<_Scalar, RowMajor, _StorageIndex> type; typedef SparseVector<_Scalar, RowMajor, _StorageIndex> type;
}; };
template<typename T,int Rows> struct sparse_eval<T,Rows,1> { template<typename T,int Rows,int Flags> struct sparse_eval<T,Rows,1,Flags> {
typedef typename traits<T>::Scalar _Scalar; typedef typename traits<T>::Scalar _Scalar;
typedef typename traits<T>::StorageIndex _StorageIndex; typedef typename traits<T>::StorageIndex _StorageIndex;
public: public:
@ -95,15 +95,15 @@ template<typename T,int Rows> struct sparse_eval<T,Rows,1> {
}; };
// TODO this seems almost identical to plain_matrix_type<T, Sparse> // TODO this seems almost identical to plain_matrix_type<T, Sparse>
template<typename T,int Rows,int Cols> struct sparse_eval { template<typename T,int Rows,int Cols,int Flags> struct sparse_eval {
typedef typename traits<T>::Scalar _Scalar; typedef typename traits<T>::Scalar _Scalar;
typedef typename traits<T>::StorageIndex _StorageIndex; typedef typename traits<T>::StorageIndex _StorageIndex;
enum { _Options = ((traits<T>::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor }; enum { _Options = ((Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor };
public: public:
typedef SparseMatrix<_Scalar, _Options, _StorageIndex> type; typedef SparseMatrix<_Scalar, _Options, _StorageIndex> type;
}; };
template<typename T> struct sparse_eval<T,1,1> { template<typename T,int Flags> struct sparse_eval<T,1,1,Flags> {
typedef typename traits<T>::Scalar _Scalar; typedef typename traits<T>::Scalar _Scalar;
public: public:
typedef Matrix<_Scalar, 1, 1> type; typedef Matrix<_Scalar, 1, 1> type;
@ -118,10 +118,15 @@ template<typename T> struct plain_matrix_type<T,Sparse>
typedef SparseMatrix<_Scalar, _Options, _StorageIndex> type; typedef SparseMatrix<_Scalar, _Options, _StorageIndex> type;
}; };
template<typename T>
struct plain_object_eval<T,Sparse>
: sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime, evaluator<T>::Flags>
{};
template<typename Decomposition, typename RhsType> template<typename Decomposition, typename RhsType>
struct solve_traits<Decomposition,RhsType,Sparse> struct solve_traits<Decomposition,RhsType,Sparse>
{ {
typedef typename sparse_eval<RhsType, RhsType::RowsAtCompileTime, RhsType::ColsAtCompileTime>::type PlainObject; typedef typename sparse_eval<RhsType, RhsType::RowsAtCompileTime, RhsType::ColsAtCompileTime,traits<RhsType>::Flags>::type PlainObject;
}; };
template<typename Derived> template<typename Derived>