49 const ITensorInfo *vector_sum_col,
50 const ITensorInfo *vector_sum_row,
69 const bool reinterpret_as_3d =
70 mm_result->num_dimensions() > 1 && mm_result->tensor_shape().y() != vector_sum_row->tensor_shape().x();
74 (mm_result->dimension(1) * mm_result->dimension(2)));
80 const unsigned int output_batch_idx = reinterpret_as_3d ? 3 : 2;
82 TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
87 "mm_result tensor must have the same number of batches of output tensor");
91 TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
92 vector_sum_col_shape.collapse_from(1);
95 vector_sum_col_shape[1] != vector_sum_row_shape[1],
96 "vector_sum_col tensor must have the same number of batches of "
97 "vector_sum_row_shape or the number of batches must be set to 1");
105 void run_offset_contribution(
const Window &window,
107 const ITensor *vector_sum_col,
108 const ITensor *vector_sum_row,
112 bool slide_vector_sum_col,
115 Window collapsed_window = window.collapse_if_possible(window,
Window::DimZ);
116 collapsed_window.set(
Window::DimX, Window::Dimension(0, 1, 1));
118 const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
119 const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
121 const int window_start_x = window.x().start();
122 const int window_end_x = window.x().end();
123 const int window_step_x = 16;
126 const size_t sum_col_stride_y = (vector_sum_col !=
nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0;
127 Iterator mm_result_it(mm_result, collapsed_window);
129 if ((a_offset != 0) && (b_offset != 0) && (vector_sum_col !=
nullptr) && (vector_sum_row !=
nullptr))
132 Window win_vector_sum_col(collapsed_window);
133 win_vector_sum_col.set(
Window::DimY, Window::Dimension(0, 0, 0));
134 win_vector_sum_col.set(
Window::DimZ, Window::Dimension(0, 0, 0));
137 Window win_vector_sum_row(collapsed_window);
138 win_vector_sum_row.set(
Window::DimX, Window::Dimension(0, 0, 0));
139 win_vector_sum_row.set(
Window::DimY, Window::Dimension(0, 0, 0));
140 win_vector_sum_row.set(
Window::DimZ, Window::Dimension(0, 0, 0));
142 Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
143 Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
145 const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
148 const int vector_sum_col_batch_offset =
149 slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
153 [&](
const Coordinates &
id)
155 const int batch_id =
id.z() / depth_input;
156 const size_t batch_offset_col = batch_id * (sum_col_stride_y);
157 auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(vector_sum_col_it.ptr() + batch_offset_col +
158 batch_id * vector_sum_col_batch_offset);
159 auto mm_result_ptr =
reinterpret_cast<int32_t *
>(mm_result_it.ptr());
162 int32_t b_offset_term_s32 =
163 *(
reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
164 id.y() + (
id.z() % depth_input) * height_input);
165 b_offset_term_s32 *= b_offset;
167 const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
169 int x = window_start_x;
170 for (; x <= (window_end_x - window_step_x); x += window_step_x)
173 int32x4x4_t a_offset_term_s32 = {
174 {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
175 vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
177 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
178 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
179 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
180 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
183 int32x4x4_t offset_term_s32 = {
184 {vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset)}};
186 offset_term_s32.val[0] =
187 vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
188 offset_term_s32.val[1] =
189 vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
190 offset_term_s32.val[2] =
191 vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
192 offset_term_s32.val[3] =
193 vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
195 int32x4x4_t in_s32 = {{vld1q_s32(mm_result_ptr + x + 0), vld1q_s32(mm_result_ptr + x + 4),
196 vld1q_s32(mm_result_ptr + x + 8), vld1q_s32(mm_result_ptr + x + 12)}};
199 in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
200 in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
201 in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
202 in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
205 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
206 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
207 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
208 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
212 for (; x < window_end_x; ++x)
215 int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
217 a_offset_term_s32 *= a_offset;
221 mm_result_ptr[x] += k_offset + a_offset_term_s32 + b_offset_term_s32;
224 vector_sum_col_it, vector_sum_row_it, mm_result_it);
226 else if ((a_offset == 0) && (b_offset != 0) && (vector_sum_row !=
nullptr))
231 Window win_vector_sum_row(collapsed_window);
232 win_vector_sum_row.set(
Window::DimX, Window::Dimension(0, 0, 0));
233 win_vector_sum_row.set(
Window::DimY, Window::Dimension(0, 0, 0));
234 win_vector_sum_row.set(
Window::DimZ, Window::Dimension(0, 0, 0));
236 Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
238 const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
242 [&](
const Coordinates &
id)
244 const int batch_id =
id.z() / depth_input;
245 auto mm_result_ptr =
reinterpret_cast<int32_t *
>(mm_result_it.ptr());
248 int32_t b_offset_term_s32 =
249 *(
reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
250 id.y() + (
id.z() % depth_input) * height_input);
251 b_offset_term_s32 *= b_offset;
253 const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
255 int x = window_start_x;
256 for (; x <= (window_end_x - window_step_x); x += window_step_x)
258 int32x4x4_t in_s32 = {{vld1q_s32(mm_result_ptr + x + 0), vld1q_s32(mm_result_ptr + x + 4),
259 vld1q_s32(mm_result_ptr + x + 8), vld1q_s32(mm_result_ptr + x + 12)}};
262 in_s32.val[0] = vaddq_s32(in_s32.val[0], b_offset_term_s32_vec);
263 in_s32.val[1] = vaddq_s32(in_s32.val[1], b_offset_term_s32_vec);
264 in_s32.val[2] = vaddq_s32(in_s32.val[2], b_offset_term_s32_vec);
265 in_s32.val[3] = vaddq_s32(in_s32.val[3], b_offset_term_s32_vec);
268 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
269 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
270 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
271 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
275 for (; x < window_end_x; ++x)
279 mm_result_ptr[x] += b_offset_term_s32;
282 vector_sum_row_it, mm_result_it);
284 else if ((a_offset != 0) && (b_offset == 0) && (vector_sum_col !=
nullptr))
287 Window win_vector_sum_col(collapsed_window);
288 win_vector_sum_col.set(
Window::DimY, Window::Dimension(0, 0, 0));
289 win_vector_sum_col.set(
Window::DimZ, Window::Dimension(0, 0, 0));
291 Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
294 const int vector_sum_col_batch_offset =
295 slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
299 [&](
const Coordinates &
id)
301 const int batch_id =
id.z() / depth_input;
302 const size_t batch_offset_col =
305 auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(vector_sum_col_it.ptr() + batch_offset_col +
306 batch_id * vector_sum_col_batch_offset);
307 auto mm_result_ptr =
reinterpret_cast<int32_t *
>(mm_result_it.ptr());
309 int x = window_start_x;
310 for (; x <= (window_end_x - window_step_x); x += window_step_x)
313 int32x4x4_t a_offset_term_s32 = {
314 {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
315 vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
317 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
318 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
319 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
320 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
322 int32x4x4_t in_s32 = {{vld1q_s32(mm_result_ptr + x + 0), vld1q_s32(mm_result_ptr + x + 4),
323 vld1q_s32(mm_result_ptr + x + 8), vld1q_s32(mm_result_ptr + x + 12)}};
326 in_s32.val[0] = vaddq_s32(in_s32.val[0], a_offset_term_s32.val[0]);
327 in_s32.val[1] = vaddq_s32(in_s32.val[1], a_offset_term_s32.val[1]);
328 in_s32.val[2] = vaddq_s32(in_s32.val[2], a_offset_term_s32.val[2]);
329 in_s32.val[3] = vaddq_s32(in_s32.val[3], a_offset_term_s32.val[3]);
332 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
333 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
334 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
335 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
339 for (; x < window_end_x; ++x)
342 const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
346 mm_result_ptr[x] += a_offset_term_s32 * a_offset;
349 vector_sum_col_it, mm_result_it);
371 _a_offset = a_offset;
372 _b_offset = b_offset;
373 _k_offset = a_offset * b_offset * k;
386 ICpuKernel::configure(win);
410 const bool reinterpret_as_3d = vector_sum_row !=
nullptr && mm_result->
info()->
num_dimensions() > 1 &&
411 mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
413 run_offset_contribution(
window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset,
414 _slide_vector_sum_col, reinterpret_as_3d);
419 return "CpuGemmLowpOffsetContributionKernel";