diff --git a/Eigen/src/Core/Reshaped.h b/Eigen/src/Core/Reshaped.h index 42ce2dbae..56fd3519a 100644 --- a/Eigen/src/Core/Reshaped.h +++ b/Eigen/src/Core/Reshaped.h @@ -21,7 +21,7 @@ namespace Eigen { * \tparam XprType the type of the expression in which we are taking a reshape * \tparam Rows the number of rows of the reshape we are taking at compile time (optional) * \tparam Cols the number of columns of the reshape we are taking at compile time (optional) - * \tparam Order + * \tparam Order can be ColMajor or RowMajor, default is ColMajor. * * This class represents an expression of either a fixed-size or dynamic-size reshape. * It is the return type of DenseBase::reshaped(NRowsType,NColsType) and @@ -68,9 +68,8 @@ struct traits > : traits : Dynamic, OuterStrideAtCompileTime = Dynamic, - InOrder = Order, HasDirectAccess = internal::has_direct_access::ret - && (Order==int(AutoOrderValue) || Order==int(XpxStorageOrder)) + && (Order==int(XpxStorageOrder)) && ((evaluator::Flags&LinearAccessBit)==LinearAccessBit), MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits::size) == 0) @@ -324,11 +323,20 @@ struct reshaped_evaluator RowCol; - inline RowCol index_remap(Index rowId, Index colId) const { - const Index nth_elem_idx = colId * m_xpr.rows() + rowId; - const Index actual_col = nth_elem_idx / m_xpr.nestedExpression().rows(); - const Index actual_row = nth_elem_idx % m_xpr.nestedExpression().rows(); - return RowCol(actual_row, actual_col); + inline RowCol index_remap(Index rowId, Index colId) const + { + if(Order==ColMajor) + { + const Index nth_elem_idx = colId * m_xpr.rows() + rowId; + return RowCol(nth_elem_idx % m_xpr.nestedExpression().rows(), + nth_elem_idx / m_xpr.nestedExpression().rows()); + } + else + { + const Index nth_elem_idx = colId + rowId * m_xpr.cols(); + return RowCol(nth_elem_idx / m_xpr.nestedExpression().cols(), + nth_elem_idx % m_xpr.nestedExpression().cols()); + } } EIGEN_DEVICE_FUNC diff --git a/Eigen/src/plugins/ReshapedMethods.h b/Eigen/src/plugins/ReshapedMethods.h index a9b4af7c3..7a11a4bcc 100644 --- a/Eigen/src/plugins/ReshapedMethods.h +++ b/Eigen/src/plugins/ReshapedMethods.h @@ -40,10 +40,12 @@ reshaped(NRowsType nRows, NColsType nCols) template EIGEN_DEVICE_FUNC -inline Reshaped::value,internal::get_fixed_value::value,OrderType::value> +inline Reshaped::value,internal::get_fixed_value::value, + OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> reshaped(NRowsType nRows, NColsType nCols, OrderType) { - return Reshaped::value,internal::get_fixed_value::value,OrderType::value>( + return Reshaped::value,internal::get_fixed_value::value, + OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>( derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); } @@ -59,10 +61,12 @@ reshaped(NRowsType nRows, NColsType nCols) const template EIGEN_DEVICE_FUNC -inline const Reshaped::value,internal::get_fixed_value::value,OrderType::value> +inline const Reshaped::value,internal::get_fixed_value::value, + OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> reshaped(NRowsType nRows, NColsType nCols, OrderType) const { - return Reshaped::value,internal::get_fixed_value::value,OrderType::value>( + return Reshaped::value,internal::get_fixed_value::value, + OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>( derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols)); } diff --git a/test/reshape.cpp b/test/reshape.cpp index 9b2825d86..516dce0ba 100644 --- a/test/reshape.cpp +++ b/test/reshape.cpp @@ -1,6 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // +// Copyright (C) 2017 Gael Guennebaud // Copyright (C) 2014 yoco // // This Source Code Form is subject to the terms of the Mozilla @@ -9,45 +10,44 @@ #include "main.h" -using Eigen::Map; -using Eigen::MatrixXi; - -// just test a 4x4 matrix, enumerate all combination manually, -// so I don't have to do template-meta-programming here. +// just test a 4x4 matrix, enumerate all combination manually template -void reshape_all_size(MatType m) +void reshape4x4(MatType m) { - typedef Eigen::Map MapMat; - // dynamic - VERIFY_IS_EQUAL((m.reshaped( 1, 16)), MapMat(m.data(), 1, 16)); - VERIFY_IS_EQUAL((m.reshaped( 2, 8)), MapMat(m.data(), 2, 8)); - VERIFY_IS_EQUAL((m.reshaped( 4, 4)), MapMat(m.data(), 4, 4)); - VERIFY_IS_EQUAL((m.reshaped( 8, 2)), MapMat(m.data(), 8, 2)); - VERIFY_IS_EQUAL((m.reshaped(16, 1)), MapMat(m.data(), 16, 1)); + if((MatType::Flags&RowMajorBit)==0) + { + typedef Map MapMat; + // dynamic + VERIFY_IS_EQUAL((m.reshaped( 1, 16)), MapMat(m.data(), 1, 16)); + VERIFY_IS_EQUAL((m.reshaped( 2, 8)), MapMat(m.data(), 2, 8)); + VERIFY_IS_EQUAL((m.reshaped( 4, 4)), MapMat(m.data(), 4, 4)); + VERIFY_IS_EQUAL((m.reshaped( 8, 2)), MapMat(m.data(), 8, 2)); + VERIFY_IS_EQUAL((m.reshaped(16, 1)), MapMat(m.data(), 16, 1)); - // static - VERIFY_IS_EQUAL(m.reshaped(fix< 1>, fix<16>), MapMat(m.data(), 1, 16)); - VERIFY_IS_EQUAL(m.reshaped(fix< 2>, fix< 8>), MapMat(m.data(), 2, 8)); - VERIFY_IS_EQUAL(m.reshaped(fix< 4>, fix< 4>), MapMat(m.data(), 4, 4)); - VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2)); - VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1)); + // static + VERIFY_IS_EQUAL(m.reshaped(fix< 1>, fix<16>), MapMat(m.data(), 1, 16)); + VERIFY_IS_EQUAL(m.reshaped(fix< 2>, fix< 8>), MapMat(m.data(), 2, 8)); + VERIFY_IS_EQUAL(m.reshaped(fix< 4>, fix< 4>), MapMat(m.data(), 4, 4)); + VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2)); + VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1)); - // reshape chain - VERIFY_IS_EQUAL( - (m - .reshaped( 1, 16) - .reshaped(fix< 2>,fix< 8>) - .reshaped(16, 1) - .reshaped(fix< 8>,fix< 2>) - .reshaped( 2, 8) - .reshaped(fix< 1>,fix<16>) - .reshaped( 4, 4) - .reshaped(fix<16>,fix< 1>) - .reshaped( 8, 2) - .reshaped(fix< 4>,fix< 4>) - ), - MapMat(m.data(), 4, 4) - ); + // reshape chain + VERIFY_IS_EQUAL( + (m + .reshaped( 1, 16) + .reshaped(fix< 2>,fix< 8>) + .reshaped(16, 1) + .reshaped(fix< 8>,fix< 2>) + .reshaped( 2, 8) + .reshaped(fix< 1>,fix<16>) + .reshaped( 4, 4) + .reshaped(fix<16>,fix< 1>) + .reshaped( 8, 2) + .reshaped(fix< 4>,fix< 4>) + ), + MapMat(m.data(), 4, 4) + ); + } VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data()); VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1); @@ -56,23 +56,43 @@ void reshape_all_size(MatType m) VERIFY_IS_EQUAL(m.reshaped( 2, 8).innerStride(), 1); VERIFY_IS_EQUAL(m.reshaped( 2, 8).outerStride(), 2); - m.reshaped(2,8,ColOrder); + if((MatType::Flags&RowMajorBit)==0) + { + VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8)); + VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8,AutoOrder)); + VERIFY_IS_EQUAL(m.transpose().reshaped(2,8,RowOrder),m.transpose().reshaped(2,8,AutoOrder)); + } + else + { + VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8)); + VERIFY_IS_EQUAL(m.reshaped(2,8,RowOrder),m.reshaped(2,8,AutoOrder)); + VERIFY_IS_EQUAL(m.transpose().reshaped(2,8,ColOrder),m.transpose().reshaped(2,8,AutoOrder)); + VERIFY_IS_EQUAL(m.transpose().reshaped(2,8),m.transpose().reshaped(2,8,AutoOrder)); + } - MatrixXi m28r = m.reshaped(2,8,RowOrder); - std::cout << m28r << "\n"; + MatrixXi m28r1 = m.reshaped(2,8,RowOrder); + MatrixXi m28r2 = m.transpose().reshaped(8,2,ColOrder).transpose(); + VERIFY_IS_EQUAL( m28r1, m28r2); } void test_reshape() { - Eigen::MatrixXi mx = Eigen::MatrixXi::Random(4, 4); - Eigen::Matrix4i m4 = Eigen::Matrix4i::Random(4, 4); + typedef Matrix RowMatrixXi; + typedef Matrix RowMatrix4i; + MatrixXi mx = MatrixXi::Random(4, 4); + Matrix4i m4 = Matrix4i::Random(4, 4); + RowMatrixXi rmx = RowMatrixXi::Random(4, 4); + RowMatrix4i rm4 = RowMatrix4i::Random(4, 4); // test dynamic-size matrix - CALL_SUBTEST(reshape_all_size(mx)); + CALL_SUBTEST(reshape4x4(mx)); // test static-size matrix - CALL_SUBTEST(reshape_all_size(m4)); + CALL_SUBTEST(reshape4x4(m4)); // test dynamic-size const matrix - CALL_SUBTEST(reshape_all_size(static_cast(mx))); + CALL_SUBTEST(reshape4x4(static_cast(mx))); // test static-size const matrix - CALL_SUBTEST(reshape_all_size(static_cast(m4))); + CALL_SUBTEST(reshape4x4(static_cast(m4))); + + CALL_SUBTEST(reshape4x4(rmx)); + CALL_SUBTEST(reshape4x4(rm4)); }