Added a test for shuffling

This commit is contained in:
Benoit Steiner 2015-07-29 15:01:21 -07:00
parent 0570594f2c
commit e1d28b7ea7
2 changed files with 49 additions and 8 deletions

View File

@ -67,7 +67,7 @@ class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType>
: m_xpr(expr), m_shuffle(shuffle) {}
EIGEN_DEVICE_FUNC
const Shuffle& shuffle() const { return m_shuffle; }
const Shuffle& shufflePermutation() const { return m_shuffle; }
EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename XprType::Nested>::type&
@ -119,7 +119,7 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
: m_impl(op.expression(), device)
{
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
const Shuffle& shuffle = op.shuffle();
const Shuffle& shuffle = op.shufflePermutation();
for (int i = 0; i < NumDims; ++i) {
m_dimensions[i] = input_dims[shuffle[i]];
}

View File

@ -176,12 +176,53 @@ static void test_shuffling_as_value()
}
}
template <int DataLayout>
static void test_shuffle_unshuffle()
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
// Choose a random permutation.
array<ptrdiff_t, 4> shuffles;
for (int i = 0; i < 4; ++i) {
shuffles[i] = i;
}
array<ptrdiff_t, 4> shuffles_inverse;
for (int i = 0; i < 4; ++i) {
const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
shuffles_inverse[shuffles[index]] = i;
std::swap(shuffles[i], shuffles[index]);
}
Tensor<float, 4, DataLayout> shuffle;
shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 5; ++k) {
for (int l = 0; l < 7; ++l) {
VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
}
}
}
}
}
void test_cxx11_tensor_shuffling()
{
CALL_SUBTEST(test_simple_shuffling<ColMajor>());
CALL_SUBTEST(test_simple_shuffling<RowMajor>());
CALL_SUBTEST(test_expr_shuffling<ColMajor>());
CALL_SUBTEST(test_expr_shuffling<RowMajor>());
CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
CALL_SUBTEST(test_simple_shuffling<ColMajor>());
CALL_SUBTEST(test_simple_shuffling<RowMajor>());
CALL_SUBTEST(test_expr_shuffling<ColMajor>());
CALL_SUBTEST(test_expr_shuffling<RowMajor>());
CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
}