Add support for automatic-size deduction in reshaped, e.g.:

mat.reshaped(4,AutoSize); <-> mat.reshaped(4,mat.size()/4);
This commit is contained in:
Gael Guennebaud 2017-02-21 15:57:25 +01:00
parent f8179385bd
commit 3d200257d7
4 changed files with 165 additions and 40 deletions

View File

@ -431,6 +431,7 @@ using std::ptrdiff_t;
#include "src/Core/arch/CUDA/Complex.h" #include "src/Core/arch/CUDA/Complex.h"
#include "src/Core/util/IndexedViewHelper.h" #include "src/Core/util/IndexedViewHelper.h"
#include "src/Core/util/ReshapedHelper.h"
#include "src/Core/ArithmeticSequence.h" #include "src/Core/ArithmeticSequence.h"
#include "src/Core/DenseCoeffsBase.h" #include "src/Core/DenseCoeffsBase.h"
#include "src/Core/DenseBase.h" #include "src/Core/DenseBase.h"

View File

@ -0,0 +1,45 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_RESHAPED_HELPER_H
#define EIGEN_RESHAPED_HELPER_H
namespace Eigen {
enum AutoSize_t { AutoSize };
namespace internal {
template<typename SizeType,typename OtherSize, int TotalSize>
struct get_compiletime_reshape_size {
enum { value = get_fixed_value<SizeType>::value };
};
template<typename SizeType>
Index get_runtime_reshape_size(SizeType size, Index /*other*/, Index /*total*/) {
return internal::get_runtime_value(size);
}
template<typename OtherSize, int TotalSize>
struct get_compiletime_reshape_size<AutoSize_t,OtherSize,TotalSize> {
enum {
other_size = get_fixed_value<OtherSize>::value,
value = (TotalSize==Dynamic || other_size==Dynamic) ? Dynamic : TotalSize / other_size };
};
Index get_runtime_reshape_size(AutoSize_t /*size*/, Index other, Index total) {
return total/other;
}
}
} // end namespace Eigen
#endif // EIGEN_RESHAPED_HELPER_H

View File

