Speed up tensor FFT by up ~25-50%.

Benchmark                          Base (ns)  New (ns) Improvement
------------------------------------------------------------------
BM_tensor_fft_single_1D_cpu/8            132       134     -1.5%
BM_tensor_fft_single_1D_cpu/9           1162      1229     -5.8%
BM_tensor_fft_single_1D_cpu/16           199       195     +2.0%
BM_tensor_fft_single_1D_cpu/17          2587      2267    +12.4%
BM_tensor_fft_single_1D_cpu/32           373       341     +8.6%
BM_tensor_fft_single_1D_cpu/33          5922      4879    +17.6%
BM_tensor_fft_single_1D_cpu/64           797       675    +15.3%
BM_tensor_fft_single_1D_cpu/65         13580     10481    +22.8%
BM_tensor_fft_single_1D_cpu/128         1753      1375    +21.6%
BM_tensor_fft_single_1D_cpu/129        31426     22789    +27.5%
BM_tensor_fft_single_1D_cpu/256         4005      3008    +24.9%
BM_tensor_fft_single_1D_cpu/257        70910     49549    +30.1%
BM_tensor_fft_single_1D_cpu/512         8989      6524    +27.4%
BM_tensor_fft_single_1D_cpu/513       165402    107751    +34.9%
BM_tensor_fft_single_1D_cpu/999       198293    115909    +41.5%
BM_tensor_fft_single_1D_cpu/1ki        21289     14143    +33.6%
BM_tensor_fft_single_1D_cpu/1k        361980    233355    +35.5%
BM_tensor_fft_double_1D_cpu/8            138       131     +5.1%
BM_tensor_fft_double_1D_cpu/9           1253      1133     +9.6%
BM_tensor_fft_double_1D_cpu/16           218       200     +8.3%
BM_tensor_fft_double_1D_cpu/17          2770      2392    +13.6%
BM_tensor_fft_double_1D_cpu/32           406       368     +9.4%
BM_tensor_fft_double_1D_cpu/33          6418      5153    +19.7%
BM_tensor_fft_double_1D_cpu/64           856       728    +15.0%
BM_tensor_fft_double_1D_cpu/65         14666     11148    +24.0%
BM_tensor_fft_double_1D_cpu/128         1913      1502    +21.5%
BM_tensor_fft_double_1D_cpu/129        36414     24072    +33.9%
BM_tensor_fft_double_1D_cpu/256         4226      3216    +23.9%
BM_tensor_fft_double_1D_cpu/257        86638     52059    +39.9%
BM_tensor_fft_double_1D_cpu/512         9397      6939    +26.2%
BM_tensor_fft_double_1D_cpu/513       203208    114090    +43.9%
BM_tensor_fft_double_1D_cpu/999       237841    125583    +47.2%
BM_tensor_fft_double_1D_cpu/1ki        20921     15392    +26.4%
BM_tensor_fft_double_1D_cpu/1k        455183    250763    +44.9%
BM_tensor_fft_single_2D_cpu/8           1051      1005     +4.4%
BM_tensor_fft_single_2D_cpu/9          16784     14837    +11.6%
BM_tensor_fft_single_2D_cpu/16          4074      3772     +7.4%
BM_tensor_fft_single_2D_cpu/17         75802     63884    +15.7%
BM_tensor_fft_single_2D_cpu/32         20580     16931    +17.7%
BM_tensor_fft_single_2D_cpu/33        345798    278579    +19.4%
BM_tensor_fft_single_2D_cpu/64         97548     81237    +16.7%
BM_tensor_fft_single_2D_cpu/65       1592701   1227048    +23.0%
BM_tensor_fft_single_2D_cpu/128       472318    384303    +18.6%
BM_tensor_fft_single_2D_cpu/129      7038351   5445308    +22.6%
BM_tensor_fft_single_2D_cpu/256      2309474   1850969    +19.9%
BM_tensor_fft_single_2D_cpu/257     31849182  23797538    +25.3%
BM_tensor_fft_single_2D_cpu/512     10395194   8077499    +22.3%
BM_tensor_fft_single_2D_cpu/513     144053843  104242541    +27.6%
BM_tensor_fft_single_2D_cpu/999     279885833  208389718    +25.5%
BM_tensor_fft_single_2D_cpu/1ki     45967677  36070985    +21.5%
BM_tensor_fft_single_2D_cpu/1k      619727095  456489500    +26.3%
BM_tensor_fft_double_2D_cpu/8           1110      1016     +8.5%
BM_tensor_fft_double_2D_cpu/9          17957     15768    +12.2%
BM_tensor_fft_double_2D_cpu/16          4558      4000    +12.2%
BM_tensor_fft_double_2D_cpu/17         79237     66901    +15.6%
BM_tensor_fft_double_2D_cpu/32         21494     17699    +17.7%
BM_tensor_fft_double_2D_cpu/33        357962    290357    +18.9%
BM_tensor_fft_double_2D_cpu/64        105179     87435    +16.9%
BM_tensor_fft_double_2D_cpu/65       1617143   1288006    +20.4%
BM_tensor_fft_double_2D_cpu/128       512848    419397    +18.2%
BM_tensor_fft_double_2D_cpu/129      7271322   5636884    +22.5%
BM_tensor_fft_double_2D_cpu/256      2415529   1922032    +20.4%
BM_tensor_fft_double_2D_cpu/257     32517952  24462177    +24.8%
BM_tensor_fft_double_2D_cpu/512     10724898   8287617    +22.7%
BM_tensor_fft_double_2D_cpu/513     146007419  108603266    +25.6%
BM_tensor_fft_double_2D_cpu/999     296351330  221885776    +25.1%
BM_tensor_fft_double_2D_cpu/1ki     59334166  48357539    +18.5%
BM_tensor_fft_double_2D_cpu/1k      666660132  483840349    +27.4%
This commit is contained in:
Rasmus Munk Larsen 2016-02-19 16:29:23 -08:00
parent d90a2dac5e
commit d5e2ec7447

