Add unitests for inverse and selfadjoint-eigenvalues on CUDA

This commit is contained in:
Gael Guennebaud 2018-07-06 09:58:45 +02:00
parent f7124b3e46
commit a8ab6060df

View File

@ -121,7 +121,7 @@ struct diagonal {
};
template<typename T>
struct eigenvalues {
struct eigenvalues_direct {
EIGEN_DEVICE_FUNC
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
{
@ -136,6 +136,34 @@ struct eigenvalues {
}
};
template<typename T>
struct eigenvalues {
EIGEN_DEVICE_FUNC
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
{
using namespace Eigen;
typedef Matrix<typename T::Scalar, T::RowsAtCompileTime, 1> Vec;
T M(in+i);
Map<Vec> res(out+i*Vec::MaxSizeAtCompileTime);
T A = M*M.adjoint();
SelfAdjointEigenSolver<T> eig;
eig.compute(M);
res = eig.eigenvalues();
}
};
template<typename T>
struct matrix_inverse {
EIGEN_DEVICE_FUNC
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
{
using namespace Eigen;
T M(in+i);
Map<T> res(out+i*T::MaxSizeAtCompileTime);
res = M.inverse();
}
};
void test_cuda_basic()
{
ei_test_init_cuda();
@ -163,8 +191,13 @@ void test_cuda_basic()
CALL_SUBTEST( run_and_compare_to_cuda(diagonal<Matrix3f,Vector3f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(diagonal<Matrix4f,Vector4f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(matrix_inverse<Matrix2f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(matrix_inverse<Matrix3f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(matrix_inverse<Matrix4f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(eigenvalues<Matrix3f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(eigenvalues<Matrix2f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(eigenvalues_direct<Matrix3f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(eigenvalues_direct<Matrix2f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_cuda(eigenvalues<Matrix4f>(), nthreads, in, out) );
}