Disable use of recurrence for computing twiddle factors. Fixes FFT precision issues for large FFTs. https://github.com/tensorflow/tensorflow/issues/10749#issuecomment-354557689

This commit is contained in:
RJ Ryan 2017-12-31 10:44:56 -05:00
parent f9bdcea022
commit 59985cfd26
2 changed files with 54 additions and 14 deletions

View File

@ -231,20 +231,32 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
// t_n = exp(sqrt(-1) * pi * n^2 / line_len)
// for n = 0, 1,..., line_len-1.
// For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
pos_j_base_powered[0] = ComplexScalar(1, 0);
if (line_len > 1) {
const RealScalar pi_over_len(EIGEN_PI / line_len);
const ComplexScalar pos_j_base = ComplexScalar(
std::cos(pi_over_len), std::sin(pi_over_len));
pos_j_base_powered[1] = pos_j_base;
if (line_len > 2) {
const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
for (int j = 2; j < line_len + 1; ++j) {
pos_j_base_powered[j] = pos_j_base_powered[j - 1] *
pos_j_base_powered[j - 1] /
pos_j_base_powered[j - 2] * pos_j_base_sq;
}
}
// The recurrence is correct in exact arithmetic, but causes
// numerical issues for large transforms, especially in
// single-precision floating point.
//
// pos_j_base_powered[0] = ComplexScalar(1, 0);
// if (line_len > 1) {
// const ComplexScalar pos_j_base = ComplexScalar(
// numext::cos(M_PI / line_len), numext::sin(M_PI / line_len));
// pos_j_base_powered[1] = pos_j_base;
// if (line_len > 2) {
// const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
// for (int i = 2; i < line_len + 1; ++i) {
// pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
// pos_j_base_powered[i - 1] /
// pos_j_base_powered[i - 2] *
// pos_j_base_sq;
// }
// }
// }
// TODO(rmlarsen): Find a way to use Eigen's vectorized sin
// and cosine functions here.
for (int j = 0; j < line_len + 1; ++j) {
double arg = ((EIGEN_PI * j) * j) / line_len;
std::complex<double> tmp(numext::cos(arg), numext::sin(arg));
pos_j_base_powered[j] = static_cast<ComplexScalar>(tmp);
}
}

View File

@ -224,6 +224,32 @@ static void test_fft_real_input_energy() {
}
}
template <typename RealScalar>
static void test_fft_non_power_of_2_round_trip(int exponent) {
int n = (1 << exponent) + 1;
Eigen::DSizes<long, 1> dimensions;
dimensions[0] = n;
const DSizes<long, 1> arr = dimensions;
Tensor<RealScalar, 1, ColMajor, long> input;
input.resize(arr);
input.setRandom();
array<int, 1> fft;
fft[0] = 0;
Tensor<std::complex<RealScalar>, 1, ColMajor> forward =
input.template fft<BothParts, FFT_FORWARD>(fft);
Tensor<RealScalar, 1, ColMajor, long> output =
forward.template fft<RealPart, FFT_REVERSE>(fft);
for (int i = 0; i < n; ++i) {
VERIFY_IS_APPROX(input[i], output[i]);
}
}
void test_cxx11_tensor_fft() {
test_fft_complex_input_golden();
test_fft_real_input_golden();
@ -270,4 +296,6 @@ void test_cxx11_tensor_fft() {
test_fft_real_input_energy<RowMajor, double, true, Eigen::BothParts, FFT_FORWARD, 4>();
test_fft_real_input_energy<RowMajor, float, false, Eigen::BothParts, FFT_FORWARD, 4>();
test_fft_real_input_energy<RowMajor, double, false, Eigen::BothParts, FFT_FORWARD, 4>();
test_fft_non_power_of_2_round_trip<float>(7);
}