Allow specifying inner & outer stride for CWiseUnaryView - fixes #2398

This commit is contained in:
Andrew Johnson 2022-01-05 19:24:46 +00:00 committed by Rasmus Munk Larsen
parent 27a78e4f96
commit a491c7f898
6 changed files with 79 additions and 27 deletions

View File

@ -812,11 +812,11 @@ protected:
// -------------------- CwiseUnaryView --------------------
template<typename UnaryOp, typename ArgType>
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType>, IndexBased>
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType> >
template<typename UnaryOp, typename ArgType, typename StrideType>
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType, StrideType>, IndexBased>
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType, StrideType> >
{
typedef CwiseUnaryView<UnaryOp, ArgType> XprType;
typedef CwiseUnaryView<UnaryOp, ArgType, StrideType> XprType;
enum {
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),

View File

@ -15,8 +15,8 @@
namespace Eigen {
namespace internal {
template<typename ViewOp, typename MatrixType>
struct traits<CwiseUnaryView<ViewOp, MatrixType> >
template<typename ViewOp, typename MatrixType, typename StrideType>
struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> >
: traits<MatrixType>
{
typedef typename result_of<
@ -30,17 +30,22 @@ struct traits<CwiseUnaryView<ViewOp, MatrixType> >
MatrixTypeInnerStride = inner_stride_at_compile_time<MatrixType>::ret,
// need to cast the sizeof's from size_t to int explicitly, otherwise:
// "error: no integral type can represent all of the enumerator values
InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic
? int(Dynamic)
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)),
OuterStrideAtCompileTime = outer_stride_at_compile_time<MatrixType>::ret == Dynamic
? int(Dynamic)
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar))
InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0
? (MatrixTypeInnerStride == Dynamic
? int(Dynamic)
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
: int(StrideType::InnerStrideAtCompileTime),
OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0
? (outer_stride_at_compile_time<MatrixType>::ret == Dynamic
? int(Dynamic)
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
: int(StrideType::OuterStrideAtCompileTime)
};
};
}
template<typename ViewOp, typename MatrixType, typename StorageKind>
template<typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind>
class CwiseUnaryViewImpl;
/** \class CwiseUnaryView
@ -56,12 +61,12 @@ class CwiseUnaryViewImpl;
*
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
*/
template<typename ViewOp, typename MatrixType>
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename internal::traits<MatrixType>::StorageKind>
template<typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>
{
public:
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base;
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
@ -93,22 +98,22 @@ class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename in
};
// Generic API dispatcher
template<typename ViewOp, typename XprType, typename StorageKind>
template<typename ViewOp, typename XprType, typename StrideType, typename StorageKind>
class CwiseUnaryViewImpl
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type
{
public:
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type Base;
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
};
template<typename ViewOp, typename MatrixType>
class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type
template<typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryViewImpl<ViewOp,MatrixType,StrideType,Dense>
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type
{
public:
typedef CwiseUnaryView<ViewOp, MatrixType> Derived;
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type Base;
typedef CwiseUnaryView<ViewOp, MatrixType,StrideType> Derived;
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType,StrideType> >::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
@ -118,12 +123,16 @@ class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
{
return derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
return StrideType::InnerStrideAtCompileTime != 0
? int(StrideType::InnerStrideAtCompileTime)
: derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
{
return derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
return StrideType::OuterStrideAtCompileTime != 0
? int(StrideType::OuterStrideAtCompileTime)
: derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
}
protected:
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)

View File

@ -78,7 +78,6 @@ template<typename MatrixType> class Transpose;
template<typename MatrixType> class Conjugate;
template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
template<typename Decomposition, typename Rhstype> class Solve;
@ -108,6 +107,7 @@ template<typename MatrixType, int MapOptions=Unaligned, typename StrideType = St
template<typename Derived> class RefBase;
template<typename PlainObjectType, int Options = 0,
typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref;
template<typename ViewOp, typename MatrixType, typename StrideType = Stride<0,0>> class CwiseUnaryView;
template<typename Derived> class TriangularBase;
template<typename MatrixType, unsigned int Mode> class TriangularView;

View File

@ -61,6 +61,9 @@ set(ei_smoke_test_list
mapped_matrix_1
mapstaticmethods_1
mapstride_1
unaryviewstride_1
unaryviewstride_2
unaryviewstride_3
matrix_square_root_1
meta
minres_2

View File

@ -194,6 +194,7 @@ ei_add_test(commainitializer)
ei_add_test(smallvectors)
ei_add_test(mapped_matrix)
ei_add_test(mapstride)
ei_add_test(unaryviewstride)
ei_add_test(mapstaticmethods)
ei_add_test(array_cwise)
ei_add_test(array_for_matrix)

39
test/unaryviewstride.cpp Normal file
View File

@ -0,0 +1,39 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2021 Andrew Johnson <andrew.johnson@arjohnsonau.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
template<int OuterStride,int InnerStride,typename VectorType> void unaryview_stride(const VectorType& m)
{
typedef typename VectorType::Scalar Scalar;
Index rows = m.rows();
Index cols = m.cols();
VectorType vec = VectorType::Random(rows, cols);
struct view_op {
EIGEN_EMPTY_STRUCT_CTOR(view_op)
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar&
operator()(const Scalar& v) const { return v; }
};
CwiseUnaryView<view_op, VectorType, Stride<OuterStride,InnerStride>> vec_view(vec);
VERIFY(vec_view.outerStride() == (OuterStride == 0 ? 0 : OuterStride));
VERIFY(vec_view.innerStride() == (InnerStride == 0 ? 1 : InnerStride));
}
EIGEN_DECLARE_TEST(unaryviewstride)
{
CALL_SUBTEST_1(( unaryview_stride<1,2>(MatrixXf()) ));
CALL_SUBTEST_1(( unaryview_stride<0,0>(MatrixXf()) ));
CALL_SUBTEST_2(( unaryview_stride<1,2>(VectorXf()) ));
CALL_SUBTEST_2(( unaryview_stride<0,0>(VectorXf()) ));
CALL_SUBTEST_3(( unaryview_stride<1,2>(RowVectorXf()) ));
CALL_SUBTEST_3(( unaryview_stride<0,0>(RowVectorXf()) ));
}