diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h index 10e0a8a6b..f81da318c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h @@ -231,20 +231,32 @@ struct TensorEvaluator, D // t_n = exp(sqrt(-1) * pi * n^2 / line_len) // for n = 0, 1,..., line_len-1. // For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2 - pos_j_base_powered[0] = ComplexScalar(1, 0); - if (line_len > 1) { - const RealScalar pi_over_len(EIGEN_PI / line_len); - const ComplexScalar pos_j_base = ComplexScalar( - std::cos(pi_over_len), std::sin(pi_over_len)); - pos_j_base_powered[1] = pos_j_base; - if (line_len > 2) { - const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base; - for (int j = 2; j < line_len + 1; ++j) { - pos_j_base_powered[j] = pos_j_base_powered[j - 1] * - pos_j_base_powered[j - 1] / - pos_j_base_powered[j - 2] * pos_j_base_sq; - } - } + + // The recurrence is correct in exact arithmetic, but causes + // numerical issues for large transforms, especially in + // single-precision floating point. + // + // pos_j_base_powered[0] = ComplexScalar(1, 0); + // if (line_len > 1) { + // const ComplexScalar pos_j_base = ComplexScalar( + // numext::cos(M_PI / line_len), numext::sin(M_PI / line_len)); + // pos_j_base_powered[1] = pos_j_base; + // if (line_len > 2) { + // const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base; + // for (int i = 2; i < line_len + 1; ++i) { + // pos_j_base_powered[i] = pos_j_base_powered[i - 1] * + // pos_j_base_powered[i - 1] / + // pos_j_base_powered[i - 2] * + // pos_j_base_sq; + // } + // } + // } + // TODO(rmlarsen): Find a way to use Eigen's vectorized sin + // and cosine functions here. + for (int j = 0; j < line_len + 1; ++j) { + double arg = ((EIGEN_PI * j) * j) / line_len; + std::complex tmp(numext::cos(arg), numext::sin(arg)); + pos_j_base_powered[j] = static_cast(tmp); } } diff --git a/unsupported/test/cxx11_tensor_fft.cpp b/unsupported/test/cxx11_tensor_fft.cpp index 2f14ebc62..a55369477 100644 --- a/unsupported/test/cxx11_tensor_fft.cpp +++ b/unsupported/test/cxx11_tensor_fft.cpp @@ -224,6 +224,32 @@ static void test_fft_real_input_energy() { } } +template +static void test_fft_non_power_of_2_round_trip(int exponent) { + int n = (1 << exponent) + 1; + + Eigen::DSizes dimensions; + dimensions[0] = n; + const DSizes arr = dimensions; + Tensor input; + + input.resize(arr); + input.setRandom(); + + array fft; + fft[0] = 0; + + Tensor, 1, ColMajor> forward = + input.template fft(fft); + + Tensor output = + forward.template fft(fft); + + for (int i = 0; i < n; ++i) { + VERIFY_IS_APPROX(input[i], output[i]); + } +} + void test_cxx11_tensor_fft() { test_fft_complex_input_golden(); test_fft_real_input_golden(); @@ -270,4 +296,6 @@ void test_cxx11_tensor_fft() { test_fft_real_input_energy(); test_fft_real_input_energy(); test_fft_real_input_energy(); + + test_fft_non_power_of_2_round_trip(7); }