Fallback Reshaped to MapBase when possible (same storage order and linear access to the nested expression)

This commit is contained in:
Gael Guennebaud 2017-02-11 15:32:53 +01:00
parent 83d6a529c3
commit 4b22048cea
3 changed files with 156 additions and 120 deletions

View File

@ -57,34 +57,37 @@ struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
ColsAtCompileTime = Cols,
MaxRowsAtCompileTime = Rows,
MaxColsAtCompileTime = Cols,
XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
IsRowMajor = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? 1
: (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? 0
: XprTypeIsRowMajor,
HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
InnerSize = IsRowMajor ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
XpxStorageOrder = ((int(traits<XprType>::Flags) & RowMajorBit) == RowMajorBit) ? RowMajor : ColMajor,
ReshapedStorageOrder = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? RowMajor
: (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? ColMajor
: XpxStorageOrder,
HasSameStorageOrderAsXprType = (ReshapedStorageOrder == XpxStorageOrder),
InnerSize = (ReshapedStorageOrder==RowMajor) ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
InnerStrideAtCompileTime = HasSameStorageOrderAsXprType
? int(inner_stride_at_compile_time<XprType>::ret)
: int(outer_stride_at_compile_time<XprType>::ret),
OuterStrideAtCompileTime = HasSameStorageOrderAsXprType
? int(outer_stride_at_compile_time<XprType>::ret)
: int(inner_stride_at_compile_time<XprType>::ret),
: Dynamic,
OuterStrideAtCompileTime = Dynamic,
InOrder = Order,
HasDirectAccess = internal::has_direct_access<XprType>::ret
&& (Order==int(AutoOrderValue) || Order==int(XpxStorageOrder))
&& ((evaluator<XprType>::Flags&LinearAccessBit)==LinearAccessBit),
MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0)
&& (InnerStrideAtCompileTime == 1)
? PacketAccessBit : 0,
//MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0,
FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit)
& ~DirectAccessBit,
FlagsRowMajorBit = (ReshapedStorageOrder==RowMajor) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit),
Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit)
Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit | FlagsDirectAccessBit)
};
};
template<typename XprType, int Rows=Dynamic, int Cols=Dynamic, int Order = 0,
bool HasDirectAccess = internal::has_direct_access<XprType>::ret> class ReshapedImpl_dense;
template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess> class ReshapedImpl_dense;
} // end namespace internal
@ -127,9 +130,9 @@ template<typename XprType, int Rows, int Cols, int Order> class Reshaped
// that must be specialized for direct and non-direct access...
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl<XprType, Rows, Cols, Order, Dense>
: public internal::ReshapedImpl_dense<XprType, Rows, Cols, Order>
: public internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess>
{
typedef internal::ReshapedImpl_dense<XprType, Rows, Cols, Order> Impl;
typedef internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess> Impl;
public:
typedef Impl Base;
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl)
@ -140,8 +143,9 @@ class ReshapedImpl<XprType, Rows, Cols, Order, Dense>
namespace internal {
/** \internal Internal implementation of dense Reshapeds in the general case. */
template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess> class ReshapedImpl_dense
/** \internal Internal implementation of dense Reshaped in the general case. */
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl_dense<XprType,Rows,Cols,Order,false>
: public internal::dense_xpr_base<Reshaped<XprType, Rows, Cols, Order> >::type
{
typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
@ -166,8 +170,7 @@ template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess>
/** Dynamic-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr,
Index nRows, Index nCols)
inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
: m_xpr(xpr), m_rows(nRows), m_cols(nCols)
{}
@ -199,8 +202,106 @@ template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess>
};
/** \internal Internal implementation of dense Reshaped in the direct access case. */
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl_dense<XprType, Rows, Cols, Order, true>
: public MapBase<Reshaped<XprType, Rows, Cols, Order> >
{
typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
typedef typename internal::ref_selector<XprType>::non_const_type XprTypeNested;
public:
typedef MapBase<ReshapedType> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
/** Fixed-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr)
: Base(xpr.data()), m_xpr(xpr)
{}
/** Dynamic-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
: Base(xpr.data(), nRows, nCols),
m_xpr(xpr)
{}
EIGEN_DEVICE_FUNC
const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
{
return m_xpr;
}
EIGEN_DEVICE_FUNC
XprType& nestedExpression() { return m_xpr; }
/** \sa MapBase::innerStride() */
EIGEN_DEVICE_FUNC
inline Index innerStride() const
{
return m_xpr.innerStride();
}
/** \sa MapBase::outerStride() */
EIGEN_DEVICE_FUNC
inline Index outerStride() const
{
return ((Flags&RowMajorBit)==RowMajorBit) ? this->cols() : this->rows();
}
protected:
XprTypeNested m_xpr;
};
// Evaluators
template<typename ArgType, int Rows, int Cols, int Order, bool HasDirectAccess> struct reshaped_evaluator;
template<typename ArgType, int Rows, int Cols, int Order>
struct unary_evaluator<Reshaped<ArgType, Rows, Cols, Order>, IndexBased>
struct evaluator<Reshaped<ArgType, Rows, Cols, Order> >
: reshaped_evaluator<ArgType, Rows, Cols, Order, traits<Reshaped<ArgType,Rows,Cols,Order> >::HasDirectAccess>
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
typedef typename XprType::Scalar Scalar;
// TODO: should check for smaller packet types
typedef typename packet_traits<Scalar>::type PacketScalar;
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
HasDirectAccess = traits<XprType>::HasDirectAccess,
// RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
// ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
// MaxRowsAtCompileTime = traits<XprType>::MaxRowsAtCompileTime,
// MaxColsAtCompileTime = traits<XprType>::MaxColsAtCompileTime,
//
// InnerStrideAtCompileTime = traits<XprType>::HasSameStorageOrderAsXprType
// ? int(inner_stride_at_compile_time<ArgType>::ret)
// : Dynamic,
// OuterStrideAtCompileTime = Dynamic,
FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1 || HasDirectAccess) ? LinearAccessBit : 0,
FlagsRowMajorBit = (traits<XprType>::ReshapedStorageOrder==RowMajor) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit),
Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit | FlagsDirectAccessBit,
PacketAlignment = unpacket_traits<PacketScalar>::alignment,
Alignment = evaluator<ArgType>::Alignment
};
typedef reshaped_evaluator<ArgType, Rows, Cols, Order, HasDirectAccess> reshaped_evaluator_type;
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : reshaped_evaluator_type(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
};
template<typename ArgType, int Rows, int Cols, int Order>
struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ false>
: evaluator_base<Reshaped<ArgType, Rows, Cols, Order> >
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
@ -213,7 +314,7 @@ struct unary_evaluator<Reshaped<ArgType, Rows, Cols, Order>, IndexBased>
Alignment = 0
};
EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
@ -321,103 +422,21 @@ protected:
};
template<typename ArgType, int Rows, int Cols, int Order>
struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ true>
: mapbase_evaluator<Reshaped<ArgType, Rows, Cols, Order>,
typename Reshaped<ArgType, Rows, Cols, Order>::PlainObject>
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
typedef typename XprType::Scalar Scalar;
///** \internal Internal implementation of dense Reshapeds in the direct access case.*/
//template<typename XprType, int Rows, int Cols, int Order>
//class ReshapedImpl_dense<XprType,ReshapedRows,ReshapedCols, true>
// : public MapBase<Reshaped<XprType, Rows, Cols, Order> >
//{
// typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
// public:
//
// typedef MapBase<ReshapedType> Base;
// EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
// EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
//
// /** Column or Row constructor
// */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr, Index i)
// : Base(internal::const_cast_ptr(&xpr.coeffRef(
// (ReshapedRows==1) && (ReshapedCols==XprType::ColsAtCompileTime) ? i : 0,
// (ReshapedRows==XprType::RowsAtCompileTime) && (ReshapedCols==1) ? i : 0)),
// ReshapedRows==1 ? 1 : xpr.rows(),
// ReshapedCols==1 ? 1 : xpr.cols()),
// m_xpr(xpr)
// {
// init();
// }
//
// /** Fixed-size constructor
// */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr)
// : Base(internal::const_cast_ptr(&xpr.coeffRef(0, 0))), m_xpr(xpr)
// {
// init();
// }
//
// /** Dynamic-size constructor
// */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr,
// Index reshapeRows, Index reshapeCols)
// : Base(internal::const_cast_ptr(&xpr.coeffRef(0, 0)), reshapeRows, reshapeCols),
// m_xpr(xpr)
// {
// init();
// }
//
// EIGEN_DEVICE_FUNC
// const typename internal::remove_all<typename XprType::Nested>::type& nestedExpression() const
// {
// return m_xpr;
// }
//
// EIGEN_DEVICE_FUNC
// /** \sa MapBase::innerStride() */
// inline Index innerStride() const
// {
// return internal::traits<ReshapedType>::HasSameStorageOrderAsXprType
// ? m_xpr.innerStride()
// : m_xpr.outerStride();
// }
//
// EIGEN_DEVICE_FUNC
// /** \sa MapBase::outerStride() */
// inline Index outerStride() const
// {
// return m_outerStride;
// }
//
// #ifndef __SUNPRO_CC
// // FIXME sunstudio is not friendly with the above friend...
// // META-FIXME there is no 'friend' keyword around here. Is this obsolete?
// protected:
// #endif
//
// #ifndef EIGEN_PARSED_BY_DOXYGEN
// /** \internal used by allowAligned() */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr, const Scalar* data, Index reshapeRows, Index reshapeCols)
// : Base(data, reshapeRows, reshapeCols), m_xpr(xpr)
// {
// init();
// }
// #endif
//
// protected:
// EIGEN_DEVICE_FUNC
// void init()
// {
// m_outerStride = internal::traits<ReshapedType>::HasSameStorageOrderAsXprType
// ? m_xpr.outerStride()
// : m_xpr.innerStride();
// }
//
// typename XprType::Nested m_xpr;
// Index m_outerStride;
//};
EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr)
: mapbase_evaluator<XprType, typename XprType::PlainObject>(xpr)
{
// TODO: for the 3.4 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime
eigen_assert(((internal::UIntPtr(xpr.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned");
}
};
} // end namespace internal

View File

@ -265,6 +265,11 @@ static const auto fix(int val);
#endif // EIGEN_PARSED_BY_DOXYGEN
const int AutoOrderValue = 2;
const internal::FixedInt<ColMajor> ColOrder;
const internal::FixedInt<RowMajor> RowOrder;
const internal::FixedInt<AutoOrderValue> AutoOrder;
} // end namespace Eigen
#endif // EIGEN_INTEGRAL_CONSTANT_H

View File

@ -48,6 +48,18 @@ void reshape_all_size(MatType m)
),
MapMat(m.data(), 4, 4)
);
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
VERIFY_IS_EQUAL(m.reshaped( 2, 8).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 2, 8).innerStride(), 1);
VERIFY_IS_EQUAL(m.reshaped( 2, 8).outerStride(), 2);
m.reshaped(2,8,ColOrder);
MatrixXi m28r = m.reshaped(2,8,RowOrder);
std::cout << m28r << "\n";
}
void test_reshape()