Fallback to Block<> when possible (Index, all, seq with > increment).

This is important to take advantage of the optimized implementations (evaluator, products, etc.),
and to support sparse matrices.
This commit is contained in:
Gael Guennebaud 2017-01-10 14:25:30 +01:00
parent a98c7efb16
commit 87963f441c
4 changed files with 53 additions and 3 deletions

View File

@ -272,6 +272,7 @@ public:
}; };
Index size() const { return m_size; } Index size() const { return m_size; }
Index first() const { return m_first; }
Index operator[](Index i) const { return m_first + i * m_incr; } Index operator[](Index i) const { return m_first + i * m_incr; }
const FirstType& firstObject() const { return m_first; } const FirstType& firstObject() const { return m_first; }
@ -414,6 +415,9 @@ Index size(const T& x) { return x.size(); }
template<typename T,std::size_t N> template<typename T,std::size_t N>
Index size(const T (&) [N]) { return N; } Index size(const T (&) [N]) { return N; }
template<typename T>
Index first(const T& x) { return x.first(); }
template<typename T, int XprSize, typename EnableIf = void> struct get_compile_time_size { template<typename T, int XprSize, typename EnableIf = void> struct get_compile_time_size {
enum { value = Dynamic }; enum { value = Dynamic };
}; };
@ -458,6 +462,7 @@ struct IntAsArray {
IntAsArray(Index val) : m_value(val) {} IntAsArray(Index val) : m_value(val) {}
Index operator[](Index) const { return m_value; } Index operator[](Index) const { return m_value; }
Index size() const { return 1; } Index size() const { return 1; }
Index first() const { return m_value; }
Index m_value; Index m_value;
}; };
@ -512,6 +517,7 @@ struct AllRange {
AllRange(Index size) : m_size(size) {} AllRange(Index size) : m_size(size) {}
Index operator[](Index i) const { return i; } Index operator[](Index i) const { return i; }
Index size() const { return m_size; } Index size() const { return m_size; }
Index first() const { return 0; }
Index m_size; Index m_size;
}; };

View File

@ -557,15 +557,36 @@ template<typename Derived> class DenseBase
} }
EIGEN_DEVICE_FUNC void reverseInPlace(); EIGEN_DEVICE_FUNC void reverseInPlace();
template<typename RowIndices, typename ColIndices>
struct IndexedViewType {
typedef IndexedView<const Derived,typename internal::MakeIndexing<RowIndices>::type,typename internal::MakeIndexing<ColIndices>::type> type;
};
template<typename RowIndices, typename ColIndices> template<typename RowIndices, typename ColIndices>
typename internal::enable_if< typename internal::enable_if<
!(internal::is_integral<RowIndices>::value && internal::is_integral<ColIndices>::value), ! (internal::traits<typename IndexedViewType<RowIndices,ColIndices>::type>::IsBlockAlike
IndexedView<const Derived,typename internal::MakeIndexing<RowIndices>::type,typename internal::MakeIndexing<ColIndices>::type> >::type || (internal::is_integral<RowIndices>::value && internal::is_integral<ColIndices>::value)),
typename IndexedViewType<RowIndices,ColIndices>::type >::type
operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const {
return IndexedView<const Derived,typename internal::MakeIndexing<RowIndices>::type,typename internal::MakeIndexing<ColIndices>::type>( return typename IndexedViewType<RowIndices,ColIndices>::type(
derived(), internal::make_indexing(rowIndices,derived().rows()), internal::make_indexing(colIndices,derived().cols())); derived(), internal::make_indexing(rowIndices,derived().rows()), internal::make_indexing(colIndices,derived().cols()));
} }
template<typename RowIndices, typename ColIndices>
typename internal::enable_if<
internal::traits<typename IndexedViewType<RowIndices,ColIndices>::type>::IsBlockAlike,
typename internal::traits<typename IndexedViewType<RowIndices,ColIndices>::type>::BlockType>::type
operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const {
typedef typename internal::traits<typename IndexedViewType<RowIndices,ColIndices>::type>::BlockType BlockType;
typename internal::MakeIndexing<RowIndices>::type actualRowIndices = internal::make_indexing(rowIndices,derived().rows());
typename internal::MakeIndexing<ColIndices>::type actualColIndices = internal::make_indexing(colIndices,derived().cols());
return BlockType(derived(),
internal::first(actualRowIndices),
internal::first(actualColIndices),
internal::size(actualRowIndices),
internal::size(actualColIndices));
}
template<typename RowIndicesT, std::size_t RowIndicesN, typename ColIndices> template<typename RowIndicesT, std::size_t RowIndicesN, typename ColIndices>
IndexedView<const Derived,const RowIndicesT (&)[RowIndicesN],typename internal::MakeIndexing<ColIndices>::type> IndexedView<const Derived,const RowIndicesT (&)[RowIndicesN],typename internal::MakeIndexing<ColIndices>::type>
operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndices& colIndices) const { operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndices& colIndices) const {

View File

@ -38,6 +38,9 @@ struct traits<IndexedView<XprType, RowIndices, ColIndices> >
XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret) : int(outer_stride_at_compile_time<XprType>::ret), XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret) : int(outer_stride_at_compile_time<XprType>::ret),
XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret) : int(inner_stride_at_compile_time<XprType>::ret), XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret) : int(inner_stride_at_compile_time<XprType>::ret),
IsBlockAlike = InnerIncr==1 && OuterIncr==1,
IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange,typename conditional<XprTypeIsRowMajor,ColIndices,RowIndices>::type>::value,
InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr, InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr,
OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr, OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr,
@ -48,8 +51,11 @@ struct traits<IndexedView<XprType, RowIndices, ColIndices> >
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0, FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask)) | FlagsLvalueBit | FlagsRowMajorBit Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask)) | FlagsLvalueBit | FlagsRowMajorBit
}; };
typedef Block<XprType,RowsAtCompileTime,ColsAtCompileTime,IsInnerPannel> BlockType;
}; };
} }
template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind> template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>

