mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Added a test to cover threaded tensor shuffling
This commit is contained in:
parent
32088c06a1
commit
9de155d153
@ -283,6 +283,31 @@ void test_multithread_random()
|
||||
t.device(device) = t.random<Eigen::internal::NormalRandomGenerator<float>>();
|
||||
}
|
||||
|
||||
template<int DataLayout>
|
||||
void test_multithread_shuffle()
|
||||
{
|
||||
Tensor<float, 4, DataLayout> tensor(17,5,7,11);
|
||||
tensor.setRandom();
|
||||
|
||||
const int num_threads = internal::random<int>(2, 11);
|
||||
ThreadPool threads(num_threads);
|
||||
Eigen::ThreadPoolDevice device(&threads, num_threads);
|
||||
|
||||
Tensor<float, 4, DataLayout> shuffle(7,5,11,17);
|
||||
array<ptrdiff_t, 4> shuffles = {{2,1,3,0}};
|
||||
shuffle.device(device) = tensor.shuffle(shuffles);
|
||||
|
||||
for (int i = 0; i < 17; ++i) {
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
for (int l = 0; l < 11; ++l) {
|
||||
VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,j,l,i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_thread_pool()
|
||||
{
|
||||
@ -304,4 +329,6 @@ void test_cxx11_tensor_thread_pool()
|
||||
|
||||
CALL_SUBTEST_6(test_memcpy());
|
||||
CALL_SUBTEST_6(test_multithread_random());
|
||||
CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>());
|
||||
CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user