bug #1619: fix mixing of const and non-const generic iterators

This commit is contained in:
Gael Guennebaud 2018-11-09 21:45:10 +01:00
parent db9a9a12ba
commit 784a3f13cf
2 changed files with 82 additions and 18 deletions

View File

@ -11,9 +11,20 @@ namespace Eigen {
namespace internal { namespace internal {
template<typename XprType,typename Derived> template<typename IteratorType>
struct indexed_based_stl_iterator_traits;
template<typename Derived>
class indexed_based_stl_iterator_base class indexed_based_stl_iterator_base
{ {
protected:
typedef indexed_based_stl_iterator_traits<Derived> traits;
typedef typename traits::XprType XprType;
typedef indexed_based_stl_iterator_base<typename traits::non_const_iterator> non_const_iterator;
typedef indexed_based_stl_iterator_base<typename traits::const_iterator> const_iterator;
typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
friend const_iterator;
friend non_const_iterator;
public: public:
typedef Index difference_type; typedef Index difference_type;
typedef std::random_access_iterator_tag iterator_category; typedef std::random_access_iterator_tag iterator_category;
@ -21,6 +32,17 @@ public:
indexed_based_stl_iterator_base() : mp_xpr(0), m_index(0) {} indexed_based_stl_iterator_base() : mp_xpr(0), m_index(0) {}
indexed_based_stl_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {} indexed_based_stl_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {}
indexed_based_stl_iterator_base(const non_const_iterator& other)
: mp_xpr(other.mp_xpr), m_index(other.m_index)
{}
indexed_based_stl_iterator_base& operator=(const non_const_iterator& other)
{
mp_xpr = other.mp_xpr;
m_index = other.m_index;
return *this;
}
Derived& operator++() { ++m_index; return derived(); } Derived& operator++() { ++m_index; return derived(); }
Derived& operator--() { --m_index; return derived(); } Derived& operator--() { --m_index; return derived(); }
@ -35,14 +57,31 @@ public:
Derived& operator+=(Index b) { m_index += b; return derived(); } Derived& operator+=(Index b) { m_index += b; return derived(); }
Derived& operator-=(Index b) { m_index -= b; return derived(); } Derived& operator-=(Index b) { m_index -= b; return derived(); }
difference_type operator-(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr);return m_index - other.m_index; } difference_type operator-(const indexed_based_stl_iterator_base& other) const
{
eigen_assert(mp_xpr == other.mp_xpr);
return m_index - other.m_index;
}
bool operator==(const indexed_based_stl_iterator_base& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; } difference_type operator-(const other_iterator& other) const
bool operator!=(const indexed_based_stl_iterator_base& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; } {
bool operator< (const indexed_based_stl_iterator_base& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; } eigen_assert(mp_xpr == other.mp_xpr);
bool operator<=(const indexed_based_stl_iterator_base& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; } return m_index - other.m_index;
bool operator> (const indexed_based_stl_iterator_base& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; } }
bool operator>=(const indexed_based_stl_iterator_base& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
bool operator==(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
bool operator!=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
bool operator< (const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
bool operator<=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
bool operator> (const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
bool operator>=(const indexed_based_stl_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
protected: protected:
@ -57,8 +96,8 @@ template<typename XprType>
class pointer_based_stl_iterator class pointer_based_stl_iterator
{ {
enum { is_lvalue = internal::is_lvalue<XprType>::value }; enum { is_lvalue = internal::is_lvalue<XprType>::value };
typedef pointer_based_stl_iterator<typename internal::remove_const<XprType>::type > non_const_iterator; typedef pointer_based_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
typedef pointer_based_stl_iterator<typename internal::add_const<XprType>::type > const_iterator; typedef pointer_based_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator; typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
friend const_iterator; friend const_iterator;
friend non_const_iterator; friend non_const_iterator;
@ -133,8 +172,16 @@ protected:
internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_incr; internal::variable_if_dynamic<Index, XprType::InnerStrideAtCompileTime> m_incr;
}; };
template<typename _XprType>
struct indexed_based_stl_iterator_traits<generic_randaccess_stl_iterator<_XprType> >
{
typedef _XprType XprType;
typedef generic_randaccess_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
typedef generic_randaccess_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
};
template<typename XprType> template<typename XprType>
class generic_randaccess_stl_iterator : public indexed_based_stl_iterator_base<XprType, generic_randaccess_stl_iterator<XprType> > class generic_randaccess_stl_iterator : public indexed_based_stl_iterator_base<generic_randaccess_stl_iterator<XprType> >
{ {
public: public:
typedef typename XprType::Scalar value_type; typedef typename XprType::Scalar value_type;
@ -146,7 +193,7 @@ protected:
is_lvalue = internal::is_lvalue<XprType>::value is_lvalue = internal::is_lvalue<XprType>::value
}; };
typedef indexed_based_stl_iterator_base<XprType,generic_randaccess_stl_iterator> Base; typedef indexed_based_stl_iterator_base<generic_randaccess_stl_iterator> Base;
using Base::m_index; using Base::m_index;
using Base::mp_xpr; using Base::mp_xpr;
@ -162,20 +209,30 @@ public:
generic_randaccess_stl_iterator() : Base() {} generic_randaccess_stl_iterator() : Base() {}
generic_randaccess_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {} generic_randaccess_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
generic_randaccess_stl_iterator(const typename Base::non_const_iterator& other) : Base(other) {}
using Base::operator=;
reference operator*() const { return (*mp_xpr)(m_index); } reference operator*() const { return (*mp_xpr)(m_index); }
reference operator[](Index i) const { return (*mp_xpr)(m_index+i); } reference operator[](Index i) const { return (*mp_xpr)(m_index+i); }
pointer operator->() const { return &((*mp_xpr)(m_index)); } pointer operator->() const { return &((*mp_xpr)(m_index)); }
}; };
template<typename _XprType, DirectionType Direction>
struct indexed_based_stl_iterator_traits<subvector_stl_iterator<_XprType,Direction> >
{
typedef _XprType XprType;
typedef subvector_stl_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
typedef subvector_stl_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
};
template<typename XprType, DirectionType Direction> template<typename XprType, DirectionType Direction>
class subvector_stl_iterator : public indexed_based_stl_iterator_base<XprType, subvector_stl_iterator<XprType,Direction> > class subvector_stl_iterator : public indexed_based_stl_iterator_base<subvector_stl_iterator<XprType,Direction> >
{ {
protected: protected:
enum { is_lvalue = internal::is_lvalue<XprType>::value }; enum { is_lvalue = internal::is_lvalue<XprType>::value };
typedef indexed_based_stl_iterator_base<XprType,subvector_stl_iterator> Base; typedef indexed_based_stl_iterator_base<subvector_stl_iterator> Base;
using Base::m_index; using Base::m_index;
using Base::mp_xpr; using Base::mp_xpr;

View File

@ -66,9 +66,15 @@ void check_begin_end_for_loop(Xpr xpr)
{ {
// simple API check // simple API check
typename Xpr::const_iterator cit; typename Xpr::const_iterator cit = xpr.begin();
cit = xpr.begin();
cit = xpr.cbegin(); cit = xpr.cbegin();
#if EIGEN_HAS_CXX11
auto tmp1 = xpr.begin();
VERIFY(tmp1==xpr.begin());
auto tmp2 = xpr.cbegin();
VERIFY(tmp2==xpr.cbegin());
#endif
} }
VERIFY( xpr.end() -xpr.begin() == xpr.size() ); VERIFY( xpr.end() -xpr.begin() == xpr.size() );
@ -150,8 +156,9 @@ void test_stl_iterators(int rows=Rows, int cols=Cols)
{ {
check_begin_end_for_loop(v); check_begin_end_for_loop(v);
check_begin_end_for_loop(v.col(internal::random<Index>(0,A.cols()-1))); check_begin_end_for_loop(A.col(internal::random<Index>(0,A.cols()-1)));
check_begin_end_for_loop(v.row(internal::random<Index>(0,A.rows()-1))); check_begin_end_for_loop(A.row(internal::random<Index>(0,A.rows()-1)));
check_begin_end_for_loop(v+v);
} }
#if EIGEN_HAS_CXX11 #if EIGEN_HAS_CXX11