Compute Library
 21.05
NEDirectConvolutionDetail.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26 #define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
27 
30 #include "support/Requires.h"
31 
32 #include <arm_neon.h>
33 
34 namespace arm_compute
35 {
36 namespace detail
37 {
38 /** Loads a 3x3 matrix as a row (float).
39  *
40  * @param[in] ptr Pointer to a float 3x3 matrix.
41  * @param[in] weights_offset (Optional) Weights quantization offset.
42  *
43  * @return The loaded matrix.
44  */
45 inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
46 {
47  ARM_COMPUTE_UNUSED(weights_offset);
48  const float32x4x3_t r =
49  {
50  {
51  vld1q_dup_f32(ptr),
52  vld1q_dup_f32(1 + ptr),
53  vld1q_dup_f32(2 + ptr)
54  }
55  };
56  return r;
57 }
58 
59 /** Loads a 3x3 matrix as a row (uint8_t/int8_t).
60  *
61  * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix.
62  * @param[in] weights_offset (Optional) Weights quantization offset.
63  *
64  * @return The loaded matrix.
65  */
66 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
67 inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
68 {
69  const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
70 
71  /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
72  r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
73  int32x4x3_t r =
74  {
75  {
76  vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
77  vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
78  vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
79  }
80  };
81  return r;
82 }
83 
84 /** Stores a float32x4x2_t array into a memory location.
85  *
86  * @param[in] buffer Pointer to the memory location where the values will be stored.
87  * @param[in] values Values that will be stored.
88  *
89  */
90 template <unsigned int stridex>
91 void store_results(float *buffer, const float32x4x2_t &values);
92 
93 template <>
94 inline void store_results<1>(float *buffer, const float32x4x2_t &values)
95 {
96  vst1q_f32(buffer, values.val[0]);
97  vst1q_f32(buffer + 4, values.val[1]);
98 }
99 
100 template <>
101 inline void store_results<2>(float *buffer, const float32x4x2_t &values)
102 {
103  vst1q_f32(buffer, values.val[0]);
104 }
105 
106 template <>
107 inline void store_results<3>(float *buffer, const float32x4x2_t &values)
108 {
109  vst1_f32(buffer, vget_low_f32(values.val[0]));
110 }
111 
112 /** Stores a uint32_t array into a memory location.
113  *
114  * @param[in] buffer Pointer to the memory location where the values will be stored.
115  * @param[in] values Values that will be stored.
116  *
117  */
118 template <unsigned int stridex>
119 void store_results(int32_t *buffer, const int32x4x2_t &values);
120 
121 template <>
122 inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
123 {
124  vst1q_s32(buffer, values.val[0]);
125  vst1q_s32(buffer + 4, values.val[1]);
126 }
127 
128 template <>
129 inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
130 {
131  vst1q_s32(buffer, values.val[0]);
132 }
133 
134 template <>
135 inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
136 {
137  vst1_s32(buffer, vget_low_s32(values.val[0]));
138 }
139 
140 template <unsigned int stridex>
141 inline void accumulate_results(float *buffer, const float32x4x2_t &values);
142 
143 template <>
144 inline void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
145 {
146  vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
147  vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
148 }
149 
150 template <>
151 inline void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
152 {
153  vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
154 }
155 
156 template <>
157 inline void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
158 {
159  vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
160 }
161 
162 template <unsigned int stridex>
163 void accumulate_results(int32_t *buffer, const int32x4x2_t &values);
164 
165 template <>
166 inline void accumulate_results<1>(int32_t *buffer, const int32x4x2_t &values)
167 {
168  vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
169  vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
170 }
171 
172 template <>
173 inline void accumulate_results<2>(int32_t *buffer, const int32x4x2_t &values)
174 {
175  vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
176 }
177 
178 template <>
179 inline void accumulate_results<3>(int32_t *buffer, const int32x4x2_t &values)
180 {
181  vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
182 }
183 
184 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
185 /** Stores a float16x8x2_t array into a memory location.
186  *
187  * @param[in] buffer Pointer to the memory location where the values will be stored.
188  * @param[in] values Values that will be stored.
189  *
190  */
191 template <unsigned int stridex>
192 void store_results(float16_t *buffer, const float16x8x2_t &values);
193 
194 template <>
195 inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
196 {
197  vst1q_f16(buffer, values.val[0]);
198  vst1q_f16(buffer + 8, values.val[1]);
199 }
200 
201 template <>
202 inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
203 {
204  vst1q_f16(buffer, values.val[0]);
205 }
206 
207 template <>
208 inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
209 {
210  vst1_f16(buffer, vget_low_f16(values.val[0]));
211 }
212 
213 template <unsigned int stridex>
214 inline void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
215 
216 template <>
217 inline void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
218 {
219  vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
220  vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
221 }
222 
223 template <>
224 inline void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
225 {
226  vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
227 }
228 
229 template <>
230 inline void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
231 {
232  vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
233 }
234 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
235 
236 /** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
237  *
238  * @param[in] in_top Pointer to the first row of the input.
239  * @param[in] in_mid Pointer to the second row of the input.
240  * @param[in] in_low Pointer to the third row of the input.
241  * @param[in] m0 First row of the filter.
242  * @param[in] m1 Second row of the filter.
243  * @param[in] m2 Third row of the filter.
244  * @param[in] dilation_x Dilation, in elements across x.
245  * @param[in] input_offset (Optional) Input quantization offset.
246  *
247  */
248 inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
249  const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
250  const size_t dilation_x, int input_offset)
251 {
252  ARM_COMPUTE_UNUSED(input_offset);
253 
254  const float32x4x3_t vtop =
255  {
256  {
257  vld1q_f32(in_top),
258  vld1q_f32(in_top + dilation_x),
259  vld1q_f32(in_top + 2 * dilation_x)
260  }
261  };
262  const float32x4x3_t vmid =
263  {
264  {
265  vld1q_f32(in_mid),
266  vld1q_f32(in_mid + dilation_x),
267  vld1q_f32(in_mid + 2 * dilation_x)
268  }
269  };
270  const float32x4x3_t vlow =
271  {
272  {
273  vld1q_f32(in_low),
274  vld1q_f32(in_low + dilation_x),
275  vld1q_f32(in_low + 2 * dilation_x)
276  }
277  };
278  float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
279  out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
280  out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
281 
282  out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
283  out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
284  out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
285 
286  out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
287  out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
288  out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
289 
290  return out;
291 }
292 
293 /** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
294  *
295  * @param[in] in_top Pointer to the first row of the input.
296  * @param[in] in_mid Pointer to the second row of the input.
297  * @param[in] in_low Pointer to the third row of the input.
298  * @param[in] m0 First row of the filter.
299  * @param[in] m1 Second row of the filter.
300  * @param[in] m2 Third row of the filter.
301  * @param[in] dilation_x Dilation, in elements across x.
302  * @param[in] stridex Stride value in elements across x.
303  * @param[in] input_offset (Optional) Input quantization offset.
304  *
305  */
306 inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
307  const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
308  const size_t dilation_x, unsigned int stridex, int input_offset = 0)
309 {
310  ARM_COMPUTE_ERROR_ON(stridex > 3);
311  float32x4x2_t out =
312  {
313  {
314  single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
315  single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
316  }
317  };
318 
319  if(stridex == 2)
320  {
321  out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
322  out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
323  out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
324  }
325  else if(stridex == 3)
326  {
327  out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
328  }
329 
330  return out;
331 }
332 
333 /** Perform a convolve3x3 on float32.
334  *
335  * @param[in] in_top Pointer to the first row of the input.
336  * @param[in] in_mid Pointer to the second row of the input.
337  * @param[in] in_low Pointer to the third row of the input.
338  * @param[out] out_ptr Pointer to the output.
339  * @param[in] m0 First row of the filter.
340  * @param[in] m1 Second row of the filter.
341  * @param[in] m2 Third row of the filter.
342  * @param[in] stridex Stride value in elements across x.
343  * @param[in] input_offset (Optional) Input quantization offset.
344  *
345  */
346 template <bool accumulate>
347 void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
348  const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
349  unsigned int stridex, int input_offset = 0);
350 
351 template <bool accumulate>
352 inline void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
353  const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
354  unsigned int stridex, int input_offset)
355 {
356  ARM_COMPUTE_UNUSED(input_offset);
357  ARM_COMPUTE_ERROR_ON(stridex > 3);
358 
359  float32x4x2_t out =
360  {
361  {
362  vdupq_n_f32(0.f),
363  vdupq_n_f32(0.f)
364  }
365  };
366  if(stridex == 2)
367  {
368  const float32x4x2_t vtop = vld2q_f32(in_top);
369  const float32x4x2_t vmid = vld2q_f32(in_mid);
370  const float32x4x2_t vlow = vld2q_f32(in_low);
371  const float32x4_t vtop_end = vld1q_f32(in_top + 8);
372  const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
373  const float32x4_t vlow_end = vld1q_f32(in_low + 8);
374 
375  out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
376 
377  out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
378  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
379 
380  out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
381  out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
382  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
383 
384  out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
385  out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
386  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
387 
388  accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
389  }
390  else
391  {
392  const float32x4x3_t vtop =
393  {
394  {
395  vld1q_f32(in_top),
396  vld1q_f32(in_top + 4),
397  vld1q_f32(in_top + 8)
398  }
399  };
400  const float32x4x3_t vmid =
401  {
402  {
403  vld1q_f32(in_mid),
404  vld1q_f32(in_mid + 4),
405  vld1q_f32(in_mid + 8)
406  }
407  };
408  const float32x4x3_t vlow =
409  {
410  {
411  vld1q_f32(in_low),
412  vld1q_f32(in_low + 4),
413  vld1q_f32(in_low + 8)
414  }
415  };
416  out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
417  out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
418 
419  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
420  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
421 
422  out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
423  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
424  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
425 
426  out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
427  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
428  out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
429 
430  out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
431  out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
432 
433  out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
434  out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
435  out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
436 
437  out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
438  out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
439  out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
440 
441  if(stridex == 3)
442  {
443  out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
444  accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
445  }
446  else
447  {
448  accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
449  }
450  }
451 }
452 
453 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
454  *
455  * @param[in] in_top Pointer to the first row of the input.
456  * @param[in] in_mid Pointer to the second row of the input.
457  * @param[in] in_low Pointer to the third row of the input.
458  * @param[in] m0 First row of the filter.
459  * @param[in] m1 Second row of the filter.
460  * @param[in] m2 Third row of the filter.
461  * @param[in] dilation_x Dilation, in elements across x.
462  * @param[in] input_offset Input quantization offset.
463  *
464  */
465 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
466 inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
467  const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
468  size_t dilation_x, int32_t input_offset)
469 {
470  using VectorType = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
472 
473  const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
474 
475  const VectorType vtop =
476  {
477  {
478  wrapper::vload(in_top),
479  wrapper::vload(in_top + dilation_x),
480  wrapper::vload(in_top + 2 * dilation_x)
481  }
482  };
483  const VectorType vmid =
484  {
485  {
486  wrapper::vload(in_mid),
487  wrapper::vload(in_mid + dilation_x),
488  wrapper::vload(in_mid + 2 * dilation_x)
489  }
490  };
491  const VectorType vlow =
492  {
493  {
494  wrapper::vload(in_low),
495  wrapper::vload(in_low + dilation_x),
496  wrapper::vload(in_low + 2 * dilation_x)
497  }
498  };
499 
500  const int32x4x3_t vtop_s32 =
501  {
502  {
506  }
507  };
508  const int32x4x3_t vmid_s32 =
509  {
510  {
514  }
515  };
516  const int32x4x3_t vlow_s32 =
517  {
518  {
522  }
523  };
524 
525  int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
526  out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
527  out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
528 
529  out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
530  out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
531  out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
532 
533  out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
534  out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
535  out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
536 
537  return out;
538 }
539 
540 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
541  *
542  * @param[in] in_top Pointer to the first row of the input.
543  * @param[in] in_mid Pointer to the second row of the input.
544  * @param[in] in_low Pointer to the third row of the input.
545  * @param[in] m0 First row of the filter.
546  * @param[in] m1 Second row of the filter.
547  * @param[in] m2 Third row of the filter.
548  * @param[in] dilation_x Dilation, in elements across x.
549  * @param[in] stridex Stride value in elements across x.
550  * @param[in] input_offset Input quantization offset.
551  *
552  */
553 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
554 inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
555  const size_t dilation_x, unsigned int stridex, int input_offset)
556 {
557  ARM_COMPUTE_ERROR_ON(stridex > 3);
558  int32x4x2_t out =
559  {
560  {
561  single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
562  single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
563  }
564  };
565 
566  if(stridex == 2)
567  {
568  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
569  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
570  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
571  }
572  else if(stridex == 3)
573  {
574  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
575  }
576  return out;
577 }
578 
579 /** Perform a convolve3x3 on 8-bit elements
580  *
581  * @param[in] in_top Pointer to the first row of the input.
582  * @param[in] in_mid Pointer to the second row of the input.
583  * @param[in] in_low Pointer to the third row of the input.
584  * @param[out] out_ptr Pointer to the output.
585  * @param[in] m0 First row of the filter.
586  * @param[in] m1 Second row of the filter.
587  * @param[in] m2 Third row of the filter.
588  * @param[in] stridex Stride value in elements across x.
589  * @param[in] input_offset Input quantization offset.
590  *
591  */
592 template < bool accumulate, typename T1, typename T2, ARM_COMPUTE_REQUIRES_TA(std::is_same<T1, uint8_t>::value || std::is_same<T1, int8_t>::value) >
593 void convolve_3x3(const T1 *in_top, const T1 *in_mid, const T1 *in_low, T2 *out_ptr,
594  const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
595  unsigned int stridex, int32_t input_offset)
596 {
597  ARM_COMPUTE_ERROR_ON(stridex > 3);
598  using VectorType = typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
600 
601  const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
602 
603  const VectorType vtop =
604  {
605  {
606  wrapper::vload(in_top),
607  wrapper::vload(in_top + 8)
608  }
609  };
610  const VectorType vmid =
611  {
612  {
613  wrapper::vload(in_mid),
614  wrapper::vload(in_mid + 8)
615  }
616  };
617  const VectorType vlow =
618  {
619  {
620  wrapper::vload(in_low),
621  wrapper::vload(in_low + 8)
622  }
623  };
624 
625  const int32x4x3_t vtop_s32 =
626  {
627  {
631  }
632  };
633  const int32x4x3_t vmid_s32 =
634  {
635  {
639  }
640  };
641  const int32x4x3_t vlow_s32 =
642  {
643  {
647  }
648  };
649 
650  int32x4x2_t out
651  {
652  {
653  wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
654  wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
655  }
656  };
657 
658  // 0
659  out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
660  out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
661  out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
662 
663  out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
664  out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
665  out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
666 
667  out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
668  out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
669  out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
670 
671  // 1
672  out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
673  out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
674  out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
675 
676  out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
677  out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
678  out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
679 
680  out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
681  out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
682  out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
683 
684  if(stridex == 1)
685  {
686  accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
687  }
688  else if(stridex == 2)
689  {
690  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
691  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
692  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
693 
694  accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
695  }
696  else if(stridex == 3)
697  {
698  out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
699  accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
700  }
701 }
702 
703 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
704 /** Loads a 3x3 matrix as a row (float16_t).
705  *
706  * @param[in] ptr Pointer to a float 3x3 matrix.
707  *
708  * @return The loaded matrix.
709  */
710 inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
711 {
712  ARM_COMPUTE_UNUSED(weights_offset);
713  /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
714  r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
715  const float16x8x3_t r =
716  {
717  {
718  vld1q_dup_f16(ptr),
719  vld1q_dup_f16(1 + ptr),
720  vld1q_dup_f16(2 + ptr)
721  }
722  };
723  return r;
724 }
725 
726 /** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
727  *
728  * @param[in] in_top Pointer to the first row of the input.
729  * @param[in] in_mid Pointer to the second row of the input.
730  * @param[in] in_low Pointer to the third row of the input.
731  * @param[in] m0 First row of the filter.
732  * @param[in] m1 Second row of the filter.
733  * @param[in] m2 Third row of the filter.
734  * @param[in] dilation_x Dilation, in elements across x.
735  * @param[in] input_offset (Optional)Input quantization offset.
736  *
737  */
738 inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
739  const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
740  const size_t dilation_x, int input_offset = 0)
741 {
742  ARM_COMPUTE_UNUSED(input_offset);
743  const float16x8x3_t vtop =
744  {
745  {
746  vld1q_f16(in_top),
747  vld1q_f16(in_top + dilation_x),
748  vld1q_f16(in_top + 2 * dilation_x)
749  }
750  };
751  const float16x8x3_t vmid =
752  {
753  {
754  vld1q_f16(in_mid),
755  vld1q_f16(in_mid + dilation_x),
756  vld1q_f16(in_mid + 2 * dilation_x)
757  }
758  };
759  const float16x8x3_t vlow =
760  {
761  {
762  vld1q_f16(in_low),
763  vld1q_f16(in_low + dilation_x),
764  vld1q_f16(in_low + 2 * dilation_x)
765  }
766  };
767  float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
768  out = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
769  out = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
770 
771  out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
772  out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
773  out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
774 
775  out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
776  out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
777  out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
778 
779  return out;
780 }
781 
782 /** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
783  *
784  * @param[in] in_top Pointer to the first row of the input.
785  * @param[in] in_mid Pointer to the second row of the input.
786  * @param[in] in_low Pointer to the third row of the input.
787  * @param[in] m0 First row of the filter.
788  * @param[in] m1 Second row of the filter.
789  * @param[in] m2 Third row of the filter.
790  * @param[in] dilation_x Dilation, in elements across x.
791  * @param[in] stridex Stride value in elements across x.
792  * @param[in] input_offset (Optional) Input quantization offset.
793  *
794  */
795 inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
796  const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
797  const size_t dilation_x, unsigned int stridex, int input_offset = 0)
798 {
799  float16x8x2_t out =
800  {
801  {
802  single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
803  single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
804  }
805  };
806 
807  if(stridex == 2)
808  {
809  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
810  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
811  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
812  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
813  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
814  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
815  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
816  }
817  else if(stridex == 3)
818  {
819  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
820  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
821  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
822  }
823 
824  return out;
825 }
826 
827 /** Perform a convolve3x3 on float16.
828  *
829  * @param[in] in_top Pointer to the first row of the input.
830  * @param[in] in_mid Pointer to the second row of the input.
831  * @param[in] in_low Pointer to the third row of the input.
832  * @param[out] out_ptr Pointer to the output.
833  * @param[in] m0 First row of the filter.
834  * @param[in] m1 Second row of the filter.
835  * @param[in] m2 Third row of the filter.
836  * @param[in] stridex Stride value in elements across x.
837  * @param[in] input_offset (Optional) Input quantization offset.
838  *
839  */
840 template <bool accumulate>
841 inline void convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, float16_t *out_ptr,
842  const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
843  unsigned int stridex, int input_offset = 0)
844 {
845  ARM_COMPUTE_UNUSED(input_offset);
846 
847  float16x8x2_t out =
848  {
849  {
850  vdupq_n_f16(0),
851  vdupq_n_f16(0)
852  }
853  };
854  if(stridex == 2)
855  {
856  const float16x8x2_t vtop = vld2q_f16(in_top);
857  const float16x8x2_t vmid = vld2q_f16(in_mid);
858  const float16x8x2_t vlow = vld2q_f16(in_low);
859  const float16x8_t vtop_end = vld1q_f16(in_top + 16);
860  const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
861  const float16x8_t vlow_end = vld1q_f16(in_low + 16);
862 
863  out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
864 
865  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
866  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
867 
868  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
869  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
870  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
871 
872  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
873  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
874  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
875 
876  accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
877  }
878  else
879  {
880  const float16x8x3_t vtop =
881  {
882  {
883  vld1q_f16(in_top),
884  vld1q_f16(in_top + 8),
885  vld1q_f16(in_top + 16)
886  }
887  };
888  const float16x8x3_t vmid =
889  {
890  {
891  vld1q_f16(in_mid),
892  vld1q_f16(in_mid + 8),
893  vld1q_f16(in_mid + 16)
894  }
895  };
896  const float16x8x3_t vlow =
897  {
898  {
899  vld1q_f16(in_low),
900  vld1q_f16(in_low + 8),
901  vld1q_f16(in_low + 16)
902  }
903  };
904  out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
905  out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
906 
907  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
908  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
909  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
910  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
911  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
912  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
913  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
914  out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
915  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
916  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
917  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
918  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
919  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
920  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
921  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
922  out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
923 
924  if(stridex == 3)
925  {
926  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
927  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
928  out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
929 
930  accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
931  }
932  else
933  {
934  accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
935  }
936  }
937 }
938 #endif /** __ARM_FEATURE_FP16_VECTOR_ARITHMETIC **/
939 
940 /** Get the number of elements processed on 3x3 convolution.
941  *
942  * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
943  * @param[in] stridex Stride value in elements across x.
944  *
945  * @return The number of elements processed.
946  */
947 inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
948 {
949  switch(stridex)
950  {
951  case 1:
952  return num_elems_written_per_iteration;
953  case 2:
954  return num_elems_written_per_iteration << 1;
955  case 3:
956  return num_elems_written_per_iteration * 3;
957  default:
958  ARM_COMPUTE_ERROR("stridex not supported");
959  return 0;
960  }
961 }
962 }
963 } // namespace arm_compute
964 #endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
Definition: clang-tidy.h:78
void accumulate_results< 1 >(float *buffer, const float32x4x2_t &values)
void accumulate_results< 3 >(float *buffer, const float32x4x2_t &values)
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
float16x8_t vextq_f16(float16x8_t, float16x8_t, int)
Definition: clang-tidy.h:108
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
Definition: clang-tidy.h:68
uint8x8_t vext_2(uint8x8_t value_a, uint8x8_t value_b)
Definition: ext.h:40
float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, const size_t dilation_x, int input_offset)
Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation....
int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration)
float16x4_t vadd_f16(float16x4_t, float16x4_t)
Definition: clang-tidy.h:33
decltype(strategy::transforms) typedef type
SimpleTensor< T2 > accumulate(const SimpleTensor< T1 > &src, DataType output_data_type)
Definition: Accumulate.cpp:38
Copyright (c) 2017-2021 Arm Limited.
float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
typename neon_bitvector< T, BW >::tag_type neon_bitvector_tag_t
Helper type template to get the tag type of a neon vector.
Definition: traits.h:132
uint8_t vgetlane(const uint8x8_t vector, const unsigned int lane)
Definition: getlane.h:91
float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, const size_t dilation_x, unsigned int stridex, int input_offset=0)
Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation....
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
uint8x8_t vext_1(uint8x8_t value_a, uint8x8_t value_b)
Definition: ext.h:39
void store_results(float *buffer, const float32x4x2_t &values)
Stores a float32x4x2_t array into a memory location.
int16x4_t vreinterpret(const uint16x4_t &a)
Definition: reinterpret.h:44
void accumulate_results(float *buffer, const float32x4x2_t &values)
uint8x8_t vgetlow(const uint8x16_t val)
Definition: getlow.h:39
uint8x8_t vsetlane(const uint8_t value, const uint8x8_t vector, const unsigned int lane)
Definition: setlane.h:91
void store_results< 3 >(float *buffer, const float32x4x2_t &values)
uint8x8_t vgethigh(const uint8x16_t val)
Definition: gethigh.h:39
uint16x8_t vaddw(const uint16x8_t &a, const uint8x8_t &b)
Definition: add.h:107
uint8x8_t vmul(const uint8x8_t &a, const uint8x8_t &b)
Definition: mul.h:39
void accumulate_results< 2 >(float *buffer, const float32x4x2_t &values)
void store_results< 1 >(float *buffer, const float32x4x2_t &values)
uint8x8_t vload(const uint8_t *ptr)
Definition: load.h:39
uint8x8_t vdup_n(uint8_t value, traits::vector_64_tag)
Definition: dup_n.h:41
float32x4x3_t load_matrix_row(const float *ptr)
Includes all wrapper headers at once.
uint8x8_t vmla(const uint8x8_t &a, const uint8x8_t &b, const uint8x8_t &c)
Definition: mla.h:46
uint16x8_t vmovl(const uint8x8_t &a)
Definition: movl.h:39
void store_results< 2 >(float *buffer, const float32x4x2_t &values)