Fixed the tensor_cuda test

This commit is contained in:
Benoit Steiner 2016-01-27 14:58:48 -08:00
parent 55a5204319
commit 47ca9dc809

View File

@ -131,8 +131,7 @@ void test_cuda_reduction()
cudaMemcpy(d_in1, in1.data(), in1_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 4> > gpu_in1(d_in1, 72,53,97,113);
@ -189,8 +188,7 @@ static void test_cuda_contraction()
cudaMemcpy(d_t_left, t_left.data(), t_left_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_t_right, t_right.data(), t_right_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 4, DataLayout> > gpu_t_left(d_t_left, 6, 50, 3, 31);
@ -214,7 +212,7 @@ static void test_cuda_contraction()
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
if (fabs(t_result.data()[i] - m_result.data()[i]) >= 1e-4) {
cout << "mismatch detected at index " << i << ": " << t_result.data()[i] << " vs " << m_result.data()[i] << endl;
std::cout << "mismatch detected at index " << i << ": " << t_result.data()[i] << " vs " << m_result.data()[i] << std::endl;
assert(false);
}
}
@ -243,8 +241,7 @@ static void test_cuda_convolution_1d()
cudaMemcpy(d_input, input.data(), input_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_kernel, kernel.data(), kernel_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 4, DataLayout> > gpu_input(d_input, 74,37,11,137);
@ -293,8 +290,7 @@ static void test_cuda_convolution_inner_dim_col_major_1d()
cudaMemcpy(d_input, input.data(), input_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_kernel, kernel.data(), kernel_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 4, ColMajor> > gpu_input(d_input,74,9,11,7);
@ -343,8 +339,7 @@ static void test_cuda_convolution_inner_dim_row_major_1d()
cudaMemcpy(d_input, input.data(), input_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_kernel, kernel.data(), kernel_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 4, RowMajor> > gpu_input(d_input, 7,9,11,74);
@ -394,8 +389,7 @@ static void test_cuda_convolution_2d()
cudaMemcpy(d_input, input.data(), input_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_kernel, kernel.data(), kernel_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 4, DataLayout> > gpu_input(d_input,74,37,11,137);
@ -455,8 +449,7 @@ static void test_cuda_convolution_3d()
cudaMemcpy(d_input, input.data(), input_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_kernel, kernel.data(), kernel_bytes, cudaMemcpyHostToDevice);
cudaStream_t stream;
assert(cudaStreamCreate(&stream) == cudaSuccess);
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
Eigen::TensorMap<Eigen::Tensor<float, 5, DataLayout> > gpu_input(d_input,74,37,11,137,17);
@ -644,10 +637,6 @@ void test_cxx11_tensor_cuda()
CALL_SUBTEST(test_cuda_erfc<float>(5.0f)); // CUDA erfc lacks precision for large inputs
CALL_SUBTEST(test_cuda_erfc<float>(0.01f));
CALL_SUBTEST(test_cuda_erfc<float>(0.001f));
CALL_SUBTEST(test_cuda_tanh<double>(1.0));
CALL_SUBTEST(test_cuda_tanh<double>(100.0));
CALL_SUBTEST(test_cuda_tanh<double>(0.01));
CALL_SUBTEST(test_cuda_tanh<double>(0.001));
CALL_SUBTEST(test_cuda_lgamma<double>(1.0));
CALL_SUBTEST(test_cuda_lgamma<double>(100.0));
CALL_SUBTEST(test_cuda_lgamma<double>(0.01));