52 void rdft_1d_step(
const T *src_ptr,
size_t N, T *dst_ptr,
size_t K)
55 #pragma omp parallel for
57 for(
unsigned int k = 0;
k <
K; ++
k)
61 for(
unsigned int n = 0;
n <
N; ++
n)
63 const float alpha = (2 *
M_PI *
k *
n) /
N;
64 const float val_r = src_ptr[
n];
66 Xr += val_r * cos(alpha);
67 Xi -= val_r * sin(alpha);
71 dst_ptr[
k * 2 + 1] = Xi;
82 void dft_1d_step(
const T *src_ptr, T *dst_ptr,
size_t N)
85 #pragma omp parallel for
87 for(
unsigned int k = 0;
k <
N; ++
k)
91 for(
unsigned int n = 0;
n <
N; ++
n)
93 const float alpha = (2 *
M_PI *
k *
n) /
N;
94 const float val_r = src_ptr[2 *
n];
95 const float val_i = src_ptr[2 *
n + 1];
96 const float cos_alpha = cos(alpha);
97 const float sin_alpha = sin(alpha);
99 Xr += val_r * cos_alpha + val_i * sin_alpha;
100 Xi += val_i * cos_alpha - val_r * sin_alpha;
104 dst_ptr[
k * 2 + 1] = Xi;
115 template <
typename T>
116 void irdft_1d_step(
const T *src_ptr,
size_t K, T *dst_ptr,
size_t N)
119 const unsigned int Nleft =
N -
K;
120 const int tail_start =
is_odd ?
K - 1 :
K - 2;
122 #pragma omp parallel for
124 for(
unsigned int n = 0;
n <
N; ++
n)
127 for(
unsigned int k = 0;
k <
K; ++
k)
129 const float alpha = (2 *
M_PI *
k *
n) /
N;
130 xr += src_ptr[2 *
k] * cos(alpha) - src_ptr[2 *
k + 1] * sin(alpha);
133 unsigned int j = tail_start;
134 for(
unsigned int k = 0;
k < Nleft; ++
k)
136 const float alpha = (2 *
M_PI * (
k +
K) *
n) /
N;
137 xr += src_ptr[2 * j] * cos(alpha) + src_ptr[2 * j + 1] * sin(alpha);
151 template <
typename T>
152 void idft_1d_step(
const T *src_ptr, T *dst_ptr,
size_t N)
155 #pragma omp parallel for
157 for(
unsigned int n = 0;
n <
N; ++
n)
161 for(
unsigned int k = 0;
k <
N; ++
k)
163 const float alpha = (2 *
M_PI *
k *
n) /
N;
164 const float cos_alpha = cos(alpha);
165 const float sin_alpha = sin(alpha);
166 const float val_r = src_ptr[2 *
k];
167 const float val_i = src_ptr[2 *
k + 1];
169 xr += val_r * cos_alpha - val_i * sin_alpha;
170 xi += val_i * cos_alpha + val_r * sin_alpha;
174 dst_ptr[2 *
n + 1] = xi;
178 template <
typename T>
185 const unsigned int inverse_tail =
is_odd ? 1 : 0;
186 const unsigned int N =
src.shape()[0];
195 const unsigned int upper_dims =
src.shape().total_size_upper(1);
197 #pragma omp parallel for
199 for(
unsigned int du = 0; du < upper_dims; ++du)
201 const T *src_row_ptr =
src.data() + du *
N *
src.num_channels();
202 T *dst_row_ptr =
dst.data() + du *
K *
dst.num_channels();
203 direction ==
FFTDirection::Forward ? rdft_1d_step(src_row_ptr,
N, dst_row_ptr,
K) : irdft_1d_step(src_row_ptr,
N, dst_row_ptr,
K);
209 template <
typename T>
210 SimpleTensor<T> dft_1d_core(
const SimpleTensor<T> &
src,
FFTDirection direction)
214 const unsigned int N =
src.shape()[0];
216 SimpleTensor<T>
dst(
src.shape(),
src.data_type(),
src.num_channels());
218 const unsigned int upper_dims =
src.shape().total_size_upper(1);
220 #pragma omp parallel for
222 for(
unsigned int du = 0; du < upper_dims; ++du)
224 const T *src_row_ptr =
src.data() + du *
N *
src.num_channels();
225 T *dst_row_ptr =
dst.data() + du *
N *
dst.num_channels();
226 direction ==
FFTDirection::Forward ? dft_1d_step(src_row_ptr, dst_row_ptr,
N) : idft_1d_step(src_row_ptr, dst_row_ptr,
N);
237 template <
typename T>
238 void scale(SimpleTensor<T> &
tensor, T scaling_factor)
240 const int total_elements =
tensor.num_elements() *
tensor.num_channels();
241 T *data_ptr =
tensor.data();
243 #pragma omp parallel for
245 for(
int i = 0; i < total_elements; ++i)
247 data_ptr[i] /= scaling_factor;
258 template <
typename T>
259 SimpleTensor<T> complex_mul_and_reduce(
const SimpleTensor<T> &
input,
const SimpleTensor<T> &weights)
261 const uint32_t W =
input.shape().x();
262 const uint32_t H =
input.shape().y();
263 const uint32_t Ci =
input.shape().z();
264 const uint32_t Co = weights.shape()[3];
265 const uint32_t
N =
input.shape().total_size() / (W * H * Ci);
272 const auto total_element_count =
dst.num_channels() *
dst.num_elements();
273 std::fill_n(
dst.data(), total_element_count, 0);
275 for(uint32_t
b = 0;
b <
N; ++
b)
277 for(uint32_t co = 0; co < Co; ++co)
279 for(uint32_t
ci = 0;
ci < Ci; ++
ci)
281 for(uint32_t h = 0; h < H; ++h)
283 for(uint32_t
w = 0;
w < W; ++
w)
285 const uint32_t i_index =
w + h * W +
ci * H * W +
b * H * W * Ci;
286 const uint32_t w_index =
w + h * W +
ci * H * W + co * H * W * Ci;
287 const uint32_t o_index =
w + h * W + co * H * W +
b * H * W * Co;
289 const Coordinates w_coords =
index2coords(weights.shape(), w_index);
292 auto i_ptr =
static_cast<const T *
>(
input(i_coords));
293 auto w_ptr =
static_cast<const T *
>(weights(w_coords));
294 auto o_ptr =
static_cast<T *
>(
dst(o_coords));
296 const T Rin = i_ptr[0];
297 const T Iin = i_ptr[1];
298 const T Rw = w_ptr[0];
299 const T Iw = w_ptr[1];
301 o_ptr[0] += Rin * Rw - Iin * Iw;
302 o_ptr[1] += Rin * Iw + Rw * Iin;
312 template <
typename T>
318 template <
typename T>
323 const T scaling_factor = T(
dst.shape()[0]);
329 template <
typename T>
332 auto dst = dft_1d_core(
src, direction);
335 const T scaling_factor = T(
dst.shape()[0]);
341 template <
typename T>
347 auto first_pass = rdft_1d_core(
src, direction,
false);
349 auto second_pass = dft_1d_core(transposed, direction);
353 template <
typename T>
360 auto first_pass = dft_1d_core(transposed, direction);
362 auto dst = rdft_1d_core(transposed_2, direction,
is_odd);
364 const T scaling_factor = T(
dst.shape()[0] *
dst.shape()[1]);
369 template <
typename T>
376 auto first_pass = dft_1d_core(
src, direction);
378 auto second_pass = dft_1d_core(transposed, direction);
384 auto first_pass = dft_1d_core(transposed, direction);
386 auto dst = dft_1d_core(transposed_2, direction);
388 const T scaling_factor = T(
dst.shape()[0] *
dst.shape()[1]);
395 template <
typename T>
399 const PaddingList padding_in = { { 0,
w.shape()[0] - 1 }, { 0,
w.shape()[1] - 1 } };
403 std::vector<uint32_t> axis_v = { 0, 1 };
405 std::copy(axis_v.begin(), axis_v.begin() + axis.shape().x(), axis.data());
409 const PaddingList paddings_w = { { 0,
src.shape()[0] - 1 }, { 0,
src.shape()[1] - 1 } };
410 auto padded_w =
pad_layer(flipped_w, paddings_w);
413 auto Fsrc =
rdft_2d(padded_src);
417 auto Fdst = complex_mul_and_reduce(Fsrc, Fw);
423 const int start_left =
w.shape().x() -
conv_info.pad_left() - 1;
424 const int start_top =
w.shape().y() -
conv_info.pad_top() - 1;
425 const int end_right = conv_res.shape().x() - (
w.shape().x() -
conv_info.pad_right() - 1);
426 const int end_botton = conv_res.shape().y() - (
w.shape().y() -
conv_info.pad_bottom() - 1);