mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Add support for automatic-size deduction in reshaped, e.g.:
mat.reshaped(4,AutoSize); <-> mat.reshaped(4,mat.size()/4);
This commit is contained in:
parent
f8179385bd
commit
3d200257d7
@ -431,6 +431,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/CUDA/Complex.h"
|
||||
|
||||
#include "src/Core/util/IndexedViewHelper.h"
|
||||
#include "src/Core/util/ReshapedHelper.h"
|
||||
#include "src/Core/ArithmeticSequence.h"
|
||||
#include "src/Core/DenseCoeffsBase.h"
|
||||
#include "src/Core/DenseBase.h"
|
||||
|
45
Eigen/src/Core/util/ReshapedHelper.h
Normal file
45
Eigen/src/Core/util/ReshapedHelper.h
Normal file
@ -0,0 +1,45 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
#ifndef EIGEN_RESHAPED_HELPER_H
|
||||
#define EIGEN_RESHAPED_HELPER_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
enum AutoSize_t { AutoSize };
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename SizeType,typename OtherSize, int TotalSize>
|
||||
struct get_compiletime_reshape_size {
|
||||
enum { value = get_fixed_value<SizeType>::value };
|
||||
};
|
||||
|
||||
template<typename SizeType>
|
||||
Index get_runtime_reshape_size(SizeType size, Index /*other*/, Index /*total*/) {
|
||||
return internal::get_runtime_value(size);
|
||||
}
|
||||
|
||||
template<typename OtherSize, int TotalSize>
|
||||
struct get_compiletime_reshape_size<AutoSize_t,OtherSize,TotalSize> {
|
||||
enum {
|
||||
other_size = get_fixed_value<OtherSize>::value,
|
||||
value = (TotalSize==Dynamic || other_size==Dynamic) ? Dynamic : TotalSize / other_size };
|
||||
};
|
||||
|
||||
Index get_runtime_reshape_size(AutoSize_t /*size*/, Index other, Index total) {
|
||||
return total/other;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_RESHAPED_HELPER_H
|
@ -28,62 +28,70 @@ template<typename NRowsType, typename NColsType, typename OrderType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const Reshaped<const Derived,...>
|
||||
reshaped(NRowsType nRows, NColsType nCols, OrderType = ColOrder) const;
|
||||
|
||||
#else
|
||||
|
||||
// This file is automatically included twice to generate const and non-const versions
|
||||
|
||||
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||
#define EIGEN_RESHAPED_METHOD_CONST const
|
||||
#else
|
||||
#define EIGEN_RESHAPED_METHOD_CONST
|
||||
#endif
|
||||
|
||||
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||
|
||||
// This part is included once
|
||||
|
||||
#endif
|
||||
|
||||
template<typename NRowsType, typename NColsType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>
|
||||
reshaped(NRowsType nRows, NColsType nCols)
|
||||
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
|
||||
reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
|
||||
{
|
||||
return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>(
|
||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
||||
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
|
||||
(derived(),
|
||||
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
|
||||
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
|
||||
}
|
||||
|
||||
template<typename NRowsType, typename NColsType, typename OrderType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
||||
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
|
||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
|
||||
reshaped(NRowsType nRows, NColsType nCols, OrderType)
|
||||
reshaped(NRowsType nRows, NColsType nCols, OrderType) EIGEN_RESHAPED_METHOD_CONST
|
||||
{
|
||||
return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
|
||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
||||
}
|
||||
|
||||
|
||||
template<typename NRowsType, typename NColsType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>
|
||||
reshaped(NRowsType nRows, NColsType nCols) const
|
||||
{
|
||||
return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>(
|
||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
||||
}
|
||||
|
||||
template<typename NRowsType, typename NColsType, typename OrderType>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
||||
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
|
||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
|
||||
reshaped(NRowsType nRows, NColsType nCols, OrderType) const
|
||||
{
|
||||
return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
|
||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
||||
(derived(),
|
||||
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
|
||||
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
|
||||
}
|
||||
|
||||
// Views as linear vectors
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const Reshaped<Derived,SizeAtCompileTime,1>
|
||||
operator()(const Eigen::internal::all_t&)
|
||||
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>
|
||||
operator()(const Eigen::internal::all_t&) EIGEN_RESHAPED_METHOD_CONST
|
||||
{
|
||||
return Reshaped<Derived,SizeAtCompileTime,1>(derived(),size(),1);
|
||||
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>(derived(),size(),1);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const Reshaped<const Derived,SizeAtCompileTime,1>
|
||||
operator()(const Eigen::internal::all_t&) const
|
||||
{
|
||||
return Reshaped<const Derived,SizeAtCompileTime,1>(derived(),size(),1);
|
||||
}
|
||||
#undef EIGEN_RESHAPED_METHOD_CONST
|
||||
|
||||
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||
#define EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||
#include "ReshapedMethods.h"
|
||||
#undef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||
#endif
|
||||
|
||||
#endif // EIGEN_PARSED_BY_DOXYGEN
|
||||
|
@ -17,10 +17,48 @@ is_same_eq(const T1& a, const T2& b)
|
||||
return (a.array() == b.array()).all();
|
||||
}
|
||||
|
||||
template <typename MatType,typename OrderType>
|
||||
void check_auto_reshape4x4(MatType m,OrderType order)
|
||||
{
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 4> v4( 4);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
|
||||
|
||||
VERIFY(is_same_eq(m.reshaped( 1, AutoSize, order), m.reshaped( 1, 16, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 16, order), m.reshaped( 1, 16, order)));
|
||||
VERIFY(is_same_eq(m.reshaped( 2, AutoSize, order), m.reshaped( 2, 8, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 8, order), m.reshaped( 2, 8, order)));
|
||||
VERIFY(is_same_eq(m.reshaped( 4, AutoSize, order), m.reshaped( 4, 4, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 4, order), m.reshaped( 4, 4, order)));
|
||||
VERIFY(is_same_eq(m.reshaped( 8, AutoSize, order), m.reshaped( 8, 2, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 2, order), m.reshaped( 8, 2, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(16, AutoSize, order), m.reshaped(16, 1, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 1, order), m.reshaped(16, 1, order)));
|
||||
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize, order), m.reshaped(fix< 1>, v16, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>, order), m.reshaped( v1, fix<16>, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize, order), m.reshaped(fix< 2>, v8, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>, order), m.reshaped( v2, fix< 8>, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize, order), m.reshaped(fix< 4>, v4, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>, order), m.reshaped( v4, fix< 4>, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize, order), m.reshaped(fix< 8>, v2, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>, order), m.reshaped( v8, fix< 2>, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize, order), m.reshaped(fix<16>, v1, order)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>, order), m.reshaped(v16, fix< 1>, order)));
|
||||
}
|
||||
|
||||
// just test a 4x4 matrix, enumerate all combination manually
|
||||
template <typename MatType>
|
||||
void reshape4x4(MatType m)
|
||||
{
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 4> v4( 4);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
|
||||
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
|
||||
|
||||
if((MatType::Flags&RowMajorBit)==0)
|
||||
{
|
||||
typedef Map<MatrixXi> MapMat;
|
||||
@ -38,6 +76,7 @@ void reshape4x4(MatType m)
|
||||
VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2));
|
||||
VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1));
|
||||
|
||||
|
||||
// reshape chain
|
||||
VERIFY_IS_EQUAL(
|
||||
(m
|
||||
@ -56,6 +95,35 @@ void reshape4x4(MatType m)
|
||||
);
|
||||
}
|
||||
|
||||
VERIFY(is_same_eq(m.reshaped( 1, AutoSize), m.reshaped( 1, 16)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 16), m.reshaped( 1, 16)));
|
||||
VERIFY(is_same_eq(m.reshaped( 2, AutoSize), m.reshaped( 2, 8)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 8), m.reshaped( 2, 8)));
|
||||
VERIFY(is_same_eq(m.reshaped( 4, AutoSize), m.reshaped( 4, 4)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 4), m.reshaped( 4, 4)));
|
||||
VERIFY(is_same_eq(m.reshaped( 8, AutoSize), m.reshaped( 8, 2)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 2), m.reshaped( 8, 2)));
|
||||
VERIFY(is_same_eq(m.reshaped(16, AutoSize), m.reshaped(16, 1)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, 1), m.reshaped(16, 1)));
|
||||
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize), m.reshaped(fix< 1>, v16)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>), m.reshaped( v1, fix<16>)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize), m.reshaped(fix< 2>, v8)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>), m.reshaped( v2, fix< 8>)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize), m.reshaped(fix< 4>, v4)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>), m.reshaped( v4, fix< 4>)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize), m.reshaped(fix< 8>, v2)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>), m.reshaped( v8, fix< 2>)));
|
||||
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize), m.reshaped(fix<16>, v1)));
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>), m.reshaped(v16, fix< 1>)));
|
||||
|
||||
check_auto_reshape4x4(m,ColOrder);
|
||||
check_auto_reshape4x4(m,RowOrder);
|
||||
check_auto_reshape4x4(m,AutoOrder);
|
||||
check_auto_reshape4x4(m.transpose(),ColOrder);
|
||||
check_auto_reshape4x4(m.transpose(),RowOrder);
|
||||
check_auto_reshape4x4(m.transpose(),AutoOrder);
|
||||
|
||||
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
|
||||
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
|
||||
|
||||
@ -82,12 +150,15 @@ void reshape4x4(MatType m)
|
||||
VERIFY_IS_EQUAL( m28r1, m28r2);
|
||||
|
||||
using placeholders::all;
|
||||
VERIFY(is_same_eq(m.reshaped(fix<MatType::SizeAtCompileTime>(m.size()),fix<1>), m(all)));
|
||||
VERIFY(is_same_eq(m.reshaped(v16,fix<1>), m(all)));
|
||||
VERIFY_IS_EQUAL(m.reshaped(16,1), m(all));
|
||||
VERIFY_IS_EQUAL(m.reshaped(1,16), m(all).transpose());
|
||||
VERIFY_IS_EQUAL(m(all).reshaped(2,8), m.reshaped(2,8));
|
||||
VERIFY_IS_EQUAL(m(all).reshaped(4,4), m.reshaped(4,4));
|
||||
VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2));
|
||||
|
||||
VERIFY(is_same_eq(m.reshaped(AutoSize,fix<1>), m(all)));
|
||||
VERIFY_IS_EQUAL(m.reshaped(fix<1>,AutoSize,RowOrder), m.transpose()(all).transpose());
|
||||
}
|
||||
|
||||
void test_reshape()
|
||||
|
Loading…
Reference in New Issue
Block a user