@ -28,62 +28,70 @@ template<typename NRowsType, typename NColsType, typename OrderType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Reshaped<const Derived,...> inline const Reshaped<const Derived,...>
reshaped(NRowsType nRows, NColsType nCols, OrderType = ColOrder) const; reshaped(NRowsType nRows, NColsType nCols, OrderType = ColOrder) const;
#else #else
// This file is automatically included twice to generate const and non-const versions
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
#define EIGEN_RESHAPED_METHOD_CONST const
#else
#define EIGEN_RESHAPED_METHOD_CONST
#endif
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
// This part is included once
#endif
template<typename NRowsType, typename NColsType> template<typename NRowsType, typename NColsType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value> inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
reshaped(NRowsType nRows, NColsType nCols) internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
{ {
return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>( return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
(derived(),
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
} }
template<typename NRowsType, typename NColsType, typename OrderType> template<typename NRowsType, typename NColsType, typename OrderType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value, inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
reshaped(NRowsType nRows, NColsType nCols, OrderType) reshaped(NRowsType nRows, NColsType nCols, OrderType) EIGEN_RESHAPED_METHOD_CONST
{ {
return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value, return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>( internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
}
template<typename NRowsType, typename NColsType>
EIGEN_DEVICE_FUNC
inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>
reshaped(NRowsType nRows, NColsType nCols) const
{
return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>(
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
}
template<typename NRowsType, typename NColsType, typename OrderType>
EIGEN_DEVICE_FUNC
inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
reshaped(NRowsType nRows, NColsType nCols, OrderType) const (derived(),
{ internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value, internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
} }
// Views as linear vectors // Views as linear vectors
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Reshaped<Derived,SizeAtCompileTime,1> inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>
operator()(const Eigen::internal::all_t&) operator()(const Eigen::internal::all_t&) EIGEN_RESHAPED_METHOD_CONST
{ {
return Reshaped<Derived,SizeAtCompileTime,1>(derived(),size(),1); return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>(derived(),size(),1);
} }
EIGEN_DEVICE_FUNC #undef EIGEN_RESHAPED_METHOD_CONST
inline const Reshaped<const Derived,SizeAtCompileTime,1>
operator()(const Eigen::internal::all_t&) const #ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
{ #define EIGEN_RESHAPED_METHOD_2ND_PASS
return Reshaped<const Derived,SizeAtCompileTime,1>(derived(),size(),1); #include "ReshapedMethods.h"
} #undef EIGEN_RESHAPED_METHOD_2ND_PASS
#endif
#endif // EIGEN_PARSED_BY_DOXYGEN #endif // EIGEN_PARSED_BY_DOXYGEN

View File

@ -17,10 +17,48 @@ is_same_eq(const T1& a, const T2& b)
return (a.array() == b.array()).all(); return (a.array() == b.array()).all();
} }
template <typename MatType,typename OrderType>
void check_auto_reshape4x4(MatType m,OrderType order)
{
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 4> v4( 4);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
VERIFY(is_same_eq(m.reshaped( 1, AutoSize, order), m.reshaped( 1, 16, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 16, order), m.reshaped( 1, 16, order)));
VERIFY(is_same_eq(m.reshaped( 2, AutoSize, order), m.reshaped( 2, 8, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 8, order), m.reshaped( 2, 8, order)));
VERIFY(is_same_eq(m.reshaped( 4, AutoSize, order), m.reshaped( 4, 4, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 4, order), m.reshaped( 4, 4, order)));
VERIFY(is_same_eq(m.reshaped( 8, AutoSize, order), m.reshaped( 8, 2, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 2, order), m.reshaped( 8, 2, order)));
VERIFY(is_same_eq(m.reshaped(16, AutoSize, order), m.reshaped(16, 1, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 1, order), m.reshaped(16, 1, order)));
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize, order), m.reshaped(fix< 1>, v16, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>, order), m.reshaped( v1, fix<16>, order)));
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize, order), m.reshaped(fix< 2>, v8, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>, order), m.reshaped( v2, fix< 8>, order)));
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize, order), m.reshaped(fix< 4>, v4, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>, order), m.reshaped( v4, fix< 4>, order)));
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize, order), m.reshaped(fix< 8>, v2, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>, order), m.reshaped( v8, fix< 2>, order)));
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize, order), m.reshaped(fix<16>, v1, order)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>, order), m.reshaped(v16, fix< 1>, order)));
}
// just test a 4x4 matrix, enumerate all combination manually // just test a 4x4 matrix, enumerate all combination manually
template <typename MatType> template <typename MatType>
void reshape4x4(MatType m) void reshape4x4(MatType m)
{ {
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 4> v4( 4);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
if((MatType::Flags&RowMajorBit)==0) if((MatType::Flags&RowMajorBit)==0)
{ {
typedef Map<MatrixXi> MapMat; typedef Map<MatrixXi> MapMat;
@ -38,6 +76,7 @@ void reshape4x4(MatType m)
VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2)); VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2));
VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1)); VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1));
// reshape chain // reshape chain
VERIFY_IS_EQUAL( VERIFY_IS_EQUAL(
(m (m
@ -56,6 +95,35 @@ void reshape4x4(MatType m)
); );
} }
VERIFY(is_same_eq(m.reshaped( 1, AutoSize), m.reshaped( 1, 16)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 16), m.reshaped( 1, 16)));
VERIFY(is_same_eq(m.reshaped( 2, AutoSize), m.reshaped( 2, 8)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 8), m.reshaped( 2, 8)));
VERIFY(is_same_eq(m.reshaped( 4, AutoSize), m.reshaped( 4, 4)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 4), m.reshaped( 4, 4)));
VERIFY(is_same_eq(m.reshaped( 8, AutoSize), m.reshaped( 8, 2)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 2), m.reshaped( 8, 2)));
VERIFY(is_same_eq(m.reshaped(16, AutoSize), m.reshaped(16, 1)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 1), m.reshaped(16, 1)));
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize), m.reshaped(fix< 1>, v16)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>), m.reshaped( v1, fix<16>)));
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize), m.reshaped(fix< 2>, v8)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>), m.reshaped( v2, fix< 8>)));
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize), m.reshaped(fix< 4>, v4)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>), m.reshaped( v4, fix< 4>)));
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize), m.reshaped(fix< 8>, v2)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>), m.reshaped( v8, fix< 2>)));
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize), m.reshaped(fix<16>, v1)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>), m.reshaped(v16, fix< 1>)));
check_auto_reshape4x4(m,ColOrder);
check_auto_reshape4x4(m,RowOrder);
check_auto_reshape4x4(m,AutoOrder);
check_auto_reshape4x4(m.transpose(),ColOrder);
check_auto_reshape4x4(m.transpose(),RowOrder);
check_auto_reshape4x4(m.transpose(),AutoOrder);
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data()); VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1); VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
@ -82,12 +150,15 @@ void reshape4x4(MatType m)
VERIFY_IS_EQUAL( m28r1, m28r2); VERIFY_IS_EQUAL( m28r1, m28r2);
using placeholders::all; using placeholders::all;
VERIFY(is_same_eq(m.reshaped(fix<MatType::SizeAtCompileTime>(m.size()),fix<1>), m(all))); VERIFY(is_same_eq(m.reshaped(v16,fix<1>), m(all)));
VERIFY_IS_EQUAL(m.reshaped(16,1), m(all)); VERIFY_IS_EQUAL(m.reshaped(16,1), m(all));
VERIFY_IS_EQUAL(m.reshaped(1,16), m(all).transpose()); VERIFY_IS_EQUAL(m.reshaped(1,16), m(all).transpose());
VERIFY_IS_EQUAL(m(all).reshaped(2,8), m.reshaped(2,8)); VERIFY_IS_EQUAL(m(all).reshaped(2,8), m.reshaped(2,8));
VERIFY_IS_EQUAL(m(all).reshaped(4,4), m.reshaped(4,4)); VERIFY_IS_EQUAL(m(all).reshaped(4,4), m.reshaped(4,4));
VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2)); VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2));
VERIFY(is_same_eq(m.reshaped(AutoSize,fix<1>), m(all)));
VERIFY_IS_EQUAL(m.reshaped(fix<1>,AutoSize,RowOrder), m.transpose()(all).transpose());
} }
void test_reshape() void test_reshape()