eigen/test/assignment_threaded.cpp
2024-05-29 13:38:00 -07:00

85 lines
3.2 KiB
C++

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2023 Charlie Schlosser <cs.schlosser@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#define EIGEN_USE_THREADS 1
#include "main.h"
#include <Eigen/ThreadPool>
namespace Eigen {
namespace internal {
// conveniently control vectorization logic
template <typename Scalar, bool Vectorize>
struct scalar_dummy_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const { return a; }
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const {
return a;
}
};
template <typename Scalar, bool Vectorize>
struct functor_traits<scalar_dummy_op<Scalar, Vectorize> > {
enum { Cost = 1'000'000, PacketAccess = Vectorize && packet_traits<Scalar>::Vectorizable };
};
} // namespace internal
} // namespace Eigen
template <typename PlainObject>
void test_threaded_assignment(const PlainObject&, Index rows = PlainObject::RowsAtCompileTime,
Index cols = PlainObject::ColsAtCompileTime) {
using Scalar = typename PlainObject::Scalar;
using VectorizationOff = internal::scalar_dummy_op<Scalar, false>;
using VectorizationOn = internal::scalar_dummy_op<Scalar, true>;
int threads = 4;
ThreadPool pool(threads);
CoreThreadPoolDevice threadPoolDevice(pool);
PlainObject dst(rows, cols), ref(rows, cols), rhs(rows, cols);
rhs.setRandom();
const auto rhs_xpr = rhs.cwiseAbs2();
// linear access
dst.setRandom();
ref.setRandom();
ref = rhs_xpr.unaryExpr(VectorizationOff());
dst.device(threadPoolDevice) = rhs_xpr.unaryExpr(VectorizationOff());
VERIFY_IS_CWISE_EQUAL(ref, dst);
ref = rhs_xpr.unaryExpr(VectorizationOn());
dst.device(threadPoolDevice) = rhs_xpr.unaryExpr(VectorizationOn());
VERIFY_IS_CWISE_EQUAL(ref, dst);
// outer-inner access
Index blockRows = numext::maxi(Index(1), rows - 1);
Index blockCols = numext::maxi(Index(1), cols - 1);
dst.setRandom();
ref.setRandom();
ref.bottomRightCorner(blockRows, blockCols) =
rhs_xpr.bottomRightCorner(blockRows, blockCols).unaryExpr(VectorizationOff());
dst.bottomRightCorner(blockRows, blockCols).device(threadPoolDevice) =
rhs_xpr.bottomRightCorner(blockRows, blockCols).unaryExpr(VectorizationOff());
VERIFY_IS_CWISE_EQUAL(ref.bottomRightCorner(blockRows, blockCols), dst.bottomRightCorner(blockRows, blockCols));
ref.setZero();
dst.setZero();
ref.bottomRightCorner(blockRows, blockCols) =
rhs_xpr.bottomRightCorner(blockRows, blockCols).unaryExpr(VectorizationOn());
dst.bottomRightCorner(blockRows, blockCols).device(threadPoolDevice) =
rhs_xpr.bottomRightCorner(blockRows, blockCols).unaryExpr(VectorizationOn());
VERIFY_IS_CWISE_EQUAL(ref.bottomRightCorner(blockRows, blockCols), dst.bottomRightCorner(blockRows, blockCols));
}
EIGEN_DECLARE_TEST(test) {
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST(test_threaded_assignment(MatrixXd(), 123, 123));
CALL_SUBTEST(test_threaded_assignment(Matrix<float, 16, 16>()));
}
}