Fix test for TensorRef of trace.

This commit is contained in:
Rasmus Munk Larsen 2025-03-25 23:01:46 +00:00
parent 6579e36eb4
commit 3866cbfbe8

View File

@ -41,9 +41,6 @@ static void test_simple_lvalue_ref() {
for (int i = 0; i < 6; ++i) {
VERIFY_IS_EQUAL(input(i), -i * 2);
}
TensorRef<const Tensor<int, 1>> ref5(input.trace());
VERIFY_IS_EQUAL(ref5[0], input[0]);
}
static void test_simple_rvalue_ref() {
@ -111,6 +108,17 @@ static void test_slice() {
VERIFY_IS_EQUAL(slice.data(), tensor.data());
}
static void test_ref_of_trace() {
Tensor<int, 2> input(6, 6);
input.setRandom();
int trace = 0;
for (int i = 0; i < 6; ++i) {
trace += input(i, i);
}
TensorRef<const Tensor<int, 0>> ref(input.trace());
VERIFY_IS_EQUAL(ref.coeff(0), trace);
}
static void test_ref_of_ref() {
Tensor<float, 3> input(3, 5, 7);
input.setRandom();
@ -227,6 +235,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_ref) {
CALL_SUBTEST(test_simple_rvalue_ref());
CALL_SUBTEST(test_multiple_dims());
CALL_SUBTEST(test_slice());
CALL_SUBTEST(test_ref_of_trace());
CALL_SUBTEST(test_ref_of_ref());
CALL_SUBTEST(test_ref_in_expr());
CALL_SUBTEST(test_coeff_ref());