mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Add pointer-based iterator for direct-access expressions
This commit is contained in:
parent
8c38528168
commit
0481900e25
@ -572,12 +572,21 @@ template<typename Derived> class DenseBase
|
||||
}
|
||||
EIGEN_DEVICE_FUNC void reverseInPlace();
|
||||
|
||||
inline DenseStlIterator<Derived> begin();
|
||||
inline DenseStlIterator<const Derived> begin() const;
|
||||
inline DenseStlIterator<const Derived> cbegin() const;
|
||||
inline DenseStlIterator<Derived> end();
|
||||
inline DenseStlIterator<const Derived> end() const;
|
||||
inline DenseStlIterator<const Derived> cend() const;
|
||||
typedef typename internal::conditional< (Flags&DirectAccessBit)==DirectAccessBit,
|
||||
PointerBasedStlIterator<Derived>,
|
||||
DenseStlIterator<Derived>
|
||||
>::type iterator;
|
||||
|
||||
typedef typename internal::conditional< (Flags&DirectAccessBit)==DirectAccessBit,
|
||||
PointerBasedStlIterator<const Derived>,
|
||||
DenseStlIterator<const Derived>
|
||||
>::type const_iterator;
|
||||
inline iterator begin();
|
||||
inline const_iterator begin() const;
|
||||
inline const_iterator cbegin() const;
|
||||
inline iterator end();
|
||||
inline const_iterator end() const;
|
||||
inline const_iterator cend() const;
|
||||
inline SubVectorsProxy<Derived,Vertical> allCols();
|
||||
inline SubVectorsProxy<const Derived,Vertical> allCols() const;
|
||||
inline SubVectorsProxy<Derived,Horizontal> allRows();
|
||||
|
@ -56,6 +56,58 @@ protected:
|
||||
Index m_index;
|
||||
};
|
||||
|
||||
template<typename XprType>
|
||||
class PointerBasedStlIterator
|
||||
{
|
||||
enum { is_lvalue = internal::is_lvalue<XprType>::value };
|
||||
public:
|
||||
typedef Index difference_type;
|
||||
typedef typename XprType::Scalar value_type;
|
||||
typedef std::random_access_iterator_tag iterator_category;
|
||||
typedef typename internal::conditional<bool(is_lvalue), value_type*, const value_type*>::type pointer;
|
||||
typedef typename internal::conditional<bool(is_lvalue), value_type&, const value_type&>::type reference;
|
||||
|
||||
PointerBasedStlIterator() : m_ptr(0) {}
|
||||
PointerBasedStlIterator(XprType& xpr, Index index) : m_incr(xpr.innerStride())
|
||||
{
|
||||
m_ptr = xpr.data() + index * m_incr.value();
|
||||
}
|
||||
|
||||
reference operator*() const { return *m_ptr; }
|
||||
reference operator[](Index i) const { return *(m_ptr+i*m_incr.value()); }
|
||||
pointer operator->() const { return m_ptr; }
|
||||
|
||||
PointerBasedStlIterator& operator++() { m_ptr += m_incr.value(); return *this; }
|
||||
PointerBasedStlIterator& operator--() { m_ptr -= m_incr.value(); return *this; }
|
||||
|
||||
PointerBasedStlIterator operator++(int) { PointerBasedStlIterator prev(*this); operator++(); return prev;}
|
||||
PointerBasedStlIterator operator--(int) { PointerBasedStlIterator prev(*this); operator--(); return prev;}
|
||||
|
||||
friend PointerBasedStlIterator operator+(const PointerBasedStlIterator& a, Index b) { PointerBasedStlIterator ret(a); ret += b; return ret; }
|
||||
friend PointerBasedStlIterator operator-(const PointerBasedStlIterator& a, Index b) { PointerBasedStlIterator ret(a); ret -= b; return ret; }
|
||||
friend PointerBasedStlIterator operator+(Index a, const PointerBasedStlIterator& b) { PointerBasedStlIterator ret(b); ret += a; return ret; }
|
||||
friend PointerBasedStlIterator operator-(Index a, const PointerBasedStlIterator& b) { PointerBasedStlIterator ret(b); ret -= a; return ret; }
|
||||
|
||||
PointerBasedStlIterator& operator+=(Index b) { m_ptr += b*m_incr.value(); return *this; }
|
||||
PointerBasedStlIterator& operator-=(Index b) { m_ptr -= b*m_incr.value(); return *this; }
|
||||
|
||||
difference_type operator-(const PointerBasedStlIterator& other) const {
|
||||
return (m_ptr - other.m_ptr)/m_incr.value();
|
||||
}
|
||||
|
||||
bool operator==(const PointerBasedStlIterator& other) { return m_ptr == other.m_ptr; }
|
||||
bool operator!=(const PointerBasedStlIterator& other) { return m_ptr != other.m_ptr; }
|
||||
bool operator< (const PointerBasedStlIterator& other) { return m_ptr < other.m_ptr; }
|
||||
bool operator<=(const PointerBasedStlIterator& other) { return m_ptr <= other.m_ptr; }
|
||||
bool operator> (const PointerBasedStlIterator& other) { return m_ptr > other.m_ptr; }
|
||||
bool operator>=(const PointerBasedStlIterator& other) { return m_ptr >= other.m_ptr; }
|
||||
|
||||
protected:
|
||||
|
||||
pointer m_ptr;
|
||||
internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_incr;
|
||||
};
|
||||
|
||||
template<typename XprType>
|
||||
class DenseStlIterator : public IndexedBasedStlIteratorBase<XprType, DenseStlIterator<XprType> >
|
||||
{
|
||||
@ -66,19 +118,22 @@ protected:
|
||||
|
||||
enum {
|
||||
has_direct_access = (internal::traits<XprType>::Flags & DirectAccessBit) ? 1 : 0,
|
||||
has_write_access = internal::is_lvalue<XprType>::value
|
||||
is_lvalue = internal::is_lvalue<XprType>::value
|
||||
};
|
||||
|
||||
typedef IndexedBasedStlIteratorBase<XprType,DenseStlIterator> Base;
|
||||
using Base::m_index;
|
||||
using Base::mp_xpr;
|
||||
|
||||
typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t;
|
||||
// TODO currently const Transpose/Reshape expressions never returns const references,
|
||||
// so lets return by value too.
|
||||
//typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t;
|
||||
typedef const value_type read_only_ref_t;
|
||||
|
||||
public:
|
||||
|
||||
typedef typename internal::conditional<bool(has_write_access), value_type *, const value_type *>::type pointer;
|
||||
typedef typename internal::conditional<bool(has_write_access), value_type&, read_only_ref_t>::type reference;
|
||||
typedef typename internal::conditional<bool(is_lvalue), value_type *, const value_type *>::type pointer;
|
||||
typedef typename internal::conditional<bool(is_lvalue), value_type&, read_only_ref_t>::type reference;
|
||||
|
||||
DenseStlIterator() : Base() {}
|
||||
DenseStlIterator(XprType& xpr, Index index) : Base(xpr,index) {}
|
||||
@ -94,43 +149,43 @@ void swap(IndexedBasedStlIteratorBase<XprType,Derived>& a, IndexedBasedStlIterat
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<Derived> DenseBase<Derived>::begin()
|
||||
inline typename DenseBase<Derived>::iterator DenseBase<Derived>::begin()
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<Derived>(derived(), 0);
|
||||
return iterator(derived(), 0);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::begin() const
|
||||
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::begin() const
|
||||
{
|
||||
return cbegin();
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::cbegin() const
|
||||
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cbegin() const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<const Derived>(derived(), 0);
|
||||
return const_iterator(derived(), 0);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<Derived> DenseBase<Derived>::end()
|
||||
inline typename DenseBase<Derived>::iterator DenseBase<Derived>::end()
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<Derived>(derived(), size());
|
||||
return iterator(derived(), size());
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::end() const
|
||||
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::end() const
|
||||
{
|
||||
return cend();
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::cend() const
|
||||
inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cend() const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<const Derived>(derived(), size());
|
||||
return const_iterator(derived(), size());
|
||||
}
|
||||
|
||||
template<typename XprType, DirectionType Direction>
|
||||
|
@ -134,6 +134,7 @@ template<typename ExpressionType> class MatrixWrapper;
|
||||
template<typename Derived> class SolverBase;
|
||||
template<typename XprType> class InnerIterator;
|
||||
template<typename XprType> class DenseStlIterator;
|
||||
template<typename XprType> class PointerBasedStlIterator;
|
||||
template<typename XprType, DirectionType Direction> class SubVectorsProxy;
|
||||
|
||||
namespace internal {
|
||||
|
@ -7,6 +7,7 @@
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <numeric>
|
||||
#include "main.h"
|
||||
|
||||
template< class Iterator >
|
||||
@ -16,6 +17,12 @@ make_reverse_iterator( Iterator i )
|
||||
return std::reverse_iterator<Iterator>(i);
|
||||
}
|
||||
|
||||
template<typename XprType>
|
||||
bool is_PointerBasedStlIterator(const PointerBasedStlIterator<XprType> &) { return true; }
|
||||
|
||||
template<typename XprType>
|
||||
bool is_DenseStlIterator(const DenseStlIterator<XprType> &) { return true; }
|
||||
|
||||
template<typename Scalar, int Rows, int Cols>
|
||||
void test_range_for_loop(int rows=Rows, int cols=Cols)
|
||||
{
|
||||
@ -26,10 +33,37 @@ void test_range_for_loop(int rows=Rows, int cols=Cols)
|
||||
typedef Matrix<Scalar,Rows,Cols,ColMajor> ColMatrixType;
|
||||
typedef Matrix<Scalar,Rows,Cols,RowMajor> RowMatrixType;
|
||||
VectorType v = VectorType::Random(rows);
|
||||
const VectorType& cv(v);
|
||||
ColMatrixType A = ColMatrixType::Random(rows,cols);
|
||||
const ColMatrixType& cA(A);
|
||||
RowMatrixType B = RowMatrixType::Random(rows,cols);
|
||||
|
||||
Index i, j;
|
||||
|
||||
VERIFY( is_PointerBasedStlIterator(v.begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(v.end()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cv.begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cv.end()) );
|
||||
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
VERIFY( is_PointerBasedStlIterator(A.col(j).begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(A.col(j).end()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cA.col(j).begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cA.col(j).end()) );
|
||||
|
||||
i = internal::random<Index>(0,A.rows()-1);
|
||||
VERIFY( is_PointerBasedStlIterator(A.row(i).begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(A.row(i).end()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cA.row(i).begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cA.row(i).end()) );
|
||||
|
||||
VERIFY( is_PointerBasedStlIterator(A.reshaped().begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(A.reshaped().end()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cA.reshaped().begin()) );
|
||||
VERIFY( is_PointerBasedStlIterator(cA.reshaped().end()) );
|
||||
|
||||
VERIFY( is_DenseStlIterator(A.template reshaped<RowMajor>().begin()) );
|
||||
VERIFY( is_DenseStlIterator(A.template reshaped<RowMajor>().end()) );
|
||||
|
||||
#if EIGEN_HAS_CXX11
|
||||
i = 0;
|
||||
@ -49,6 +83,19 @@ void test_range_for_loop(int rows=Rows, int cols=Cols)
|
||||
i = 0;
|
||||
for(auto x : A.reshaped()) { VERIFY_IS_EQUAL(x,A(i++)); }
|
||||
|
||||
// check const_iterator
|
||||
{
|
||||
i = 0;
|
||||
for(auto x : cv) { VERIFY_IS_EQUAL(x,v[i++]); }
|
||||
|
||||
i = 0;
|
||||
for(auto x : cA.reshaped()) { VERIFY_IS_EQUAL(x,A(i++)); }
|
||||
|
||||
j = 0;
|
||||
i = internal::random<Index>(0,A.rows()-1);
|
||||
for(auto x : cA.row(i)) { VERIFY_IS_EQUAL(x,A(i,j++)); }
|
||||
}
|
||||
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> Bc = B;
|
||||
i = 0;
|
||||
for(auto x : B.reshaped()) { VERIFY_IS_EQUAL(x,Bc(i++)); }
|
||||
@ -57,6 +104,41 @@ void test_range_for_loop(int rows=Rows, int cols=Cols)
|
||||
i = 0;
|
||||
for(auto& x : w) { x = v(i++); }
|
||||
VERIFY_IS_EQUAL(v,w);
|
||||
|
||||
{
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
auto it = A.col(j).begin();
|
||||
for(i=0;i<rows;++i) {
|
||||
VERIFY_IS_EQUAL(it[i],A(i,j));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
i = internal::random<Index>(0,A.rows()-1);
|
||||
auto it = A.row(i).begin();
|
||||
for(j=0;j<cols;++j) { VERIFY_IS_EQUAL(it[j],A(i,j)); }
|
||||
}
|
||||
|
||||
{
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
// this would produce a dangling pointer:
|
||||
// auto it = (A+2*A).col(j).begin();
|
||||
// we need to name the temporary expression:
|
||||
auto tmp = (A+2*A).col(j);
|
||||
auto it = tmp.begin();
|
||||
for(i=0;i<rows;++i) {
|
||||
VERIFY_IS_APPROX(it[i],3*A(i,j));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// {
|
||||
// j = internal::random<Index>(0,A.cols()-1);
|
||||
// auto it = (A+2*A).col(j).begin();
|
||||
// for(i=0;i<rows;++i) {
|
||||
// VERIFY_IS_APPROX(it[i],3*A(i,j));
|
||||
// }
|
||||
// }
|
||||
#endif
|
||||
|
||||
if(rows>=3) {
|
||||
@ -78,16 +160,45 @@ void test_range_for_loop(int rows=Rows, int cols=Cols)
|
||||
VERIFY(std::is_sorted(begin(v),end(v)));
|
||||
VERIFY(!std::is_sorted(make_reverse_iterator(end(v)),make_reverse_iterator(begin(v))));
|
||||
|
||||
// std::sort with pointer-based iterator and default increment
|
||||
{
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
// std::sort(begin(A.col(j)),end(A.col(j))); // does not compile because this returns const iterators
|
||||
typename ColMatrixType::ColXpr Acol = A.col(j);
|
||||
std::sort(begin(Acol),end(Acol));
|
||||
VERIFY(std::is_sorted(Acol.cbegin(),Acol.cend()));
|
||||
A.setRandom();
|
||||
|
||||
// This raises an assert because this creates a pair of iterator referencing two different proxy objects:
|
||||
// std::sort(A.col(j).begin(),A.col(j).end());
|
||||
// VERIFY(std::is_sorted(A.col(j).cbegin(),A.col(j).cend())); // same issue
|
||||
std::sort(A.col(j).begin(),A.col(j).end());
|
||||
VERIFY(std::is_sorted(A.col(j).cbegin(),A.col(j).cend()));
|
||||
A.setRandom();
|
||||
}
|
||||
|
||||
// std::sort with pointer-based iterator and runtime increment
|
||||
{
|
||||
i = internal::random<Index>(0,A.rows()-1);
|
||||
typename ColMatrixType::RowXpr Arow = A.row(i);
|
||||
VERIFY_IS_EQUAL( std::distance(begin(Arow),end(Arow)), cols);
|
||||
std::sort(begin(Arow),end(Arow));
|
||||
VERIFY(std::is_sorted(Arow.cbegin(),Arow.cend()));
|
||||
A.setRandom();
|
||||
|
||||
std::sort(A.row(i).begin(),A.row(i).end());
|
||||
VERIFY(std::is_sorted(A.row(i).cbegin(),A.row(i).cend()));
|
||||
A.setRandom();
|
||||
}
|
||||
|
||||
// std::sort with generic iterator
|
||||
{
|
||||
auto B1 = B.reshaped();
|
||||
std::sort(begin(B1),end(B1));
|
||||
VERIFY(std::is_sorted(B1.cbegin(),B1.cend()));
|
||||
B.setRandom();
|
||||
|
||||
// assertion because nested expressions are different
|
||||
// std::sort(B.reshaped().begin(),B.reshaped().end());
|
||||
// VERIFY(std::is_sorted(B.reshaped().cbegin(),B.reshaped().cend()));
|
||||
// B.setRandom();
|
||||
}
|
||||
|
||||
{
|
||||
@ -149,6 +260,7 @@ EIGEN_DECLARE_TEST(stl_iterators)
|
||||
for(int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1(( test_range_for_loop<double,2,3>() ));
|
||||
CALL_SUBTEST_1(( test_range_for_loop<float,7,5>() ));
|
||||
CALL_SUBTEST_1(( test_range_for_loop<int,Dynamic,Dynamic>(internal::random<int>(5,10), internal::random<int>(5,10)) ));
|
||||
CALL_SUBTEST_1(( test_range_for_loop<int,Dynamic,Dynamic>(internal::random<int>(10,200), internal::random<int>(10,200)) ));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user