View File

@ -41,6 +41,13 @@ bool match(const T& xpr, std::string ref, std::string str_xpr = "") {
#define MATCH(X,R) match(X, R, #X) #define MATCH(X,R) match(X, R, #X)
template<typename T1,typename T2>
typename internal::enable_if<internal::is_same<T1,T2>::value,bool>::type
is_same_type(const T1& a, const T2& b)
{
return (a == b).all();
}
void check_indexed_view() void check_indexed_view()
{ {
using Eigen::placeholders::last; using Eigen::placeholders::last;
@ -159,6 +166,16 @@ void check_indexed_view()
VERIFY_IS_APPROX( A(seq(1,n-1-2), seq(n-1-5,7)), A(seq(1,last-2), seq(last-5,7)) ); VERIFY_IS_APPROX( A(seq(1,n-1-2), seq(n-1-5,7)), A(seq(1,last-2), seq(last-5,7)) );
VERIFY_IS_APPROX( A(seq(n-1-5,n-1-2), seq(n-1-5,n-1-2)), A(seq(last-5,last-2), seq(last-5,last-2)) ); VERIFY_IS_APPROX( A(seq(n-1-5,n-1-2), seq(n-1-5,n-1-2)), A(seq(last-5,last-2), seq(last-5,last-2)) );
// Check fall-back to Block
{
const ArrayXXi& cA(A);
VERIFY( is_same_type(cA.col(0), cA(all,0)) );
VERIFY( is_same_type(cA.row(0), cA(0,all)) );
VERIFY( is_same_type(cA.block(0,0,2,2), cA(seqN(0,2),seq(0,1))) );
VERIFY( is_same_type(cA.middleRows(2,4), cA(seqN(2,4),all)) );
VERIFY( is_same_type(cA.middleCols(2,4), cA(all,seqN(2,4))) );
}
#if EIGEN_HAS_CXX11 #if EIGEN_HAS_CXX11
VERIFY( (A(all, std::array<int,4>{{1,3,2,4}})).ColsAtCompileTime == 4); VERIFY( (A(all, std::array<int,4>{{1,3,2,4}})).ColsAtCompileTime == 4);