mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Add support for automatic-size deduction in reshaped, e.g.:
mat.reshaped(4,AutoSize); <-> mat.reshaped(4,mat.size()/4);
This commit is contained in:
parent
f8179385bd
commit
3d200257d7
@ -431,6 +431,7 @@ using std::ptrdiff_t;
|
|||||||
#include "src/Core/arch/CUDA/Complex.h"
|
#include "src/Core/arch/CUDA/Complex.h"
|
||||||
|
|
||||||
#include "src/Core/util/IndexedViewHelper.h"
|
#include "src/Core/util/IndexedViewHelper.h"
|
||||||
|
#include "src/Core/util/ReshapedHelper.h"
|
||||||
#include "src/Core/ArithmeticSequence.h"
|
#include "src/Core/ArithmeticSequence.h"
|
||||||
#include "src/Core/DenseCoeffsBase.h"
|
#include "src/Core/DenseCoeffsBase.h"
|
||||||
#include "src/Core/DenseBase.h"
|
#include "src/Core/DenseBase.h"
|
||||||
|
45
Eigen/src/Core/util/ReshapedHelper.h
Normal file
45
Eigen/src/Core/util/ReshapedHelper.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef EIGEN_RESHAPED_HELPER_H
|
||||||
|
#define EIGEN_RESHAPED_HELPER_H
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
|
||||||
|
enum AutoSize_t { AutoSize };
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename SizeType,typename OtherSize, int TotalSize>
|
||||||
|
struct get_compiletime_reshape_size {
|
||||||
|
enum { value = get_fixed_value<SizeType>::value };
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename SizeType>
|
||||||
|
Index get_runtime_reshape_size(SizeType size, Index /*other*/, Index /*total*/) {
|
||||||
|
return internal::get_runtime_value(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherSize, int TotalSize>
|
||||||
|
struct get_compiletime_reshape_size<AutoSize_t,OtherSize,TotalSize> {
|
||||||
|
enum {
|
||||||
|
other_size = get_fixed_value<OtherSize>::value,
|
||||||
|
value = (TotalSize==Dynamic || other_size==Dynamic) ? Dynamic : TotalSize / other_size };
|
||||||
|
};
|
||||||
|
|
||||||
|
Index get_runtime_reshape_size(AutoSize_t /*size*/, Index other, Index total) {
|
||||||
|
return total/other;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
#endif // EIGEN_RESHAPED_HELPER_H
|
@ -28,62 +28,70 @@ template<typename NRowsType, typename NColsType, typename OrderType>
|
|||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const Reshaped<const Derived,...>
|
inline const Reshaped<const Derived,...>
|
||||||
reshaped(NRowsType nRows, NColsType nCols, OrderType = ColOrder) const;
|
reshaped(NRowsType nRows, NColsType nCols, OrderType = ColOrder) const;
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
// This file is automatically included twice to generate const and non-const versions
|
||||||
|
|
||||||
|
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||||
|
#define EIGEN_RESHAPED_METHOD_CONST const
|
||||||
|
#else
|
||||||
|
#define EIGEN_RESHAPED_METHOD_CONST
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||||
|
|
||||||
|
// This part is included once
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
template<typename NRowsType, typename NColsType>
|
template<typename NRowsType, typename NColsType>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>
|
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||||
reshaped(NRowsType nRows, NColsType nCols)
|
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||||
|
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
|
||||||
|
reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
|
||||||
{
|
{
|
||||||
return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>(
|
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||||
|
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value>
|
||||||
|
(derived(),
|
||||||
|
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
|
||||||
|
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename NRowsType, typename NColsType, typename OrderType>
|
template<typename NRowsType, typename NColsType, typename OrderType>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||||
|
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||||
|
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
|
||||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
|
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
|
||||||
reshaped(NRowsType nRows, NColsType nCols, OrderType)
|
reshaped(NRowsType nRows, NColsType nCols, OrderType) EIGEN_RESHAPED_METHOD_CONST
|
||||||
{
|
{
|
||||||
return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
|
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename NRowsType, typename NColsType>
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>
|
|
||||||
reshaped(NRowsType nRows, NColsType nCols) const
|
|
||||||
{
|
|
||||||
return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value>(
|
|
||||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename NRowsType, typename NColsType, typename OrderType>
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
|
||||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
|
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
|
||||||
reshaped(NRowsType nRows, NColsType nCols, OrderType) const
|
(derived(),
|
||||||
{
|
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
|
||||||
return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
|
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
|
||||||
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
|
|
||||||
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Views as linear vectors
|
// Views as linear vectors
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const Reshaped<Derived,SizeAtCompileTime,1>
|
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>
|
||||||
operator()(const Eigen::internal::all_t&)
|
operator()(const Eigen::internal::all_t&) EIGEN_RESHAPED_METHOD_CONST
|
||||||
{
|
{
|
||||||
return Reshaped<Derived,SizeAtCompileTime,1>(derived(),size(),1);
|
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,SizeAtCompileTime,1>(derived(),size(),1);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
#undef EIGEN_RESHAPED_METHOD_CONST
|
||||||
inline const Reshaped<const Derived,SizeAtCompileTime,1>
|
|
||||||
operator()(const Eigen::internal::all_t&) const
|
#ifndef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||||
{
|
#define EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||||
return Reshaped<const Derived,SizeAtCompileTime,1>(derived(),size(),1);
|
#include "ReshapedMethods.h"
|
||||||
}
|
#undef EIGEN_RESHAPED_METHOD_2ND_PASS
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif // EIGEN_PARSED_BY_DOXYGEN
|
#endif // EIGEN_PARSED_BY_DOXYGEN
|
||||||
|
@ -17,10 +17,48 @@ is_same_eq(const T1& a, const T2& b)
|
|||||||
return (a.array() == b.array()).all();
|
return (a.array() == b.array()).all();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename MatType,typename OrderType>
|
||||||
|
void check_auto_reshape4x4(MatType m,OrderType order)
|
||||||
|
{
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 4> v4( 4);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
|
||||||
|
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 1, AutoSize, order), m.reshaped( 1, 16, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 16, order), m.reshaped( 1, 16, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 2, AutoSize, order), m.reshaped( 2, 8, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 8, order), m.reshaped( 2, 8, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 4, AutoSize, order), m.reshaped( 4, 4, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 4, order), m.reshaped( 4, 4, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 8, AutoSize, order), m.reshaped( 8, 2, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 2, order), m.reshaped( 8, 2, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(16, AutoSize, order), m.reshaped(16, 1, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 1, order), m.reshaped(16, 1, order)));
|
||||||
|
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize, order), m.reshaped(fix< 1>, v16, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>, order), m.reshaped( v1, fix<16>, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize, order), m.reshaped(fix< 2>, v8, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>, order), m.reshaped( v2, fix< 8>, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize, order), m.reshaped(fix< 4>, v4, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>, order), m.reshaped( v4, fix< 4>, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize, order), m.reshaped(fix< 8>, v2, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>, order), m.reshaped( v8, fix< 2>, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize, order), m.reshaped(fix<16>, v1, order)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>, order), m.reshaped(v16, fix< 1>, order)));
|
||||||
|
}
|
||||||
|
|
||||||
// just test a 4x4 matrix, enumerate all combination manually
|
// just test a 4x4 matrix, enumerate all combination manually
|
||||||
template <typename MatType>
|
template <typename MatType>
|
||||||
void reshape4x4(MatType m)
|
void reshape4x4(MatType m)
|
||||||
{
|
{
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 4> v4( 4);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
|
||||||
|
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
|
||||||
|
|
||||||
if((MatType::Flags&RowMajorBit)==0)
|
if((MatType::Flags&RowMajorBit)==0)
|
||||||
{
|
{
|
||||||
typedef Map<MatrixXi> MapMat;
|
typedef Map<MatrixXi> MapMat;
|
||||||
@ -38,6 +76,7 @@ void reshape4x4(MatType m)
|
|||||||
VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2));
|
VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2));
|
||||||
VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1));
|
VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1));
|
||||||
|
|
||||||
|
|
||||||
// reshape chain
|
// reshape chain
|
||||||
VERIFY_IS_EQUAL(
|
VERIFY_IS_EQUAL(
|
||||||
(m
|
(m
|
||||||
@ -56,6 +95,35 @@ void reshape4x4(MatType m)
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 1, AutoSize), m.reshaped( 1, 16)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 16), m.reshaped( 1, 16)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 2, AutoSize), m.reshaped( 2, 8)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 8), m.reshaped( 2, 8)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 4, AutoSize), m.reshaped( 4, 4)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 4), m.reshaped( 4, 4)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped( 8, AutoSize), m.reshaped( 8, 2)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 2), m.reshaped( 8, 2)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(16, AutoSize), m.reshaped(16, 1)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, 1), m.reshaped(16, 1)));
|
||||||
|
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize), m.reshaped(fix< 1>, v16)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>), m.reshaped( v1, fix<16>)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize), m.reshaped(fix< 2>, v8)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>), m.reshaped( v2, fix< 8>)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize), m.reshaped(fix< 4>, v4)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>), m.reshaped( v4, fix< 4>)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize), m.reshaped(fix< 8>, v2)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>), m.reshaped( v8, fix< 2>)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize), m.reshaped(fix<16>, v1)));
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>), m.reshaped(v16, fix< 1>)));
|
||||||
|
|
||||||
|
check_auto_reshape4x4(m,ColOrder);
|
||||||
|
check_auto_reshape4x4(m,RowOrder);
|
||||||
|
check_auto_reshape4x4(m,AutoOrder);
|
||||||
|
check_auto_reshape4x4(m.transpose(),ColOrder);
|
||||||
|
check_auto_reshape4x4(m.transpose(),RowOrder);
|
||||||
|
check_auto_reshape4x4(m.transpose(),AutoOrder);
|
||||||
|
|
||||||
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
|
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
|
||||||
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
|
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
|
||||||
|
|
||||||
@ -82,12 +150,15 @@ void reshape4x4(MatType m)
|
|||||||
VERIFY_IS_EQUAL( m28r1, m28r2);
|
VERIFY_IS_EQUAL( m28r1, m28r2);
|
||||||
|
|
||||||
using placeholders::all;
|
using placeholders::all;
|
||||||
VERIFY(is_same_eq(m.reshaped(fix<MatType::SizeAtCompileTime>(m.size()),fix<1>), m(all)));
|
VERIFY(is_same_eq(m.reshaped(v16,fix<1>), m(all)));
|
||||||
VERIFY_IS_EQUAL(m.reshaped(16,1), m(all));
|
VERIFY_IS_EQUAL(m.reshaped(16,1), m(all));
|
||||||
VERIFY_IS_EQUAL(m.reshaped(1,16), m(all).transpose());
|
VERIFY_IS_EQUAL(m.reshaped(1,16), m(all).transpose());
|
||||||
VERIFY_IS_EQUAL(m(all).reshaped(2,8), m.reshaped(2,8));
|
VERIFY_IS_EQUAL(m(all).reshaped(2,8), m.reshaped(2,8));
|
||||||
VERIFY_IS_EQUAL(m(all).reshaped(4,4), m.reshaped(4,4));
|
VERIFY_IS_EQUAL(m(all).reshaped(4,4), m.reshaped(4,4));
|
||||||
VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2));
|
VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2));
|
||||||
|
|
||||||
|
VERIFY(is_same_eq(m.reshaped(AutoSize,fix<1>), m(all)));
|
||||||
|
VERIFY_IS_EQUAL(m.reshaped(fix<1>,AutoSize,RowOrder), m.transpose()(all).transpose());
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_reshape()
|
void test_reshape()
|
||||||
|
Loading…
Reference in New Issue
Block a user