Add templated subVector<Vertical/Horizonal>(Index) aliases to col/row(Index) methods (plus subVectors<>() to retrieve the number of rows/columns)

This commit is contained in:
Gael Guennebaud 2018-10-02 14:02:34 +02:00
parent 37e29fc893
commit 12487531ce
3 changed files with 37 additions and 18 deletions

View File

@ -186,24 +186,7 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
};
protected:
typedef typename internal::conditional<isVertical,
typename ExpressionType::ColXpr,
typename ExpressionType::RowXpr>::type SubVector;
/** \internal
* \returns the i-th subvector according to the \c Direction */
EIGEN_DEVICE_FUNC
SubVector subVector(Index i)
{
return SubVector(m_matrix.derived(),i);
}
/** \internal
* \returns the number of subvectors in the direction \c Direction */
EIGEN_DEVICE_FUNC
Index subVectors() const
{ return isVertical?m_matrix.cols():m_matrix.rows(); }
template<typename OtherDerived> struct ExtendedType {
typedef Replicate<OtherDerived,
isVertical ? 1 : ExpressionType::RowsAtCompileTime,

View File

@ -1399,3 +1399,32 @@ innerVectors(Index outerStart, Index outerSize) const
IsRowMajor ? outerSize : rows(), IsRowMajor ? cols() : outerSize);
}
/** \returns the i-th subvector (column or vector) according to the \c Direction
* \sa subVectors()
*/
EIGEN_DEVICE_FUNC
template<DirectionType Direction>
typename internal::conditional<Direction==Vertical,ColXpr,RowXpr>::type
subVector(Index i)
{
return typename internal::conditional<Direction==Vertical,ColXpr,RowXpr>::type(derived(),i);
}
/** This is the const version of subVector(Index) */
EIGEN_DEVICE_FUNC
template<DirectionType Direction>
typename internal::conditional<Direction==Vertical,ConstColXpr,ConstRowXpr>::type
subVector(Index i) const
{
return typename internal::conditional<Direction==Vertical,ConstColXpr,ConstRowXpr>::type(derived(),i);
}
/** \returns the number of subvectors (rows or columns) in the direction \c Direction
* \sa subVector(Index)
*/
EIGEN_DEVICE_FUNC
template<DirectionType Direction>
Index subVectors() const
{ return (Direction==Vertical)?cols():rows(); }

View File

@ -220,6 +220,13 @@ template<typename MatrixType> void block(const MatrixType& m)
VERIFY_RAISES_ASSERT( m1.array() *= m1.col(0).array() );
VERIFY_RAISES_ASSERT( m1.array() /= m1.col(0).array() );
}
VERIFY_IS_EQUAL( m1.template subVector<Horizontal>(r1), m1.row(r1) );
VERIFY_IS_APPROX( (m1+m1).template subVector<Horizontal>(r1), (m1+m1).row(r1) );
VERIFY_IS_EQUAL( m1.template subVector<Vertical>(c1), m1.col(c1) );
VERIFY_IS_APPROX( (m1+m1).template subVector<Vertical>(c1), (m1+m1).col(c1) );
VERIFY_IS_EQUAL( m1.template subVectors<Horizontal>(), m1.rows() );
VERIFY_IS_EQUAL( m1.template subVectors<Vertical>(), m1.cols() );
}