Compute Library
 21.11
NEFFTRadixStageKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/Utils.h"
36 
37 #include <arm_neon.h>
38 #include <cmath>
39 #include <complex>
40 #include <map>
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 // PI constant (from cmath)
47 constexpr float kPi = float(M_PI);
48 
49 // Constant used in the fft_3 kernel
50 constexpr float kSqrt3Div2 = 0.866025403784438;
51 
52 // Constants used in the fft_5 kernel
53 constexpr float kW5_0 = 0.30901699437494f;
54 constexpr float kW5_1 = 0.95105651629515f;
55 constexpr float kW5_2 = 0.80901699437494f;
56 constexpr float kW5_3 = 0.58778525229247f;
57 
58 // Constants used in the fft_7 kernel
59 constexpr float kW7_0 = 0.62348980185873f;
60 constexpr float kW7_1 = 0.78183148246802f;
61 constexpr float kW7_2 = 0.22252093395631f;
62 constexpr float kW7_3 = 0.97492791218182f;
63 constexpr float kW7_4 = 0.90096886790241f;
64 constexpr float kW7_5 = 0.43388373911755f;
65 
66 // Constant used in the fft_8 kernel
67 constexpr float kSqrt2Div2 = 0.707106781186548;
68 
69 float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
70 {
71  using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
72 
73  const float32x2_t mask = { -1.0, 1.0 };
74  const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
75  const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
76 
77  float32x2_t res = wrapper::vmul(tmp0, b);
78 
79  b = wrapper::vrev64(b);
80  b = wrapper::vmul(b, mask);
81  res = wrapper::vmla(res, tmp1, b);
82 
83  return res;
84 }
85 
86 float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
87 {
88  const float a_r = wrapper::vgetlane(a, 0);
89  const float a_i = wrapper::vgetlane(a, 1);
90 
91  const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
92  return out;
93 }
94 
95 float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_t d, float32x2_t e)
96 {
97  const auto t0 = wrapper::vadd(a, b);
98  const auto t1 = wrapper::vadd(c, d);
99  const auto t2 = wrapper::vadd(t0, t1);
100  return wrapper::vadd(t2, e);
101 }
102 
103 float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
104 {
105  const auto t0 = wrapper::vadd(x1, x2);
106  const auto t1 = wrapper::vadd(x3, x4);
107  const auto t2 = wrapper::vadd(x5, x6);
108  const auto t00 = wrapper::vadd(t0, t1);
109  const auto t01 = wrapper::vadd(t2, x7);
110 
111  return wrapper::vadd(t00, t01);
112 }
113 
114 float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
115 {
116  const auto t0 = wrapper::vadd(x1, x2);
117  const auto t1 = wrapper::vadd(x3, x4);
118  const auto t2 = wrapper::vadd(x5, x6);
119  const auto t3 = wrapper::vadd(x7, x8);
120  const auto t00 = wrapper::vadd(t0, t1);
121  const auto t01 = wrapper::vadd(t2, t3);
122 
123  return wrapper::vadd(t00, t01);
124 }
125 
126 void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
127 {
128  float32x2_t a = x;
129  float32x2_t b = c_mul_neon(w, y);
130 
131  x = wrapper::vadd(a, b);
132  y = wrapper::vsub(a, b);
133 }
134 
135 void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
136 {
137  float32x2_t a = x;
138  float32x2_t b = c_mul_neon(w, y);
139  float32x2_t c = c_mul_neon(w2, z);
140 
141  x = wrapper::vadd(a, b);
142  x = wrapper::vadd(x, c);
143 
144  const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
145  const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
146 
147  y = z = wrapper::vsub(a, v1);
148  y = wrapper::vadd(y, v2);
149  z = wrapper::vsub(z, v2);
150 }
151 
152 void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
153 {
154  float32x2_t a = x1;
155  float32x2_t b = c_mul_neon(w, x2);
156  float32x2_t c = c_mul_neon(w2, x3);
157  float32x2_t d = c_mul_neon(w3, x4);
158 
159  const auto x11 = wrapper::vadd(a, b);
160  const auto x12 = wrapper::vadd(c, d);
161  x1 = wrapper::vadd(x11, x12);
162 
163  const auto x21 = wrapper::vadd(a, c_mul_neon_img(b, -1));
164  const auto x22 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, 1.f));
165  x2 = wrapper::vadd(x21, x22);
166 
167  const auto x31 = wrapper::vadd(a, wrapper::vneg(b));
168  const auto x32 = wrapper::vadd(c, wrapper::vneg(d));
169  x3 = wrapper::vadd(x31, x32);
170 
171  const auto x41 = wrapper::vadd(a, c_mul_neon_img(b, 1));
172  const auto x42 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, -1));
173  x4 = wrapper::vadd(x41, x42);
174 }
175 
176 void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
177 {
178  const auto a = x1;
179  const auto b = c_mul_neon(w, x2);
180  const auto c = c_mul_neon(w2, x3);
181  const auto d = c_mul_neon(w3, x4);
182  const auto e = c_mul_neon(w4, x5);
183 
184  const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
185  const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
186  const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
187  const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
188 
189  const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
190  const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
191  const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
192  const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
193 
194  const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
195  const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
196  const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
197  const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
198 
199  const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
200  const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
201  const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
202  const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
203 
204  x1 = reduce_sum_5(a, b, c, d, e);
205  x2 = reduce_sum_5(a, b0, c0, d0, e0);
206  x3 = reduce_sum_5(a, b1, c1, d1, e1);
207  x4 = reduce_sum_5(a, b2, c2, d2, e2);
208  x5 = reduce_sum_5(a, b3, c3, d3, e3);
209 }
210 
211 void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
212  const float32x2_t &w4,
213  const float32x2_t &w5, const float32x2_t &w6)
214 {
215  const auto a = x1;
216  const auto b = c_mul_neon(w, x2);
217  const auto c = c_mul_neon(w2, x3);
218  const auto d = c_mul_neon(w3, x4);
219  const auto e = c_mul_neon(w4, x5);
220  const auto f = c_mul_neon(w5, x6);
221  const auto g = c_mul_neon(w6, x7);
222 
223  const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
224  const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
225  const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
226  const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
227  const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
228  const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
229 
230  const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
231  const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
232  const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
233  const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
234  const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
235  const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
236 
237  const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
238  const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
239  const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
240  const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
241  const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
242  const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
243 
244  const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
245  const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
246  const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
247  const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
248  const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
249  const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
250 
251  const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
252  const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
253  const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
254  const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
255  const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
256  const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
257 
258  const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
259  const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
260  const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
261  const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
262  const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
263  const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
264 
265  x1 = reduce_sum_7(a, b, c, d, e, f, g);
266  x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
267  x3 = reduce_sum_7(a, b1, c1, d1, e1, f1, g1);
268  x4 = reduce_sum_7(a, b2, c2, d2, e2, f2, g2);
269  x5 = reduce_sum_7(a, b3, c3, d3, e3, f3, g3);
270  x6 = reduce_sum_7(a, b4, c4, d4, e4, f4, g4);
271  x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
272 }
273 
274 void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
275  const float32x2_t &w3,
276  const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
277  const float32x2_t &w7)
278 {
279  const auto a = x1;
280  const auto b = c_mul_neon(w, x2);
281  const auto c = c_mul_neon(w2, x3);
282  const auto d = c_mul_neon(w3, x4);
283  const auto e = c_mul_neon(w4, x5);
284  const auto f = c_mul_neon(w5, x6);
285  const auto g = c_mul_neon(w6, x7);
286  const auto h = c_mul_neon(w7, x8);
287 
288  const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
289  const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
290  const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
291  const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
292  const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
293  const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
294  const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
295 
296  const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
297  const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
298  const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
299  const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
300  const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
301  const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
302  const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
303 
304  const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
305  const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
306  const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
307  const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
308  const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
309  const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
310  const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
311 
312  const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
313  const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
314  const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
315  const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
316  const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
317  const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
318  const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
319 
320  const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
321  const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
322  const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
323  const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
324  const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
325  const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
326  const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
327 
328  const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
329  const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
330  const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
331  const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
332  const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
333  const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
334  const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
335 
336  const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
337  const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
338  const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
339  const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
340  const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
341  const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
342  const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
343 
344  x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
345  x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
346  x3 = reduce_sum_8(a, b1, c1, d1, e1, f1, g1, h1);
347  x4 = reduce_sum_8(a, b2, c2, d2, e2, f2, g2, h2);
348  x5 = reduce_sum_8(a, b3, c3, d3, e3, f3, g3, h3);
349  x6 = reduce_sum_8(a, b4, c4, d4, e4, f4, g4, h4);
350  x7 = reduce_sum_8(a, b5, c5, d5, e5, f5, g5, h5);
351  x8 = reduce_sum_8(a, b6, c6, d6, e6, f6, g6, h6);
352 }
353 
354 template <bool first_stage>
355 void fft_radix_2_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
356 {
357  float32x2_t w{ 1.0f, 0.0f };
358  for(unsigned int j = 0; j < Nx; j++)
359  {
360  for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
361  {
362  auto a = float32x2_t{ 0, 0 };
363  auto b = float32x2_t{ 0, 0 };
364 
365  // Load inputs
366  if(first_stage)
367  {
368  const auto ab = wrapper::vloadq(in + k);
369  a = wrapper::vgetlow(ab);
370  b = wrapper::vgethigh(ab);
371  }
372  else
373  {
374  a = wrapper::vload(in + k);
375  b = wrapper::vload(in + k + 2 * Nx);
376  }
377 
378  // Base-case prime transform
379  fft_2(a, b, w);
380 
381  // Write outputs
382  if(first_stage)
383  {
384  wrapper::vstore(out + k, wrapper::vcombine(a, b));
385  }
386  else
387  {
388  wrapper::vstore(out + k, a);
389  wrapper::vstore(out + k + 2 * Nx, b);
390  }
391  }
392 
393  w = c_mul_neon(w, w_m);
394  }
395 }
396 
397 void fft_radix_2_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
398 {
399  float32x2_t w{ 1.0f, 0.0f };
400  for(unsigned int j = 0; j < Nx; j++)
401  {
402  for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
403  {
404  // Load inputs
405  float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
406  float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
407 
408  // Base-case prime transform
409  fft_2(a, b, w);
410 
411  // Write outputs
412  wrapper::vstore(out + (N + out_pad_x) * k, a);
413  wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
414  }
415 
416  w = c_mul_neon(w, w_m);
417  }
418 }
419 
420 template <bool first_stage>
421 void fft_radix_3_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
422 {
423  float32x2_t w{ 1.0f, 0.0f };
424  for(unsigned int j = 0; j < Nx; j++)
425  {
426  const auto w2 = c_mul_neon(w, w);
427 
428  for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
429  {
430  // Load inputs
431  float32x2_t a = { 0, 0 };
432  float32x2_t b = { 0, 0 };
433  float32x2_t c = { 0, 0 };
434  if(first_stage)
435  {
436  const auto ab = wrapper::vloadq(in + k);
437  a = wrapper::vgetlow(ab);
438  b = wrapper::vgethigh(ab);
439  }
440  else
441  {
442  a = wrapper::vload(in + k);
443  b = wrapper::vload(in + k + 2 * Nx);
444  }
445  c = wrapper::vload(in + k + 4 * Nx);
446 
447  // Base-case prime transform
448  fft_3(a, b, c, w, w2);
449 
450  if(first_stage)
451  {
452  wrapper::vstore(out + k, wrapper::vcombine(a, b));
453  }
454  else
455  {
456  wrapper::vstore(out + k, a);
457  wrapper::vstore(out + k + 2 * Nx, b);
458  }
459  wrapper::vstore(out + k + 4 * Nx, c);
460  }
461  w = c_mul_neon(w, w_m);
462  }
463 }
464 
465 void fft_radix_3_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
466 {
467  float32x2_t w{ 1.0f, 0.0f };
468  for(unsigned int j = 0; j < Nx; j++)
469  {
470  const auto w2 = c_mul_neon(w, w);
471 
472  for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
473  {
474  // Load inputs
475  float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
476  float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
477  float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
478 
479  // Base-case prime transform
480  fft_3(a, b, c, w, w2);
481 
482  // Store the output
483  wrapper::vstore(out + (N + out_pad_x) * k, a);
484  wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
485  wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
486  }
487  w = c_mul_neon(w, w_m);
488  }
489 }
490 
491 template <bool first_stage>
492 void fft_radix_4_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
493 {
494  float32x2_t w{ 1.0f, 0.0f };
495  for(unsigned int j = 0; j < Nx; j++)
496  {
497  const auto w2 = c_mul_neon(w, w);
498  const auto w3 = c_mul_neon(w2, w);
499 
500  for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
501  {
502  float32x2_t a = { 0, 0 };
503  float32x2_t b = { 0, 0 };
504  float32x2_t c = { 0, 0 };
505  float32x2_t d = { 0, 0 };
506  if(first_stage)
507  {
508  const auto ab = wrapper::vloadq(in + k);
509  const auto cd = wrapper::vloadq(in + k + 4 * Nx);
510  a = wrapper::vgetlow(ab);
511  b = wrapper::vgethigh(ab);
512  c = wrapper::vgetlow(cd);
513  d = wrapper::vgethigh(cd);
514  }
515  else
516  {
517  // Load inputs
518  a = wrapper::vload(in + k);
519  b = wrapper::vload(in + k + 2 * Nx);
520  c = wrapper::vload(in + k + 4 * Nx);
521  d = wrapper::vload(in + k + 6 * Nx);
522  }
523 
524  // Base-case prime transform
525  fft_4(a, b, c, d, w, w2, w3);
526 
527  if(first_stage)
528  {
529  wrapper::vstore(out + k, wrapper::vcombine(a, b));
530  wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
531  }
532  else
533  {
534  wrapper::vstore(out + k, a);
535  wrapper::vstore(out + k + 2 * Nx, b);
536  wrapper::vstore(out + k + 4 * Nx, c);
537  wrapper::vstore(out + k + 6 * Nx, d);
538  }
539  }
540 
541  w = c_mul_neon(w, w_m);
542  }
543 }
544 
545 void fft_radix_4_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
546 {
547  float32x2_t w{ 1.0f, 0.0f };
548  for(unsigned int j = 0; j < Nx; j++)
549  {
550  const auto w2 = c_mul_neon(w, w);
551  const auto w3 = c_mul_neon(w2, w);
552 
553  for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
554  {
555  // Load inputs
556  float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
557  float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
558  float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
559  float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
560 
561  // Base-case prime transform
562  fft_4(a, b, c, d, w, w2, w3);
563 
564  wrapper::vstore(out + (N + out_pad_x) * k, a);
565  wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
566  wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
567  wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
568  }
569 
570  w = c_mul_neon(w, w_m);
571  }
572 }
573 
574 template <bool first_stage>
575 void fft_radix_5_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
576 {
577  float32x2_t w{ 1.0f, 0.0f };
578  for(unsigned int j = 0; j < Nx; j++)
579  {
580  const float32x2_t w2 = c_mul_neon(w, w);
581  const float32x2_t w3 = c_mul_neon(w2, w);
582  const float32x2_t w4 = c_mul_neon(w3, w);
583 
584  for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
585  {
586  float32x2_t a = { 0, 0 };
587  float32x2_t b = { 0, 0 };
588  float32x2_t c = { 0, 0 };
589  float32x2_t d = { 0, 0 };
590  float32x2_t e = { 0, 0 };
591 
592  // Load inputs
593  if(first_stage)
594  {
595  const auto ab = wrapper::vloadq(in + k);
596  const auto cd = wrapper::vloadq(in + k + 4 * Nx);
597 
598  a = wrapper::vgetlow(ab);
599  b = wrapper::vgethigh(ab);
600  c = wrapper::vgetlow(cd);
601  d = wrapper::vgethigh(cd);
602  }
603  else
604  {
605  a = wrapper::vload(in + k);
606  b = wrapper::vload(in + k + 2 * Nx);
607  c = wrapper::vload(in + k + 4 * Nx);
608  d = wrapper::vload(in + k + 6 * Nx);
609  }
610  e = wrapper::vload(in + k + 8 * Nx);
611 
612  // Base-case prime transform
613  fft_5(a, b, c, d, e, w, w2, w3, w4);
614 
615  // Store outputs
616  if(first_stage)
617  {
618  wrapper::vstore(out + k, wrapper::vcombine(a, b));
619  wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
620  }
621  else
622  {
623  wrapper::vstore(out + k, a);
624  wrapper::vstore(out + k + 2 * Nx, b);
625  wrapper::vstore(out + k + 4 * Nx, c);
626  wrapper::vstore(out + k + 6 * Nx, d);
627  }
628  wrapper::vstore(out + k + 8 * Nx, e);
629  }
630 
631  w = c_mul_neon(w, w_m);
632  }
633 }
634 
635 void fft_radix_5_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
636 {
637  float32x2_t w{ 1.0f, 0.0f };
638  for(unsigned int j = 0; j < Nx; j++)
639  {
640  const float32x2_t w2 = c_mul_neon(w, w);
641  const float32x2_t w3 = c_mul_neon(w2, w);
642  const float32x2_t w4 = c_mul_neon(w3, w);
643 
644  for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
645  {
646  // Load inputs
647  float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
648  float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
649  float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
650  float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
651  float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
652 
653  // Base-case prime transform
654  fft_5(a, b, c, d, e, w, w2, w3, w4);
655 
656  // Store outputs
657  wrapper::vstore(out + (N + out_pad_x) * k, a);
658  wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
659  wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
660  wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
661  wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
662  }
663 
664  w = c_mul_neon(w, w_m);
665  }
666 }
667 
668 template <bool first_stage>
669 void fft_radix_7_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
670 {
671  float32x2_t w{ 1.0f, 0.0f };
672  for(unsigned int j = 0; j < Nx; j++)
673  {
674  const float32x2_t w2 = c_mul_neon(w, w);
675  const float32x2_t w3 = c_mul_neon(w2, w);
676  const float32x2_t w4 = c_mul_neon(w3, w);
677  const float32x2_t w5 = c_mul_neon(w4, w);
678  const float32x2_t w6 = c_mul_neon(w5, w);
679 
680  for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
681  {
682  float32x2_t a = { 0, 0 };
683  float32x2_t b = { 0, 0 };
684  float32x2_t c = { 0, 0 };
685  float32x2_t d = { 0, 0 };
686  float32x2_t e = { 0, 0 };
687  float32x2_t f = { 0, 0 };
688  float32x2_t g = { 0, 0 };
689 
690  // Load inputs
691  if(first_stage)
692  {
693  const auto ab = wrapper::vloadq(in + k);
694  const auto cd = wrapper::vloadq(in + k + 4 * Nx);
695  const auto ef = wrapper::vloadq(in + k + 8 * Nx);
696 
697  a = wrapper::vgetlow(ab);
698  b = wrapper::vgethigh(ab);
699  c = wrapper::vgetlow(cd);
700  d = wrapper::vgethigh(cd);
701  e = wrapper::vgetlow(ef);
702  f = wrapper::vgethigh(ef);
703  }
704  else
705  {
706  a = wrapper::vload(in + k);
707  b = wrapper::vload(in + k + 2 * Nx);
708  c = wrapper::vload(in + k + 4 * Nx);
709  d = wrapper::vload(in + k + 6 * Nx);
710  e = wrapper::vload(in + k + 8 * Nx);
711  f = wrapper::vload(in + k + 10 * Nx);
712  }
713  g = wrapper::vload(in + k + 12 * Nx);
714 
715  // Base-case prime transform
716  fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
717 
718  if(first_stage)
719  {
720  wrapper::vstore(out + k, wrapper::vcombine(a, b));
721  wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
722  wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
723  }
724  else
725  {
726  wrapper::vstore(out + k, a);
727  wrapper::vstore(out + k + 2 * Nx, b);
728  wrapper::vstore(out + k + 4 * Nx, c);
729  wrapper::vstore(out + k + 6 * Nx, d);
730  wrapper::vstore(out + k + 8 * Nx, e);
731  wrapper::vstore(out + k + 10 * Nx, f);
732  }
733  wrapper::vstore(out + k + 12 * Nx, g);
734  }
735 
736  w = c_mul_neon(w, w_m);
737  }
738 }
739 
740 void fft_radix_7_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
741 {
742  float32x2_t w{ 1.0f, 0.0f };
743  for(unsigned int j = 0; j < Nx; j++)
744  {
745  const float32x2_t w2 = c_mul_neon(w, w);
746  const float32x2_t w3 = c_mul_neon(w2, w);
747  const float32x2_t w4 = c_mul_neon(w3, w);
748  const float32x2_t w5 = c_mul_neon(w4, w);
749  const float32x2_t w6 = c_mul_neon(w5, w);
750 
751  for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
752  {
753  // Load inputs
754  float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
755  float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
756  float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
757  float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
758  float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
759  float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
760  float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
761 
762  // Base-case prime transform
763  fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
764 
765  // Store outputs
766  wrapper::vstore(out + (N + out_pad_x) * k, a);
767  wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
768  wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
769  wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
770  wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
771  wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
772  wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
773  }
774 
775  w = c_mul_neon(w, w_m);
776  }
777 }
778 
779 template <bool first_stage>
780 void fft_radix_8_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
781 {
782  float32x2_t w{ 1.0f, 0.0f };
783  for(unsigned int j = 0; j < Nx; j++)
784  {
785  const float32x2_t w2 = c_mul_neon(w, w);
786  const float32x2_t w3 = c_mul_neon(w2, w);
787  const float32x2_t w4 = c_mul_neon(w3, w);
788  const float32x2_t w5 = c_mul_neon(w4, w);
789  const float32x2_t w6 = c_mul_neon(w5, w);
790  const float32x2_t w7 = c_mul_neon(w6, w);
791 
792  for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
793  {
794  // Load inputs
795  float32x2_t a = { 0, 0 };
796  float32x2_t b = { 0, 0 };
797  float32x2_t c = { 0, 0 };
798  float32x2_t d = { 0, 0 };
799  float32x2_t e = { 0, 0 };
800  float32x2_t f = { 0, 0 };
801  float32x2_t g = { 0, 0 };
802  float32x2_t h = { 0, 0 };
803 
804  // Base-case prime transform
805  if(first_stage)
806  {
807  const auto ab = wrapper::vloadq(in + k);
808  const auto cd = wrapper::vloadq(in + k + 4 * Nx);
809  const auto ef = wrapper::vloadq(in + k + 8 * Nx);
810  const auto gh = wrapper::vloadq(in + k + 12 * Nx);
811 
812  a = wrapper::vgetlow(ab);
813  b = wrapper::vgethigh(ab);
814  c = wrapper::vgetlow(cd);
815  d = wrapper::vgethigh(cd);
816  e = wrapper::vgetlow(ef);
817  f = wrapper::vgethigh(ef);
818  g = wrapper::vgetlow(gh);
819  h = wrapper::vgethigh(gh);
820  }
821  else
822  {
823  a = wrapper::vload(in + k);
824  b = wrapper::vload(in + k + 2 * Nx);
825  c = wrapper::vload(in + k + 4 * Nx);
826  d = wrapper::vload(in + k + 6 * Nx);
827  e = wrapper::vload(in + k + 8 * Nx);
828  f = wrapper::vload(in + k + 10 * Nx);
829  g = wrapper::vload(in + k + 12 * Nx);
830  h = wrapper::vload(in + k + 14 * Nx);
831  }
832 
833  // Apply twiddle factors
834  fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
835 
836  // Store outputs
837  if(first_stage)
838  {
839  wrapper::vstore(out + k, wrapper::vcombine(a, b));
840  wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
841  wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
842  wrapper::vstore(out + k + 12 * Nx, wrapper::vcombine(g, h));
843  }
844  else
845  {
846  wrapper::vstore(out + k, a);
847  wrapper::vstore(out + k + 2 * Nx, b);
848  wrapper::vstore(out + k + 4 * Nx, c);
849  wrapper::vstore(out + k + 6 * Nx, d);
850  wrapper::vstore(out + k + 8 * Nx, e);
851  wrapper::vstore(out + k + 10 * Nx, f);
852  wrapper::vstore(out + k + 12 * Nx, g);
853  wrapper::vstore(out + k + 14 * Nx, h);
854  }
855  }
856 
857  w = c_mul_neon(w, w_m);
858  }
859 }
860 
861 void fft_radix_8_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
862 {
863  float32x2_t w{ 1.0f, 0.0f };
864  for(unsigned int j = 0; j < Nx; j++)
865  {
866  const float32x2_t w2 = c_mul_neon(w, w);
867  const float32x2_t w3 = c_mul_neon(w2, w);
868  const float32x2_t w4 = c_mul_neon(w3, w);
869  const float32x2_t w5 = c_mul_neon(w4, w);
870  const float32x2_t w6 = c_mul_neon(w5, w);
871  const float32x2_t w7 = c_mul_neon(w6, w);
872 
873  for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
874  {
875  // Load inputs
876  float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
877  float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
878  float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
879  float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
880  float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
881  float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
882  float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
883  float32x2_t h = wrapper::vload(in + (N + in_pad_x) * (k + 14 * Nx));
884 
885  // Base-case prime transform
886  fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
887 
888  // Store outputs
889  wrapper::vstore(out + (N + out_pad_x) * k, a);
890  wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
891  wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
892  wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
893  wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
894  wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
895  wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
896  wrapper::vstore(out + (N + out_pad_x) * (k + 14 * Nx), h);
897  }
898 
899  w = c_mul_neon(w, w_m);
900  }
901 }
902 
903 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
904 {
906  ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
908  ARM_COMPUTE_UNUSED(config);
909 
910  // Checks performed when output is configured
911  if((output != nullptr) && (output->total_size() != 0))
912  {
915  }
916 
917  return Status{};
918 }
919 
920 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
921 {
922  ARM_COMPUTE_UNUSED(config);
923 
924  if(output != nullptr)
925  {
926  auto_init_if_empty(*output, *input);
927  }
928 
929  Window win = calculate_max_window(*input, Steps());
930 
931  return std::make_pair(Status{}, win);
932 }
933 } // namespace
934 
936  : _input(nullptr), _output(nullptr), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
937 {
938 }
939 
940 void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config)
941 {
942  // FFT table axis 0: [radix, first_stage]
943  static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
944 
945  if(fft_table_axis0.empty())
946  {
947  fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
948  fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
949  fft_table_axis0[4][false] = &fft_radix_4_axes_0<false>;
950  fft_table_axis0[5][false] = &fft_radix_5_axes_0<false>;
951  fft_table_axis0[7][false] = &fft_radix_7_axes_0<false>;
952  fft_table_axis0[8][false] = &fft_radix_8_axes_0<false>;
953 
954  fft_table_axis0[2][true] = &fft_radix_2_axes_0<true>;
955  fft_table_axis0[3][true] = &fft_radix_3_axes_0<true>;
956  fft_table_axis0[4][true] = &fft_radix_4_axes_0<true>;
957  fft_table_axis0[5][true] = &fft_radix_5_axes_0<true>;
958  fft_table_axis0[7][true] = &fft_radix_7_axes_0<true>;
959  fft_table_axis0[8][true] = &fft_radix_8_axes_0<true>;
960  }
961 
962  _func_0 = fft_table_axis0[config.radix][config.is_first_stage];
963 }
964 
965 void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config)
966 {
967  // FFT table axis 1: [radix, first_stage]
968  static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
969 
970  if(fft_table_axis1.empty())
971  {
972  fft_table_axis1[2] = &fft_radix_2_axes_1;
973  fft_table_axis1[3] = &fft_radix_3_axes_1;
974  fft_table_axis1[4] = &fft_radix_4_axes_1;
975  fft_table_axis1[5] = &fft_radix_5_axes_1;
976  fft_table_axis1[7] = &fft_radix_7_axes_1;
977  fft_table_axis1[8] = &fft_radix_8_axes_1;
978  }
979 
980  _func_1 = fft_table_axis1[config.radix];
981 }
982 
984 {
986 
987  // Output auto inizialitation if not yet initialized
988  if(output != nullptr)
989  {
990  auto_init_if_empty(*output->info(), *input->info()->clone());
991  }
992 
993  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
994 
995  _input = input;
996  _output = (output == nullptr) ? input : output;
997  _Nx = config.Nx;
998  _axis = config.axis;
999  _radix = config.radix;
1000 
1001  switch(config.axis)
1002  {
1003  case 0:
1004  set_radix_stage_axis0(config);
1005  break;
1006  case 1:
1007  set_radix_stage_axis1(config);
1008  break;
1009  default:
1010  ARM_COMPUTE_ERROR("Axis not supported");
1011  break;
1012  }
1013 
1014  // Configure kernel window
1015  auto win_config = validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr, config);
1016  ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1017  INEKernel::configure(win_config.second);
1018 }
1019 
1021 {
1022  const bool run_in_place = (output == nullptr) || (output == input);
1023  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
1024  ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1025  (run_in_place) ? nullptr : output->clone().get(),
1026  config)
1027  .first);
1028 
1029  return Status{};
1030 }
1031 
1033 {
1034  return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
1035 }
1036 
1038 {
1041  ARM_COMPUTE_UNUSED(info);
1042 
1043  Window input_window = window;
1044  input_window.set(_axis, 0);
1045 
1046  Iterator in(_input, input_window);
1047  Iterator out(_output, input_window);
1048 
1049  // Precompute FFT constants
1050  const unsigned int NxRadix = _radix * _Nx;
1051  const float alpha = 2.0f * kPi / float(NxRadix);
1052  const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
1053 
1054  if(_axis == 0)
1055  {
1056  const unsigned int N = _input->info()->dimension(0);
1057  execute_window_loop(input_window, [&](const Coordinates &)
1058  {
1059  _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
1060  },
1061  in, out);
1062  }
1063  else
1064  {
1065  const unsigned int N = _input->info()->dimension(0);
1066  const unsigned int M = _input->info()->dimension(1);
1067  execute_window_loop(input_window, [&](const Coordinates &)
1068  {
1069  _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M,
1070  _input->info()->padding().right + _input->info()->padding().left,
1071  _output->info()->padding().right + _output->info()->padding().left);
1072  },
1073  in, out);
1074  }
1075 
1078 }
1079 } // namespace arm_compute
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
SimpleTensor< float > w
Definition: DFT.cpp:156
Traits defined on Arm® Neon™ vectors.
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
SimpleTensor< float > b
Definition: DFT.cpp:157
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
uint8x16_t vloadq(const uint8_t *ptr)
Definition: load.h:58
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
uint8x8_t vadd(const uint8x8_t &a, const uint8x8_t &b)
Definition: add.h:39
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
Static function to check if given info will lead to a valid configuration of NEFFTRadixStageKernel.
1 channel, 1 F32 per channel
void configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
Set the input and output tensors.
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
uint8x8_t vsub(const uint8x8_t &a, const uint8x8_t &b)
Definition: sub.h:39
#define M_PI
unsigned int M
unsigned int axis
Axis to run the kernel on.
Status class.
Definition: Error.h:52
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
Interface for CPU tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2021 Arm Limited.
uint8_t vgetlane(const uint8x8_t vector, const unsigned int lane)
Definition: getlane.h:91
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
Coordinates of an item.
Definition: Coordinates.h:37
unsigned int N
static std::set< unsigned int > supported_radix()
Returns the radix that are support by the FFT kernel.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
Descriptor used by the FFT core kernels.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
Definition: Helpers.inl:139
uint8x8_t vgetlow(const uint8x16_t val)
Definition: getlow.h:39
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:49
virtual PaddingSize padding() const =0
Padding of tensor.
uint8x16_t vcombine(const uint8x8_t &a, const uint8x8_t &b)
Definition: combine.h:39
unsigned int left
left of the border
Definition: Types.h:380
unsigned int right
right of the border
Definition: Types.h:378
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:915
int8x8_t vneg(const int8x8_t &a)
Definition: neg.h:39
uint8x8_t vgethigh(const uint8x16_t val)
Definition: gethigh.h:39
unsigned int radix
Radix to use.
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
uint8x8_t vmul(const uint8x8_t &a, const uint8x8_t &b)
Definition: mul.h:39
Information about executing thread and CPU.
Definition: CPPTypes.h:158
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:439
bool is_first_stage
Flags if the FFT kernels is the first stage of a decomposed FFT.
unsigned int Nx
Nx coefficient.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:541
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:788
uint8x8_t vrev64(const uint8x8_t &a)
Definition: rev64.h:39
uint8x8_t vload(const uint8_t *ptr)
Definition: load.h:39
void vstore(uint8_t *ptr, uint8x8_t val)
Definition: store.h:39
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
uint8x8_t vdup_n(uint8_t value, traits::vector_64_tag)
Definition: dup_n.h:41
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
Includes all wrapper headers at once.
uint8x8_t vmla(const uint8x8_t &a, const uint8x8_t &b, const uint8x8_t &c)
Definition: mla.h:46
Iterator updated by execute_window_loop for each window element.
Definition: Helpers.h:46
Describe a multidimensional execution window.
Definition: Window.h:39
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:201