mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
Speed up tensor FFT by up ~25-50%.
Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_tensor_fft_single_1D_cpu/8 132 134 -1.5% BM_tensor_fft_single_1D_cpu/9 1162 1229 -5.8% BM_tensor_fft_single_1D_cpu/16 199 195 +2.0% BM_tensor_fft_single_1D_cpu/17 2587 2267 +12.4% BM_tensor_fft_single_1D_cpu/32 373 341 +8.6% BM_tensor_fft_single_1D_cpu/33 5922 4879 +17.6% BM_tensor_fft_single_1D_cpu/64 797 675 +15.3% BM_tensor_fft_single_1D_cpu/65 13580 10481 +22.8% BM_tensor_fft_single_1D_cpu/128 1753 1375 +21.6% BM_tensor_fft_single_1D_cpu/129 31426 22789 +27.5% BM_tensor_fft_single_1D_cpu/256 4005 3008 +24.9% BM_tensor_fft_single_1D_cpu/257 70910 49549 +30.1% BM_tensor_fft_single_1D_cpu/512 8989 6524 +27.4% BM_tensor_fft_single_1D_cpu/513 165402 107751 +34.9% BM_tensor_fft_single_1D_cpu/999 198293 115909 +41.5% BM_tensor_fft_single_1D_cpu/1ki 21289 14143 +33.6% BM_tensor_fft_single_1D_cpu/1k 361980 233355 +35.5% BM_tensor_fft_double_1D_cpu/8 138 131 +5.1% BM_tensor_fft_double_1D_cpu/9 1253 1133 +9.6% BM_tensor_fft_double_1D_cpu/16 218 200 +8.3% BM_tensor_fft_double_1D_cpu/17 2770 2392 +13.6% BM_tensor_fft_double_1D_cpu/32 406 368 +9.4% BM_tensor_fft_double_1D_cpu/33 6418 5153 +19.7% BM_tensor_fft_double_1D_cpu/64 856 728 +15.0% BM_tensor_fft_double_1D_cpu/65 14666 11148 +24.0% BM_tensor_fft_double_1D_cpu/128 1913 1502 +21.5% BM_tensor_fft_double_1D_cpu/129 36414 24072 +33.9% BM_tensor_fft_double_1D_cpu/256 4226 3216 +23.9% BM_tensor_fft_double_1D_cpu/257 86638 52059 +39.9% BM_tensor_fft_double_1D_cpu/512 9397 6939 +26.2% BM_tensor_fft_double_1D_cpu/513 203208 114090 +43.9% BM_tensor_fft_double_1D_cpu/999 237841 125583 +47.2% BM_tensor_fft_double_1D_cpu/1ki 20921 15392 +26.4% BM_tensor_fft_double_1D_cpu/1k 455183 250763 +44.9% BM_tensor_fft_single_2D_cpu/8 1051 1005 +4.4% BM_tensor_fft_single_2D_cpu/9 16784 14837 +11.6% BM_tensor_fft_single_2D_cpu/16 4074 3772 +7.4% BM_tensor_fft_single_2D_cpu/17 75802 63884 +15.7% BM_tensor_fft_single_2D_cpu/32 20580 16931 +17.7% BM_tensor_fft_single_2D_cpu/33 345798 278579 +19.4% BM_tensor_fft_single_2D_cpu/64 97548 81237 +16.7% BM_tensor_fft_single_2D_cpu/65 1592701 1227048 +23.0% BM_tensor_fft_single_2D_cpu/128 472318 384303 +18.6% BM_tensor_fft_single_2D_cpu/129 7038351 5445308 +22.6% BM_tensor_fft_single_2D_cpu/256 2309474 1850969 +19.9% BM_tensor_fft_single_2D_cpu/257 31849182 23797538 +25.3% BM_tensor_fft_single_2D_cpu/512 10395194 8077499 +22.3% BM_tensor_fft_single_2D_cpu/513 144053843 104242541 +27.6% BM_tensor_fft_single_2D_cpu/999 279885833 208389718 +25.5% BM_tensor_fft_single_2D_cpu/1ki 45967677 36070985 +21.5% BM_tensor_fft_single_2D_cpu/1k 619727095 456489500 +26.3% BM_tensor_fft_double_2D_cpu/8 1110 1016 +8.5% BM_tensor_fft_double_2D_cpu/9 17957 15768 +12.2% BM_tensor_fft_double_2D_cpu/16 4558 4000 +12.2% BM_tensor_fft_double_2D_cpu/17 79237 66901 +15.6% BM_tensor_fft_double_2D_cpu/32 21494 17699 +17.7% BM_tensor_fft_double_2D_cpu/33 357962 290357 +18.9% BM_tensor_fft_double_2D_cpu/64 105179 87435 +16.9% BM_tensor_fft_double_2D_cpu/65 1617143 1288006 +20.4% BM_tensor_fft_double_2D_cpu/128 512848 419397 +18.2% BM_tensor_fft_double_2D_cpu/129 7271322 5636884 +22.5% BM_tensor_fft_double_2D_cpu/256 2415529 1922032 +20.4% BM_tensor_fft_double_2D_cpu/257 32517952 24462177 +24.8% BM_tensor_fft_double_2D_cpu/512 10724898 8287617 +22.7% BM_tensor_fft_double_2D_cpu/513 146007419 108603266 +25.6% BM_tensor_fft_double_2D_cpu/999 296351330 221885776 +25.1% BM_tensor_fft_double_2D_cpu/1ki 59334166 48357539 +18.5% BM_tensor_fft_double_2D_cpu/1k 666660132 483840349 +27.4%
This commit is contained in:
parent
d90a2dac5e
commit
d5e2ec7447
@ -219,19 +219,56 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
|
||||
ComplexScalar* b = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite);
|
||||
ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * (line_len + 1));
|
||||
if (!is_power_of_two) {
|
||||
ComplexScalar pos_j_base = ComplexScalar(std::cos(M_PI/line_len), std::sin(M_PI/line_len));
|
||||
for (Index j = 0; j < line_len + 1; ++j) {
|
||||
pos_j_base_powered[j] = std::pow(pos_j_base, j * j);
|
||||
// Compute twiddle factors
|
||||
// t_n = exp(sqrt(-1) * pi * n^2 / line_len)
|
||||
// for n = 0, 1,..., line_len-1.
|
||||
// For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
|
||||
pos_j_base_powered[0] = ComplexScalar(1, 0);
|
||||
if (line_len > 1) {
|
||||
const ComplexScalar pos_j_base = ComplexScalar(
|
||||
std::cos(M_PI / line_len), std::sin(M_PI / line_len));
|
||||
pos_j_base_powered[1] = pos_j_base;
|
||||
if (line_len > 2) {
|
||||
const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
|
||||
for (int i = 2; i < line_len + 1; ++i) {
|
||||
pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
|
||||
pos_j_base_powered[i - 1] /
|
||||
pos_j_base_powered[i - 2] * pos_j_base_sq;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Compute twiddle factors
|
||||
// t_n = exp(sqrt(-1) * pi * n^2 / line_len)
|
||||
// for n = 0, 1,..., line_len-1.
|
||||
// For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
|
||||
pos_j_base_powered[0] = ComplexScalar(1, 0);
|
||||
if (line_len > 1) {
|
||||
const ComplexScalar pos_j_base = ComplexScalar(
|
||||
std::cos(M_PI / line_len), std::sin(M_PI / line_len));
|
||||
pos_j_base_powered[1] = pos_j_base;
|
||||
if (line_len > 2) {
|
||||
const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
|
||||
for (int i = 2; i < line_len + 1; ++i) {
|
||||
pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
|
||||
pos_j_base_powered[i - 1] /
|
||||
pos_j_base_powered[i - 2] * pos_j_base_sq;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) {
|
||||
Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
|
||||
const Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
|
||||
|
||||
// get data into line_buf
|
||||
for (Index j = 0; j < line_len; ++j) {
|
||||
Index offset = getIndexFromOffset(base_offset, dim, j);
|
||||
line_buf[j] = buf[offset];
|
||||
const Index stride = m_strides[dim];
|
||||
if (stride == 1) {
|
||||
memcpy(line_buf, &buf[base_offset], line_len*sizeof(ComplexScalar));
|
||||
} else {
|
||||
Index offset = base_offset;
|
||||
for (int j = 0; j < line_len; ++j, offset += stride) {
|
||||
line_buf[j] = buf[offset];
|
||||
}
|
||||
}
|
||||
|
||||
// processs the line
|
||||
@ -243,14 +280,18 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
|
||||
}
|
||||
|
||||
// write back
|
||||
for (Index j = 0; j < line_len; ++j) {
|
||||
const ComplexScalar div_factor = (FFTDir == FFT_FORWARD) ? ComplexScalar(1, 0) : ComplexScalar(line_len, 0);
|
||||
Index offset = getIndexFromOffset(base_offset, dim, j);
|
||||
buf[offset] = line_buf[j] / div_factor;
|
||||
if (FFTDir == FFT_FORWARD && stride == 1) {
|
||||
memcpy(&buf[base_offset], line_buf, line_len*sizeof(ComplexScalar));
|
||||
} else {
|
||||
Index offset = base_offset;
|
||||
const ComplexScalar div_factor = ComplexScalar(1.0 / line_len, 0);
|
||||
for (int j = 0; j < line_len; ++j, offset += stride) {
|
||||
buf[offset] = (FFTDir == FFT_FORWARD) ? line_buf[j] : line_buf[j] * div_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
m_device.deallocate(line_buf);
|
||||
if (!pos_j_base_powered) {
|
||||
if (!is_power_of_two) {
|
||||
m_device.deallocate(a);
|
||||
m_device.deallocate(b);
|
||||
m_device.deallocate(pos_j_base_powered);
|
||||
@ -372,109 +413,130 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
|
||||
}
|
||||
}
|
||||
|
||||
template<int Dir>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(ComplexScalar* data, Index n, Index n_power_of_2) {
|
||||
template <int Dir>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_2(ComplexScalar* data) {
|
||||
ComplexScalar tmp = data[1];
|
||||
data[1] = data[0] - data[1];
|
||||
data[0] += tmp;
|
||||
}
|
||||
|
||||
template <int Dir>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_4(ComplexScalar* data) {
|
||||
ComplexScalar tmp[4];
|
||||
tmp[0] = data[0] + data[1];
|
||||
tmp[1] = data[0] - data[1];
|
||||
tmp[2] = data[2] + data[3];
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]);
|
||||
} else {
|
||||
tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]);
|
||||
}
|
||||
data[0] = tmp[0] + tmp[2];
|
||||
data[1] = tmp[1] + tmp[3];
|
||||
data[2] = tmp[0] - tmp[2];
|
||||
data[3] = tmp[1] - tmp[3];
|
||||
}
|
||||
|
||||
template <int Dir>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_8(ComplexScalar* data) {
|
||||
ComplexScalar tmp_1[8];
|
||||
ComplexScalar tmp_2[8];
|
||||
|
||||
tmp_1[0] = data[0] + data[1];
|
||||
tmp_1[1] = data[0] - data[1];
|
||||
tmp_1[2] = data[2] + data[3];
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1);
|
||||
} else {
|
||||
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1);
|
||||
}
|
||||
tmp_1[4] = data[4] + data[5];
|
||||
tmp_1[5] = data[4] - data[5];
|
||||
tmp_1[6] = data[6] + data[7];
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1);
|
||||
} else {
|
||||
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1);
|
||||
}
|
||||
tmp_2[0] = tmp_1[0] + tmp_1[2];
|
||||
tmp_2[1] = tmp_1[1] + tmp_1[3];
|
||||
tmp_2[2] = tmp_1[0] - tmp_1[2];
|
||||
tmp_2[3] = tmp_1[1] - tmp_1[3];
|
||||
tmp_2[4] = tmp_1[4] + tmp_1[6];
|
||||
// SQRT2DIV2 = sqrt(2)/2
|
||||
#define SQRT2DIV2 0.7071067811865476
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
|
||||
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
|
||||
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
|
||||
} else {
|
||||
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
|
||||
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
|
||||
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
|
||||
}
|
||||
data[0] = tmp_2[0] + tmp_2[4];
|
||||
data[1] = tmp_2[1] + tmp_2[5];
|
||||
data[2] = tmp_2[2] + tmp_2[6];
|
||||
data[3] = tmp_2[3] + tmp_2[7];
|
||||
data[4] = tmp_2[0] - tmp_2[4];
|
||||
data[5] = tmp_2[1] - tmp_2[5];
|
||||
data[6] = tmp_2[2] - tmp_2[6];
|
||||
data[7] = tmp_2[3] - tmp_2[7];
|
||||
}
|
||||
|
||||
template <int Dir>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_1D_merge(
|
||||
ComplexScalar* data, int n, int n_power_of_2) {
|
||||
// Original code:
|
||||
// RealScalar wtemp = std::sin(M_PI/n);
|
||||
// RealScalar wpi = -std::sin(2 * M_PI/n);
|
||||
const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
|
||||
const RealScalar wpi = (Dir == FFT_FORWARD)
|
||||
? m_minus_sin_2_PI_div_n_LUT[n_power_of_2]
|
||||
: -m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
|
||||
|
||||
const ComplexScalar wp(wtemp, wpi);
|
||||
const ComplexScalar wp_one = wp + ComplexScalar(1, 0);
|
||||
const ComplexScalar wp_one_2 = wp_one * wp_one;
|
||||
const ComplexScalar wp_one_3 = wp_one_2 * wp_one;
|
||||
const ComplexScalar wp_one_4 = wp_one_3 * wp_one;
|
||||
const int n2 = n / 2;
|
||||
ComplexScalar w(1.0, 0.0);
|
||||
for (int i = 0; i < n2; i += 4) {
|
||||
ComplexScalar temp0(data[i + n2] * w);
|
||||
ComplexScalar temp1(data[i + 1 + n2] * w * wp_one);
|
||||
ComplexScalar temp2(data[i + 2 + n2] * w * wp_one_2);
|
||||
ComplexScalar temp3(data[i + 3 + n2] * w * wp_one_3);
|
||||
w = w * wp_one_4;
|
||||
|
||||
data[i + n2] = data[i] - temp0;
|
||||
data[i] += temp0;
|
||||
|
||||
data[i + 1 + n2] = data[i + 1] - temp1;
|
||||
data[i + 1] += temp1;
|
||||
|
||||
data[i + 2 + n2] = data[i + 2] - temp2;
|
||||
data[i + 2] += temp2;
|
||||
|
||||
data[i + 3 + n2] = data[i + 3] - temp3;
|
||||
data[i + 3] += temp3;
|
||||
}
|
||||
}
|
||||
|
||||
template <int Dir>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(
|
||||
ComplexScalar* data, int n, int n_power_of_2) {
|
||||
eigen_assert(isPowerOfTwo(n));
|
||||
if (n == 1) {
|
||||
return;
|
||||
}
|
||||
else if (n == 2) {
|
||||
ComplexScalar tmp = data[1];
|
||||
data[1] = data[0] - data[1];
|
||||
data[0] += tmp;
|
||||
return;
|
||||
}
|
||||
else if (n == 4) {
|
||||
ComplexScalar tmp[4];
|
||||
tmp[0] = data[0] + data[1];
|
||||
tmp[1] = data[0] - data[1];
|
||||
tmp[2] = data[2] + data[3];
|
||||
if(Dir == FFT_FORWARD) {
|
||||
tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]);
|
||||
}
|
||||
else {
|
||||
tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]);
|
||||
}
|
||||
data[0] = tmp[0] + tmp[2];
|
||||
data[1] = tmp[1] + tmp[3];
|
||||
data[2] = tmp[0] - tmp[2];
|
||||
data[3] = tmp[1] - tmp[3];
|
||||
return;
|
||||
}
|
||||
else if (n == 8) {
|
||||
ComplexScalar tmp_1[8];
|
||||
ComplexScalar tmp_2[8];
|
||||
|
||||
tmp_1[0] = data[0] + data[1];
|
||||
tmp_1[1] = data[0] - data[1];
|
||||
tmp_1[2] = data[2] + data[3];
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1);
|
||||
}
|
||||
else {
|
||||
tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1);
|
||||
}
|
||||
tmp_1[4] = data[4] + data[5];
|
||||
tmp_1[5] = data[4] - data[5];
|
||||
tmp_1[6] = data[6] + data[7];
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1);
|
||||
}
|
||||
else {
|
||||
tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1);
|
||||
}
|
||||
tmp_2[0] = tmp_1[0] + tmp_1[2];
|
||||
tmp_2[1] = tmp_1[1] + tmp_1[3];
|
||||
tmp_2[2] = tmp_1[0] - tmp_1[2];
|
||||
tmp_2[3] = tmp_1[1] - tmp_1[3];
|
||||
tmp_2[4] = tmp_1[4] + tmp_1[6];
|
||||
// SQRT2DIV2 = sqrt(2)/2
|
||||
#define SQRT2DIV2 0.7071067811865476
|
||||
if (Dir == FFT_FORWARD) {
|
||||
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
|
||||
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
|
||||
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
|
||||
}
|
||||
else {
|
||||
tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
|
||||
tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
|
||||
tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
|
||||
}
|
||||
data[0] = tmp_2[0] + tmp_2[4];
|
||||
data[1] = tmp_2[1] + tmp_2[5];
|
||||
data[2] = tmp_2[2] + tmp_2[6];
|
||||
data[3] = tmp_2[3] + tmp_2[7];
|
||||
data[4] = tmp_2[0] - tmp_2[4];
|
||||
data[5] = tmp_2[1] - tmp_2[5];
|
||||
data[6] = tmp_2[2] - tmp_2[6];
|
||||
data[7] = tmp_2[3] - tmp_2[7];
|
||||
|
||||
return;
|
||||
}
|
||||
else {
|
||||
compute_1D_Butterfly<Dir>(data, n/2, n_power_of_2 - 1);
|
||||
compute_1D_Butterfly<Dir>(data + n/2, n/2, n_power_of_2 - 1);
|
||||
//Original code:
|
||||
//RealScalar wtemp = std::sin(M_PI/n);
|
||||
//RealScalar wpi = -std::sin(2 * M_PI/n);
|
||||
RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
|
||||
RealScalar wpi;
|
||||
if (Dir == FFT_FORWARD) {
|
||||
wpi = m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
|
||||
}
|
||||
else {
|
||||
wpi = 0 - m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
|
||||
}
|
||||
|
||||
const ComplexScalar wp(wtemp, wpi);
|
||||
ComplexScalar w(1.0, 0.0);
|
||||
for(Index i = 0; i < n/2; i++) {
|
||||
ComplexScalar temp(data[i + n/2] * w);
|
||||
data[i + n/2] = data[i] - temp;
|
||||
data[i] += temp;
|
||||
w += w * wp;
|
||||
}
|
||||
return;
|
||||
if (n > 8) {
|
||||
compute_1D_Butterfly<Dir>(data, n / 2, n_power_of_2 - 1);
|
||||
compute_1D_Butterfly<Dir>(data + n / 2, n / 2, n_power_of_2 - 1);
|
||||
butterfly_1D_merge<Dir>(data, n, n_power_of_2);
|
||||
} else if (n == 8) {
|
||||
butterfly_8<Dir>(data);
|
||||
} else if (n == 4) {
|
||||
butterfly_4<Dir>(data);
|
||||
} else if (n == 2) {
|
||||
butterfly_2<Dir>(data);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user