Add a plain_object_eval<> helper returning a plain object type based on evaluator's Flags,

and base nested_eval on it.
This commit is contained in:
Gael Guennebaud 2015-10-14 10:12:58 +02:00
parent b4c79ee1d3
commit 2598f3987e
3 changed files with 30 additions and 16 deletions

View File

@ -233,33 +233,33 @@ template<typename XprType> struct size_of_xpr_at_compile_time
*/
template<typename T, typename StorageKind = typename traits<T>::StorageKind> struct plain_matrix_type;
template<typename T, typename BaseClassType> struct plain_matrix_type_dense;
template<typename T, typename BaseClassType, int Flags> struct plain_matrix_type_dense;
template<typename T> struct plain_matrix_type<T,Dense>
{
typedef typename plain_matrix_type_dense<T,typename traits<T>::XprKind>::type type;
typedef typename plain_matrix_type_dense<T,typename traits<T>::XprKind, traits<T>::Flags>::type type;
};
template<typename T> struct plain_matrix_type<T,DiagonalShape>
{
typedef typename T::PlainObject type;
};
template<typename T> struct plain_matrix_type_dense<T,MatrixXpr>
template<typename T, int Flags> struct plain_matrix_type_dense<T,MatrixXpr,Flags>
{
typedef Matrix<typename traits<T>::Scalar,
traits<T>::RowsAtCompileTime,
traits<T>::ColsAtCompileTime,
AutoAlign | (traits<T>::Flags&RowMajorBit ? RowMajor : ColMajor),
AutoAlign | (Flags&RowMajorBit ? RowMajor : ColMajor),
traits<T>::MaxRowsAtCompileTime,
traits<T>::MaxColsAtCompileTime
> type;
};
template<typename T> struct plain_matrix_type_dense<T,ArrayXpr>
template<typename T, int Flags> struct plain_matrix_type_dense<T,ArrayXpr,Flags>
{
typedef Array<typename traits<T>::Scalar,
traits<T>::RowsAtCompileTime,
traits<T>::ColsAtCompileTime,
AutoAlign | (traits<T>::Flags&RowMajorBit ? RowMajor : ColMajor),
AutoAlign | (Flags&RowMajorBit ? RowMajor : ColMajor),
traits<T>::MaxRowsAtCompileTime,
traits<T>::MaxColsAtCompileTime
> type;
@ -303,6 +303,15 @@ struct eval<Array<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols>, Dense>
};
/* similar to plain_matrix_type, but using the evaluator's Flags */
template<typename T, typename StorageKind = typename traits<T>::StorageKind> struct plain_object_eval;
template<typename T>
struct plain_object_eval<T,Dense>
{
typedef typename plain_matrix_type_dense<T,typename traits<T>::XprKind, evaluator<T>::Flags>::type type;
};
/* plain_matrix_type_column_major : same as plain_matrix_type but guaranteed to be column-major
*/
@ -385,7 +394,7 @@ struct transfer_constness
* \param n the number of coefficient accesses in the nested expression for each coefficient access in the bigger expression.
* \param PlainObject the type of the temporary if needed.
*/
template<typename T, int n, typename PlainObject = typename eval<T>::type> struct nested_eval
template<typename T, int n, typename PlainObject = typename plain_object_eval<T>::type> struct nested_eval
{
enum {
// For the purpose of this test, to keep it reasonably simple, we arbitrarily choose a value of Dynamic values.

View File

@ -138,7 +138,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,C
{
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrixAux;
typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime>::type ColMajorMatrix;
typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime,ColMajorMatrixAux::Flags>::type ColMajorMatrix;
// If the result is tall and thin (in the extreme case a column vector)
// then it is faster to sort the coefficients inplace instead of transposing twice.

View File

@ -74,20 +74,20 @@ template<typename MatrixType,int UpLo> class SparseSymmetricPermutationProduct;
namespace internal {
template<typename T,int Rows,int Cols> struct sparse_eval;
template<typename T,int Rows,int Cols,int Flags> struct sparse_eval;
template<typename T> struct eval<T,Sparse>
: public sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime>
: sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime,traits<T>::Flags>
{};
template<typename T,int Cols> struct sparse_eval<T,1,Cols> {
template<typename T,int Cols,int Flags> struct sparse_eval<T,1,Cols,Flags> {
typedef typename traits<T>::Scalar _Scalar;
typedef typename traits<T>::StorageIndex _StorageIndex;
public:
typedef SparseVector<_Scalar, RowMajor, _StorageIndex> type;
};
template<typename T,int Rows> struct sparse_eval<T,Rows,1> {
template<typename T,int Rows,int Flags> struct sparse_eval<T,Rows,1,Flags> {
typedef typename traits<T>::Scalar _Scalar;
typedef typename traits<T>::StorageIndex _StorageIndex;
public:
@ -95,15 +95,15 @@ template<typename T,int Rows> struct sparse_eval<T,Rows,1> {
};
// TODO this seems almost identical to plain_matrix_type<T, Sparse>
template<typename T,int Rows,int Cols> struct sparse_eval {
template<typename T,int Rows,int Cols,int Flags> struct sparse_eval {
typedef typename traits<T>::Scalar _Scalar;
typedef typename traits<T>::StorageIndex _StorageIndex;
enum { _Options = ((traits<T>::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor };
enum { _Options = ((Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor };
public:
typedef SparseMatrix<_Scalar, _Options, _StorageIndex> type;
};
template<typename T> struct sparse_eval<T,1,1> {
template<typename T,int Flags> struct sparse_eval<T,1,1,Flags> {
typedef typename traits<T>::Scalar _Scalar;
public:
typedef Matrix<_Scalar, 1, 1> type;
@ -118,10 +118,15 @@ template<typename T> struct plain_matrix_type<T,Sparse>
typedef SparseMatrix<_Scalar, _Options, _StorageIndex> type;
};
template<typename T>
struct plain_object_eval<T,Sparse>
: sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime, evaluator<T>::Flags>
{};
template<typename Decomposition, typename RhsType>
struct solve_traits<Decomposition,RhsType,Sparse>
{
typedef typename sparse_eval<RhsType, RhsType::RowsAtCompileTime, RhsType::ColsAtCompileTime>::type PlainObject;
typedef typename sparse_eval<RhsType, RhsType::RowsAtCompileTime, RhsType::ColsAtCompileTime,traits<RhsType>::Flags>::type PlainObject;
};
template<typename Derived>