mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-13 18:37:27 +08:00
Disable use of recurrence for computing twiddle factors. Fixes FFT precision issues for large FFTs. https://github.com/tensorflow/tensorflow/issues/10749#issuecomment-354557689
This commit is contained in:
parent
f9bdcea022
commit
59985cfd26
@ -231,20 +231,32 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
|
||||
// t_n = exp(sqrt(-1) * pi * n^2 / line_len)
|
||||
// for n = 0, 1,..., line_len-1.
|
||||
// For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
|
||||
pos_j_base_powered[0] = ComplexScalar(1, 0);
|
||||
if (line_len > 1) {
|
||||
const RealScalar pi_over_len(EIGEN_PI / line_len);
|
||||
const ComplexScalar pos_j_base = ComplexScalar(
|
||||
std::cos(pi_over_len), std::sin(pi_over_len));
|
||||
pos_j_base_powered[1] = pos_j_base;
|
||||
if (line_len > 2) {
|
||||
const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
|
||||
for (int j = 2; j < line_len + 1; ++j) {
|
||||
pos_j_base_powered[j] = pos_j_base_powered[j - 1] *
|
||||
pos_j_base_powered[j - 1] /
|
||||
pos_j_base_powered[j - 2] * pos_j_base_sq;
|
||||
}
|
||||
}
|
||||
|
||||
// The recurrence is correct in exact arithmetic, but causes
|
||||
// numerical issues for large transforms, especially in
|
||||
// single-precision floating point.
|
||||
//
|
||||
// pos_j_base_powered[0] = ComplexScalar(1, 0);
|
||||
// if (line_len > 1) {
|
||||
// const ComplexScalar pos_j_base = ComplexScalar(
|
||||
// numext::cos(M_PI / line_len), numext::sin(M_PI / line_len));
|
||||
// pos_j_base_powered[1] = pos_j_base;
|
||||
// if (line_len > 2) {
|
||||
// const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
|
||||
// for (int i = 2; i < line_len + 1; ++i) {
|
||||
// pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
|
||||
// pos_j_base_powered[i - 1] /
|
||||
// pos_j_base_powered[i - 2] *
|
||||
// pos_j_base_sq;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// TODO(rmlarsen): Find a way to use Eigen's vectorized sin
|
||||
// and cosine functions here.
|
||||
for (int j = 0; j < line_len + 1; ++j) {
|
||||
double arg = ((EIGEN_PI * j) * j) / line_len;
|
||||
std::complex<double> tmp(numext::cos(arg), numext::sin(arg));
|
||||
pos_j_base_powered[j] = static_cast<ComplexScalar>(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -224,6 +224,32 @@ static void test_fft_real_input_energy() {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RealScalar>
|
||||
static void test_fft_non_power_of_2_round_trip(int exponent) {
|
||||
int n = (1 << exponent) + 1;
|
||||
|
||||
Eigen::DSizes<long, 1> dimensions;
|
||||
dimensions[0] = n;
|
||||
const DSizes<long, 1> arr = dimensions;
|
||||
Tensor<RealScalar, 1, ColMajor, long> input;
|
||||
|
||||
input.resize(arr);
|
||||
input.setRandom();
|
||||
|
||||
array<int, 1> fft;
|
||||
fft[0] = 0;
|
||||
|
||||
Tensor<std::complex<RealScalar>, 1, ColMajor> forward =
|
||||
input.template fft<BothParts, FFT_FORWARD>(fft);
|
||||
|
||||
Tensor<RealScalar, 1, ColMajor, long> output =
|
||||
forward.template fft<RealPart, FFT_REVERSE>(fft);
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
VERIFY_IS_APPROX(input[i], output[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void test_cxx11_tensor_fft() {
|
||||
test_fft_complex_input_golden();
|
||||
test_fft_real_input_golden();
|
||||
@ -270,4 +296,6 @@ void test_cxx11_tensor_fft() {
|
||||
test_fft_real_input_energy<RowMajor, double, true, Eigen::BothParts, FFT_FORWARD, 4>();
|
||||
test_fft_real_input_energy<RowMajor, float, false, Eigen::BothParts, FFT_FORWARD, 4>();
|
||||
test_fft_real_input_energy<RowMajor, double, false, Eigen::BothParts, FFT_FORWARD, 4>();
|
||||
|
||||
test_fft_non_power_of_2_round_trip<float>(7);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user