// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2016 Igor Babuschkin // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" #include #include #include using Eigen::Tensor; template static void test_1d_scan() { int size = 50; Tensor tensor(size); tensor.setRandom(); Tensor result = tensor.cumsum(0); VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0)); float accum = 0; for (int i = 0; i < size; i++) { accum += tensor(i); VERIFY_IS_EQUAL(result(i), accum); } accum = 1; result = tensor.cumprod(0); for (int i = 0; i < size; i++) { accum *= tensor(i); VERIFY_IS_EQUAL(result(i), accum); } } template static void test_4d_scan() { int size = 5; Tensor tensor(size, size, size, size); tensor.setRandom(); Tensor result(size, size, size, size); result = tensor.cumsum(0); float accum = 0; for (int i = 0; i < size; i++) { accum += tensor(i, 0, 0, 0); VERIFY_IS_EQUAL(result(i, 0, 0, 0), accum); } result = tensor.cumsum(1); accum = 0; for (int i = 0; i < size; i++) { accum += tensor(0, i, 0, 0); VERIFY_IS_EQUAL(result(0, i, 0, 0), accum); } result = tensor.cumsum(2); accum = 0; for (int i = 0; i < size; i++) { accum += tensor(0, 0, i, 0); VERIFY_IS_EQUAL(result(0, 0, i, 0), accum); } result = tensor.cumsum(3); accum = 0; for (int i = 0; i < size; i++) { accum += tensor(0, 0, 0, i); VERIFY_IS_EQUAL(result(0, 0, 0, i), accum); } } template static void test_tensor_maps() { int inputs[20]; TensorMap > tensor_map(inputs, 20); tensor_map.setRandom(); Tensor result = tensor_map.cumsum(0); int accum = 0; for (int i = 0; i < 20; ++i) { accum += tensor_map(i); VERIFY_IS_EQUAL(result(i), accum); } } void test_cxx11_tensor_scan() { CALL_SUBTEST(test_1d_scan()); CALL_SUBTEST(test_1d_scan()); CALL_SUBTEST(test_4d_scan()); CALL_SUBTEST(test_4d_scan()); CALL_SUBTEST(test_tensor_maps()); CALL_SUBTEST(test_tensor_maps()); }