mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-06 19:10:36 +08:00
Allow specifying inner & outer stride for CWiseUnaryView - fixes #2398
This commit is contained in:
parent
27a78e4f96
commit
a491c7f898
@ -812,11 +812,11 @@ protected:
|
||||
|
||||
// -------------------- CwiseUnaryView --------------------
|
||||
|
||||
template<typename UnaryOp, typename ArgType>
|
||||
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType>, IndexBased>
|
||||
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType> >
|
||||
template<typename UnaryOp, typename ArgType, typename StrideType>
|
||||
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType, StrideType>, IndexBased>
|
||||
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType, StrideType> >
|
||||
{
|
||||
typedef CwiseUnaryView<UnaryOp, ArgType> XprType;
|
||||
typedef CwiseUnaryView<UnaryOp, ArgType, StrideType> XprType;
|
||||
|
||||
enum {
|
||||
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),
|
||||
|
@ -15,8 +15,8 @@
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
template<typename ViewOp, typename MatrixType>
|
||||
struct traits<CwiseUnaryView<ViewOp, MatrixType> >
|
||||
template<typename ViewOp, typename MatrixType, typename StrideType>
|
||||
struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> >
|
||||
: traits<MatrixType>
|
||||
{
|
||||
typedef typename result_of<
|
||||
@ -30,17 +30,22 @@ struct traits<CwiseUnaryView<ViewOp, MatrixType> >
|
||||
MatrixTypeInnerStride = inner_stride_at_compile_time<MatrixType>::ret,
|
||||
// need to cast the sizeof's from size_t to int explicitly, otherwise:
|
||||
// "error: no integral type can represent all of the enumerator values
|
||||
InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic
|
||||
? int(Dynamic)
|
||||
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)),
|
||||
OuterStrideAtCompileTime = outer_stride_at_compile_time<MatrixType>::ret == Dynamic
|
||||
? int(Dynamic)
|
||||
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar))
|
||||
InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0
|
||||
? (MatrixTypeInnerStride == Dynamic
|
||||
? int(Dynamic)
|
||||
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
|
||||
: int(StrideType::InnerStrideAtCompileTime),
|
||||
|
||||
OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0
|
||||
? (outer_stride_at_compile_time<MatrixType>::ret == Dynamic
|
||||
? int(Dynamic)
|
||||
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
|
||||
: int(StrideType::OuterStrideAtCompileTime)
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
template<typename ViewOp, typename MatrixType, typename StorageKind>
|
||||
template<typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind>
|
||||
class CwiseUnaryViewImpl;
|
||||
|
||||
/** \class CwiseUnaryView
|
||||
@ -56,12 +61,12 @@ class CwiseUnaryViewImpl;
|
||||
*
|
||||
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
|
||||
*/
|
||||
template<typename ViewOp, typename MatrixType>
|
||||
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename internal::traits<MatrixType>::StorageKind>
|
||||
template<typename ViewOp, typename MatrixType, typename StrideType>
|
||||
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base;
|
||||
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>::Base Base;
|
||||
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
|
||||
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
|
||||
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
|
||||
@ -93,22 +98,22 @@ class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename in
|
||||
};
|
||||
|
||||
// Generic API dispatcher
|
||||
template<typename ViewOp, typename XprType, typename StorageKind>
|
||||
template<typename ViewOp, typename XprType, typename StrideType, typename StorageKind>
|
||||
class CwiseUnaryViewImpl
|
||||
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type
|
||||
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type
|
||||
{
|
||||
public:
|
||||
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type Base;
|
||||
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
|
||||
};
|
||||
|
||||
template<typename ViewOp, typename MatrixType>
|
||||
class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
|
||||
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type
|
||||
template<typename ViewOp, typename MatrixType, typename StrideType>
|
||||
class CwiseUnaryViewImpl<ViewOp,MatrixType,StrideType,Dense>
|
||||
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type
|
||||
{
|
||||
public:
|
||||
|
||||
typedef CwiseUnaryView<ViewOp, MatrixType> Derived;
|
||||
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type Base;
|
||||
typedef CwiseUnaryView<ViewOp, MatrixType,StrideType> Derived;
|
||||
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType,StrideType> >::type Base;
|
||||
|
||||
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
||||
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
|
||||
@ -118,12 +123,16 @@ class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
|
||||
{
|
||||
return derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
||||
return StrideType::InnerStrideAtCompileTime != 0
|
||||
? int(StrideType::InnerStrideAtCompileTime)
|
||||
: derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
|
||||
{
|
||||
return derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
||||
return StrideType::OuterStrideAtCompileTime != 0
|
||||
? int(StrideType::OuterStrideAtCompileTime)
|
||||
: derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
||||
}
|
||||
protected:
|
||||
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
|
||||
|
@ -78,7 +78,6 @@ template<typename MatrixType> class Transpose;
|
||||
template<typename MatrixType> class Conjugate;
|
||||
template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
|
||||
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
||||
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
||||
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
||||
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
|
||||
template<typename Decomposition, typename Rhstype> class Solve;
|
||||
@ -108,6 +107,7 @@ template<typename MatrixType, int MapOptions=Unaligned, typename StrideType = St
|
||||
template<typename Derived> class RefBase;
|
||||
template<typename PlainObjectType, int Options = 0,
|
||||
typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref;
|
||||
template<typename ViewOp, typename MatrixType, typename StrideType = Stride<0,0>> class CwiseUnaryView;
|
||||
|
||||
template<typename Derived> class TriangularBase;
|
||||
template<typename MatrixType, unsigned int Mode> class TriangularView;
|
||||
|
@ -61,6 +61,9 @@ set(ei_smoke_test_list
|
||||
mapped_matrix_1
|
||||
mapstaticmethods_1
|
||||
mapstride_1
|
||||
unaryviewstride_1
|
||||
unaryviewstride_2
|
||||
unaryviewstride_3
|
||||
matrix_square_root_1
|
||||
meta
|
||||
minres_2
|
||||
|
@ -194,6 +194,7 @@ ei_add_test(commainitializer)
|
||||
ei_add_test(smallvectors)
|
||||
ei_add_test(mapped_matrix)
|
||||
ei_add_test(mapstride)
|
||||
ei_add_test(unaryviewstride)
|
||||
ei_add_test(mapstaticmethods)
|
||||
ei_add_test(array_cwise)
|
||||
ei_add_test(array_for_matrix)
|
||||
|
39
test/unaryviewstride.cpp
Normal file
39
test/unaryviewstride.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2021 Andrew Johnson <andrew.johnson@arjohnsonau.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
template<int OuterStride,int InnerStride,typename VectorType> void unaryview_stride(const VectorType& m)
|
||||
{
|
||||
typedef typename VectorType::Scalar Scalar;
|
||||
Index rows = m.rows();
|
||||
Index cols = m.cols();
|
||||
VectorType vec = VectorType::Random(rows, cols);
|
||||
|
||||
struct view_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(view_op)
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const Scalar&
|
||||
operator()(const Scalar& v) const { return v; }
|
||||
};
|
||||
|
||||
CwiseUnaryView<view_op, VectorType, Stride<OuterStride,InnerStride>> vec_view(vec);
|
||||
VERIFY(vec_view.outerStride() == (OuterStride == 0 ? 0 : OuterStride));
|
||||
VERIFY(vec_view.innerStride() == (InnerStride == 0 ? 1 : InnerStride));
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(unaryviewstride)
|
||||
{
|
||||
CALL_SUBTEST_1(( unaryview_stride<1,2>(MatrixXf()) ));
|
||||
CALL_SUBTEST_1(( unaryview_stride<0,0>(MatrixXf()) ));
|
||||
CALL_SUBTEST_2(( unaryview_stride<1,2>(VectorXf()) ));
|
||||
CALL_SUBTEST_2(( unaryview_stride<0,0>(VectorXf()) ));
|
||||
CALL_SUBTEST_3(( unaryview_stride<1,2>(RowVectorXf()) ));
|
||||
CALL_SUBTEST_3(( unaryview_stride<0,0>(RowVectorXf()) ));
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user