From 3d200257d73d99a1f37b1cb23ce52b80264ba0d9 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 21 Feb 2017 15:57:25 +0100 Subject: [PATCH] Add support for automatic-size deduction in reshaped, e.g.: mat.reshaped(4,AutoSize); <-> mat.reshaped(4,mat.size()/4); --- Eigen/Core | 1 + Eigen/src/Core/util/ReshapedHelper.h | 45 +++++++++++++++ Eigen/src/plugins/ReshapedMethods.h | 86 +++++++++++++++------------- test/reshape.cpp | 73 ++++++++++++++++++++++- 4 files changed, 165 insertions(+), 40 deletions(-) create mode 100644 Eigen/src/Core/util/ReshapedHelper.h diff --git a/Eigen/Core b/Eigen/Core index 1174d7d16..1af688637 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -431,6 +431,7 @@ using std::ptrdiff_t; #include "src/Core/arch/CUDA/Complex.h" #include "src/Core/util/IndexedViewHelper.h" +#include "src/Core/util/ReshapedHelper.h" #include "src/Core/ArithmeticSequence.h" #include "src/Core/DenseCoeffsBase.h" #include "src/Core/DenseBase.h" diff --git a/Eigen/src/Core/util/ReshapedHelper.h b/Eigen/src/Core/util/ReshapedHelper.h new file mode 100644 index 000000000..b5d59cfe8 --- /dev/null +++ b/Eigen/src/Core/util/ReshapedHelper.h @@ -0,0 +1,45 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2017 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#ifndef EIGEN_RESHAPED_HELPER_H +#define EIGEN_RESHAPED_HELPER_H + +namespace Eigen { + +enum AutoSize_t { AutoSize }; + +namespace internal { + +template +struct get_compiletime_reshape_size { + enum { value = get_fixed_value::value }; +}; + +template +Index get_runtime_reshape_size(SizeType size, Index /*other*/, Index /*total*/) { + return internal::get_runtime_value(size); +} + +template +struct get_compiletime_reshape_size { + enum { + other_size = get_fixed_value::value, + value = (TotalSize==Dynamic || other_size==Dynamic) ? Dynamic : TotalSize / other_size }; +}; + +Index get_runtime_reshape_size(AutoSize_t /*size*/, Index other, Index total) { + return total/other; +} + +} + +} // end namespace Eigen + +#endif // EIGEN_RESHAPED_HELPER_H diff --git a/Eigen/src/plugins/ReshapedMethods.h b/Eigen/src/plugins/ReshapedMethods.h index fc7cdcfa7..118841798 100644 --- a/Eigen/src/plugins/ReshapedMethods.h +++ b/Eigen/src/plugins/ReshapedMethods.h @@ -28,62 +28,70 @@ template EIGEN_DEVICE_FUNC inline const Reshaped reshaped(NRowsType nRows, NColsType nCols, OrderType = ColOrder) const; + #else + +// This file is automatically included twice to generate const and non-const versions + +#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS +#define EIGEN_RESHAPED_METHOD_CONST const +#else +#define EIGEN_RESHAPED_METHOD_CONST +#endif + +#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS + +// This part is included once + +#endif + template EIGEN_DEVICE_FUNC -inline Reshaped::value,internal::get_fixed_value::value> -reshaped(NRowsType nRows, NColsType nCols) +inline Reshaped::value, + internal::get_compiletime_reshape_size::value> +reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST { - return Reshaped::value,internal::get_fixed_value::value>( - derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); + return Reshaped::value, + internal::get_compiletime_reshape_size::value> + (derived(), + internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()), + internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size())); } template EIGEN_DEVICE_FUNC -inline Reshaped::value,internal::get_fixed_value::value, +inline Reshaped::value, + internal::get_compiletime_reshape_size::value, OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> -reshaped(NRowsType nRows, NColsType nCols, OrderType) +reshaped(NRowsType nRows, NColsType nCols, OrderType) EIGEN_RESHAPED_METHOD_CONST { - return Reshaped::value,internal::get_fixed_value::value, - OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>( - derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); -} - - -template -EIGEN_DEVICE_FUNC -inline const Reshaped::value,internal::get_fixed_value::value> -reshaped(NRowsType nRows, NColsType nCols) const -{ - return Reshaped::value,internal::get_fixed_value::value>( - derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); -} - -template -EIGEN_DEVICE_FUNC -inline const Reshaped::value,internal::get_fixed_value::value, - OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> -reshaped(NRowsType nRows, NColsType nCols, OrderType) const -{ - return Reshaped::value,internal::get_fixed_value::value, - OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>( - derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); + return Reshaped::value, + internal::get_compiletime_reshape_size::value, + OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> + (derived(), + internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()), + internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size())); } // Views as linear vectors EIGEN_DEVICE_FUNC -inline const Reshaped -operator()(const Eigen::internal::all_t&) +inline Reshaped +operator()(const Eigen::internal::all_t&) EIGEN_RESHAPED_METHOD_CONST { - return Reshaped(derived(),size(),1); + return Reshaped(derived(),size(),1); } -EIGEN_DEVICE_FUNC -inline const Reshaped -operator()(const Eigen::internal::all_t&) const -{ - return Reshaped(derived(),size(),1); -} +#undef EIGEN_RESHAPED_METHOD_CONST + +#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS +#define EIGEN_RESHAPED_METHOD_2ND_PASS +#include "ReshapedMethods.h" +#undef EIGEN_RESHAPED_METHOD_2ND_PASS +#endif #endif // EIGEN_PARSED_BY_DOXYGEN diff --git a/test/reshape.cpp b/test/reshape.cpp index a38f5e098..8fe0b9348 100644 --- a/test/reshape.cpp +++ b/test/reshape.cpp @@ -17,10 +17,48 @@ is_same_eq(const T1& a, const T2& b) return (a.array() == b.array()).all(); } +template +void check_auto_reshape4x4(MatType m,OrderType order) +{ + internal::VariableAndFixedInt v1( 1); + internal::VariableAndFixedInt v2( 2); + internal::VariableAndFixedInt v4( 4); + internal::VariableAndFixedInt v8( 8); + internal::VariableAndFixedInt v16(16); + + VERIFY(is_same_eq(m.reshaped( 1, AutoSize, order), m.reshaped( 1, 16, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 16, order), m.reshaped( 1, 16, order))); + VERIFY(is_same_eq(m.reshaped( 2, AutoSize, order), m.reshaped( 2, 8, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 8, order), m.reshaped( 2, 8, order))); + VERIFY(is_same_eq(m.reshaped( 4, AutoSize, order), m.reshaped( 4, 4, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 4, order), m.reshaped( 4, 4, order))); + VERIFY(is_same_eq(m.reshaped( 8, AutoSize, order), m.reshaped( 8, 2, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 2, order), m.reshaped( 8, 2, order))); + VERIFY(is_same_eq(m.reshaped(16, AutoSize, order), m.reshaped(16, 1, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 1, order), m.reshaped(16, 1, order))); + + VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize, order), m.reshaped(fix< 1>, v16, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>, order), m.reshaped( v1, fix<16>, order))); + VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize, order), m.reshaped(fix< 2>, v8, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>, order), m.reshaped( v2, fix< 8>, order))); + VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize, order), m.reshaped(fix< 4>, v4, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>, order), m.reshaped( v4, fix< 4>, order))); + VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize, order), m.reshaped(fix< 8>, v2, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>, order), m.reshaped( v8, fix< 2>, order))); + VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize, order), m.reshaped(fix<16>, v1, order))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>, order), m.reshaped(v16, fix< 1>, order))); +} + // just test a 4x4 matrix, enumerate all combination manually template void reshape4x4(MatType m) { + internal::VariableAndFixedInt v1( 1); + internal::VariableAndFixedInt v2( 2); + internal::VariableAndFixedInt v4( 4); + internal::VariableAndFixedInt v8( 8); + internal::VariableAndFixedInt v16(16); + if((MatType::Flags&RowMajorBit)==0) { typedef Map MapMat; @@ -38,6 +76,7 @@ void reshape4x4(MatType m) VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2)); VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1)); + // reshape chain VERIFY_IS_EQUAL( (m @@ -56,6 +95,35 @@ void reshape4x4(MatType m) ); } + VERIFY(is_same_eq(m.reshaped( 1, AutoSize), m.reshaped( 1, 16))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 16), m.reshaped( 1, 16))); + VERIFY(is_same_eq(m.reshaped( 2, AutoSize), m.reshaped( 2, 8))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 8), m.reshaped( 2, 8))); + VERIFY(is_same_eq(m.reshaped( 4, AutoSize), m.reshaped( 4, 4))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 4), m.reshaped( 4, 4))); + VERIFY(is_same_eq(m.reshaped( 8, AutoSize), m.reshaped( 8, 2))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 2), m.reshaped( 8, 2))); + VERIFY(is_same_eq(m.reshaped(16, AutoSize), m.reshaped(16, 1))); + VERIFY(is_same_eq(m.reshaped(AutoSize, 1), m.reshaped(16, 1))); + + VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize), m.reshaped(fix< 1>, v16))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>), m.reshaped( v1, fix<16>))); + VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize), m.reshaped(fix< 2>, v8))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>), m.reshaped( v2, fix< 8>))); + VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize), m.reshaped(fix< 4>, v4))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>), m.reshaped( v4, fix< 4>))); + VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize), m.reshaped(fix< 8>, v2))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>), m.reshaped( v8, fix< 2>))); + VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize), m.reshaped(fix<16>, v1))); + VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>), m.reshaped(v16, fix< 1>))); + + check_auto_reshape4x4(m,ColOrder); + check_auto_reshape4x4(m,RowOrder); + check_auto_reshape4x4(m,AutoOrder); + check_auto_reshape4x4(m.transpose(),ColOrder); + check_auto_reshape4x4(m.transpose(),RowOrder); + check_auto_reshape4x4(m.transpose(),AutoOrder); + VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data()); VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1); @@ -82,12 +150,15 @@ void reshape4x4(MatType m) VERIFY_IS_EQUAL( m28r1, m28r2); using placeholders::all; - VERIFY(is_same_eq(m.reshaped(fix(m.size()),fix<1>), m(all))); + VERIFY(is_same_eq(m.reshaped(v16,fix<1>), m(all))); VERIFY_IS_EQUAL(m.reshaped(16,1), m(all)); VERIFY_IS_EQUAL(m.reshaped(1,16), m(all).transpose()); VERIFY_IS_EQUAL(m(all).reshaped(2,8), m.reshaped(2,8)); VERIFY_IS_EQUAL(m(all).reshaped(4,4), m.reshaped(4,4)); VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2)); + + VERIFY(is_same_eq(m.reshaped(AutoSize,fix<1>), m(all))); + VERIFY_IS_EQUAL(m.reshaped(fix<1>,AutoSize,RowOrder), m.transpose()(all).transpose()); } void test_reshape()