mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Support reshaping with static shapes and dimensions conversion in tensor broadcasting
This commit is contained in:
parent
9b864cdb37
commit
1b8d70a22b
@ -641,7 +641,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
return;
|
||||
}
|
||||
|
||||
const Dimensions& input_dims = m_impl.dimensions();
|
||||
const Dimensions& input_dims = Dimensions(m_impl.dimensions());
|
||||
|
||||
// Pre-fill input_block_sizes, broadcast_block_sizes,
|
||||
// broadcast_block_strides, and broadcast_tensor_strides. Later on we will
|
||||
|
@ -290,6 +290,16 @@ struct DSizes : array<DenseIndex, NumDims> {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef EIGEN_HAS_INDEX_LIST
|
||||
EIGEN_DEVICE_FUNC
|
||||
template <typename FirstType, typename... OtherTypes>
|
||||
DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) {
|
||||
for (int i = 0; i < dimensions.count; ++i) {
|
||||
(*this)[i] = dimensions[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef EIGEN_EMULATE_CXX11_META_H
|
||||
template <typename std::ptrdiff_t... Indices>
|
||||
EIGEN_DEVICE_FUNC DSizes(const Sizes<Indices...>& a) {
|
||||
|
@ -115,7 +115,7 @@ static void test_static_broadcasting()
|
||||
Tensor<float, 3, DataLayout> tensor(8,3,5);
|
||||
tensor.setRandom();
|
||||
|
||||
#if EIGEN_HAS_CONSTEXPR
|
||||
#if defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
|
||||
#else
|
||||
Eigen::array<int, 3> broadcasts;
|
||||
|
@ -41,6 +41,28 @@ static void test_simple_reshape()
|
||||
}
|
||||
}
|
||||
|
||||
template <typename>
|
||||
static void test_static_reshape() {
|
||||
#if defined(EIGEN_HAS_INDEX_LIST)
|
||||
using Eigen::type2index;
|
||||
|
||||
Tensor<float, 5> tensor(2, 3, 1, 7, 1);
|
||||
tensor.setRandom();
|
||||
|
||||
// New dimensions: [2, 3, 7]
|
||||
Eigen::IndexList<type2index<2>, type2index<3>, type2index<7>> dim;
|
||||
Tensor<float, 3> reshaped = tensor.reshape(dim);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(tensor(i, j, 0, k, 0), reshaped(i, j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename>
|
||||
static void test_reshape_in_expr() {
|
||||
MatrixXf m1(2,3*5*7*11);
|
||||
@ -462,6 +484,7 @@ static void test_composition()
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_morphing)
|
||||
{
|
||||
CALL_SUBTEST_1(test_simple_reshape<void>());
|
||||
CALL_SUBTEST_1(test_static_reshape<void>());
|
||||
CALL_SUBTEST_1(test_reshape_in_expr<void>());
|
||||
CALL_SUBTEST_1(test_reshape_as_lvalue<void>());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user