View File

@ -219,19 +219,56 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
ComplexScalar* b = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite); ComplexScalar* b = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite);
ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * (line_len + 1)); ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * (line_len + 1));
if (!is_power_of_two) { if (!is_power_of_two) {
ComplexScalar pos_j_base = ComplexScalar(std::cos(M_PI/line_len), std::sin(M_PI/line_len)); // Compute twiddle factors
for (Index j = 0; j < line_len + 1; ++j) { // t_n = exp(sqrt(-1) * pi * n^2 / line_len)
pos_j_base_powered[j] = std::pow(pos_j_base, j * j); // for n = 0, 1,..., line_len-1.
// For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
pos_j_base_powered[0] = ComplexScalar(1, 0);
if (line_len > 1) {
const ComplexScalar pos_j_base = ComplexScalar(
std::cos(M_PI / line_len), std::sin(M_PI / line_len));
pos_j_base_powered[1] = pos_j_base;
if (line_len > 2) {
const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
for (int i = 2; i < line_len + 1; ++i) {
pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
pos_j_base_powered[i - 1] /
pos_j_base_powered[i - 2] * pos_j_base_sq;
}
}
}
// Compute twiddle factors
// t_n = exp(sqrt(-1) * pi * n^2 / line_len)
// for n = 0, 1,..., line_len-1.
// For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
pos_j_base_powered[0] = ComplexScalar(1, 0);
if (line_len > 1) {
const ComplexScalar pos_j_base = ComplexScalar(
std::cos(M_PI / line_len), std::sin(M_PI / line_len));
pos_j_base_powered[1] = pos_j_base;
if (line_len > 2) {
const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
for (int i = 2; i < line_len + 1; ++i) {
pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
pos_j_base_powered[i - 1] /
pos_j_base_powered[i - 2] * pos_j_base_sq;
}
}
} }
} }
for (Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) { for (Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) {
Index base_offset = getBaseOffsetFromIndex(partial_index, dim); const Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
// get data into line_buf // get data into line_buf
for (Index j = 0; j < line_len; ++j) { const Index stride = m_strides[dim];
Index offset = getIndexFromOffset(base_offset, dim, j); if (stride == 1) {
line_buf[j] = buf[offset]; memcpy(line_buf, &buf[base_offset], line_len*sizeof(ComplexScalar));
} else {
Index offset = base_offset;
for (int j = 0; j < line_len; ++j, offset += stride) {
line_buf[j] = buf[offset];
}
} }
// processs the line // processs the line
@ -243,14 +280,18 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
} }
// write back // write back
for (Index j = 0; j < line_len; ++j) { if (FFTDir == FFT_FORWARD && stride == 1) {
const ComplexScalar div_factor = (FFTDir == FFT_FORWARD) ? ComplexScalar(1, 0) : ComplexScalar(line_len, 0); memcpy(&buf[base_offset], line_buf, line_len*sizeof(ComplexScalar));
Index offset = getIndexFromOffset(base_offset, dim, j); } else {
buf[offset] = line_buf[j] / div_factor; Index offset = base_offset;
const ComplexScalar div_factor = ComplexScalar(1.0 / line_len, 0);
for (int j = 0; j < line_len; ++j, offset += stride) {
buf[offset] = (FFTDir == FFT_FORWARD) ? line_buf[j] : line_buf[j] * div_factor;
}
} }
} }
m_device.deallocate(line_buf); m_device.deallocate(line_buf);
if (!pos_j_base_powered) { if (!is_power_of_two) {
m_device.deallocate(a); m_device.deallocate(a);
m_device.deallocate(b); m_device.deallocate(b);
m_device.deallocate(pos_j_base_powered); m_device.deallocate(pos_j_base_powered);
@ -372,109 +413,130 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
} }
} }
template<int Dir> template <int Dir>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(ComplexScalar* data, Index n, Index n_power_of_2) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_2(ComplexScalar* data) {
ComplexScalar tmp = data[1];
data[1] = data[0] - data[1];
data[0] += tmp;
}
template <int Dir>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_4(ComplexScalar* data) {
ComplexScalar tmp[4];
tmp[0] = data[0] + data[1];
tmp[1] = data[0] - data[1];
tmp[2] = data[2] + data[3];
if (Dir == FFT_FORWARD) {
tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]);
} else {
tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]);
}
data[0] = tmp[0] + tmp[2];
data[1] = tmp[1] + tmp[3];
data[2] = tmp[0] - tmp[2];
data[3] = tmp[1] - tmp[3];
}
template <int Dir>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_8(ComplexScalar* data) {
ComplexScalar tmp_1[8];
ComplexScalar tmp_2[8];
tmp_1[0] = data[0] + data[1];
tmp_1[1] = data[0] - data[1];
tmp_1[2] = data[2] + data[3];
if (Dir == FFT_FORWARD) {
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1);
} else {
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1);
}
tmp_1[4] = data[4] + data[5];
tmp_1[5] = data[4] - data[5];
tmp_1[6] = data[6] + data[7];
if (Dir == FFT_FORWARD) {
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1);
} else {
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1);
}
tmp_2[0] = tmp_1[0] + tmp_1[2];
tmp_2[1] = tmp_1[1] + tmp_1[3];
tmp_2[2] = tmp_1[0] - tmp_1[2];
tmp_2[3] = tmp_1[1] - tmp_1[3];
tmp_2[4] = tmp_1[4] + tmp_1[6];
// SQRT2DIV2 = sqrt(2)/2
#define SQRT2DIV2 0.7071067811865476
if (Dir == FFT_FORWARD) {
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
} else {
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
}
data[0] = tmp_2[0] + tmp_2[4];
data[1] = tmp_2[1] + tmp_2[5];
data[2] = tmp_2[2] + tmp_2[6];
data[3] = tmp_2[3] + tmp_2[7];
data[4] = tmp_2[0] - tmp_2[4];
data[5] = tmp_2[1] - tmp_2[5];
data[6] = tmp_2[2] - tmp_2[6];
data[7] = tmp_2[3] - tmp_2[7];
}
template <int Dir>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_1D_merge(
ComplexScalar* data, int n, int n_power_of_2) {
// Original code:
// RealScalar wtemp = std::sin(M_PI/n);
// RealScalar wpi = -std::sin(2 * M_PI/n);
const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
const RealScalar wpi = (Dir == FFT_FORWARD)
? m_minus_sin_2_PI_div_n_LUT[n_power_of_2]
: -m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
const ComplexScalar wp(wtemp, wpi);
const ComplexScalar wp_one = wp + ComplexScalar(1, 0);
const ComplexScalar wp_one_2 = wp_one * wp_one;
const ComplexScalar wp_one_3 = wp_one_2 * wp_one;
const ComplexScalar wp_one_4 = wp_one_3 * wp_one;
const int n2 = n / 2;
ComplexScalar w(1.0, 0.0);
for (int i = 0; i < n2; i += 4) {
ComplexScalar temp0(data[i + n2] * w);
ComplexScalar temp1(data[i + 1 + n2] * w * wp_one);
ComplexScalar temp2(data[i + 2 + n2] * w * wp_one_2);
ComplexScalar temp3(data[i + 3 + n2] * w * wp_one_3);
w = w * wp_one_4;
data[i + n2] = data[i] - temp0;
data[i] += temp0;
data[i + 1 + n2] = data[i + 1] - temp1;
data[i + 1] += temp1;
data[i + 2 + n2] = data[i + 2] - temp2;
data[i + 2] += temp2;
data[i + 3 + n2] = data[i + 3] - temp3;
data[i + 3] += temp3;
}
}
template <int Dir>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(
ComplexScalar* data, int n, int n_power_of_2) {
eigen_assert(isPowerOfTwo(n)); eigen_assert(isPowerOfTwo(n));
if (n == 1) { if (n > 8) {
return; compute_1D_Butterfly<Dir>(data, n / 2, n_power_of_2 - 1);
} compute_1D_Butterfly<Dir>(data + n / 2, n / 2, n_power_of_2 - 1);
else if (n == 2) { butterfly_1D_merge<Dir>(data, n, n_power_of_2);
ComplexScalar tmp = data[1]; } else if (n == 8) {
data[1] = data[0] - data[1]; butterfly_8<Dir>(data);
data[0] += tmp; } else if (n == 4) {
return; butterfly_4<Dir>(data);
} } else if (n == 2) {
else if (n == 4) { butterfly_2<Dir>(data);
ComplexScalar tmp[4];
tmp[0] = data[0] + data[1];
tmp[1] = data[0] - data[1];
tmp[2] = data[2] + data[3];
if(Dir == FFT_FORWARD) {
tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]);
}
else {
tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]);
}
data[0] = tmp[0] + tmp[2];
data[1] = tmp[1] + tmp[3];
data[2] = tmp[0] - tmp[2];
data[3] = tmp[1] - tmp[3];
return;
}
else if (n == 8) {
ComplexScalar tmp_1[8];
ComplexScalar tmp_2[8];
tmp_1[0] = data[0] + data[1];
tmp_1[1] = data[0] - data[1];
tmp_1[2] = data[2] + data[3];
if (Dir == FFT_FORWARD) {
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1);
}
else {
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1);
}
tmp_1[4] = data[4] + data[5];
tmp_1[5] = data[4] - data[5];
tmp_1[6] = data[6] + data[7];
if (Dir == FFT_FORWARD) {
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1);
}
else {
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1);
}
tmp_2[0] = tmp_1[0] + tmp_1[2];
tmp_2[1] = tmp_1[1] + tmp_1[3];
tmp_2[2] = tmp_1[0] - tmp_1[2];
tmp_2[3] = tmp_1[1] - tmp_1[3];
tmp_2[4] = tmp_1[4] + tmp_1[6];
// SQRT2DIV2 = sqrt(2)/2
#define SQRT2DIV2 0.7071067811865476
if (Dir == FFT_FORWARD) {
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
}
else {
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
}
data[0] = tmp_2[0] + tmp_2[4];
data[1] = tmp_2[1] + tmp_2[5];
data[2] = tmp_2[2] + tmp_2[6];
data[3] = tmp_2[3] + tmp_2[7];
data[4] = tmp_2[0] - tmp_2[4];
data[5] = tmp_2[1] - tmp_2[5];
data[6] = tmp_2[2] - tmp_2[6];
data[7] = tmp_2[3] - tmp_2[7];
return;
}
else {
compute_1D_Butterfly<Dir>(data, n/2, n_power_of_2 - 1);
compute_1D_Butterfly<Dir>(data + n/2, n/2, n_power_of_2 - 1);
//Original code:
//RealScalar wtemp = std::sin(M_PI/n);
//RealScalar wpi = -std::sin(2 * M_PI/n);
RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
RealScalar wpi;
if (Dir == FFT_FORWARD) {
wpi = m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
}
else {
wpi = 0 - m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
}
const ComplexScalar wp(wtemp, wpi);
ComplexScalar w(1.0, 0.0);
for(Index i = 0; i < n/2; i++) {
ComplexScalar temp(data[i + n/2] * w);
data[i + n/2] = data[i] - temp;
data[i] += temp;
w += w * wp;
}
return;
} }
} }