47 Status validate_arguments(
const ITensorInfo *mm_result,
const ITensorInfo *vector_sum_col,
const ITensorInfo *vector_sum_row,
48 int32_t a_offset, int32_t b_offset)
65 const bool reinterpret_as_3d = mm_result->num_dimensions() > 1 && mm_result->tensor_shape().y() != vector_sum_row->tensor_shape().x();
68 ARM_COMPUTE_RETURN_ERROR_ON(reinterpret_as_3d && vector_sum_row->dimension(0) != (mm_result->dimension(1) * mm_result->dimension(2)));
72 if(output_shape.num_dimensions() > 1)
74 const unsigned int output_batch_idx = reinterpret_as_3d ? 3 : 2;
76 TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
78 output_shape.collapse_from(output_batch_idx);
81 "mm_result tensor must have the same number of batches of output tensor");
85 TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
86 vector_sum_col_shape.collapse_from(1);
89 "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1");
97 void run_offset_contribution(
const Window &window,
98 ITensor *mm_result,
const ITensor *vector_sum_col,
const ITensor *vector_sum_row,
99 int32_t a_offset, int32_t b_offset, int32_t k_offset,
bool slide_vector_sum_col,
bool is_gemm3d)
101 Window collapsed_window = window.collapse_if_possible(window,
Window::DimZ);
102 collapsed_window.set(
Window::DimX, Window::Dimension(0, 1, 1));
104 const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
105 const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
107 const int window_start_x = window.x().start();
108 const int window_end_x = window.x().end();
109 const int window_step_x = 16;
111 Iterator mm_result_it(mm_result, collapsed_window);
113 if((a_offset != 0) && (b_offset != 0) && (vector_sum_col !=
nullptr) && (vector_sum_row !=
nullptr))
116 Window win_vector_sum_col(collapsed_window);
117 win_vector_sum_col.set(
Window::DimY, Window::Dimension(0, 0, 0));
118 win_vector_sum_col.set(
Window::DimZ, Window::Dimension(0, 0, 0));
121 Window win_vector_sum_row(collapsed_window);
122 win_vector_sum_row.set(
Window::DimX, Window::Dimension(0, 0, 0));
123 win_vector_sum_row.set(
Window::DimY, Window::Dimension(0, 0, 0));
124 win_vector_sum_row.set(
Window::DimZ, Window::Dimension(0, 0, 0));
126 Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
127 Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
129 const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
132 const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
136 const int batch_id =
id.z() / depth_input;
137 auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
138 auto mm_result_ptr =
reinterpret_cast<int32_t *
>(mm_result_it.ptr());
141 int32_t b_offset_term_s32 = *(
reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
id.y() + (
id.z() % depth_input) * height_input);
142 b_offset_term_s32 *= b_offset;
144 const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
146 int x = window_start_x;
147 for(; x <= (window_end_x - window_step_x); x += window_step_x)
150 int32x4x4_t a_offset_term_s32 =
153 vld1q_s32(vector_sum_col_ptr + x + 0),
154 vld1q_s32(vector_sum_col_ptr + x + 4),
155 vld1q_s32(vector_sum_col_ptr + x + 8),
156 vld1q_s32(vector_sum_col_ptr + x + 12)
160 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
161 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
162 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
163 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
166 int32x4x4_t offset_term_s32 =
169 vdupq_n_s32(k_offset),
170 vdupq_n_s32(k_offset),
171 vdupq_n_s32(k_offset),
172 vdupq_n_s32(k_offset)
176 offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
177 offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
178 offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
179 offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
184 vld1q_s32(mm_result_ptr + x + 0),
185 vld1q_s32(mm_result_ptr + x + 4),
186 vld1q_s32(mm_result_ptr + x + 8),
187 vld1q_s32(mm_result_ptr + x + 12)
192 in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
193 in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
194 in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
195 in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
198 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
199 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
200 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
201 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
205 for(; x < window_end_x; ++x)
208 int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
210 a_offset_term_s32 *= a_offset;
214 mm_result_ptr[x] += k_offset + a_offset_term_s32 + b_offset_term_s32;
217 vector_sum_col_it, vector_sum_row_it, mm_result_it);
219 else if((a_offset == 0) && (b_offset != 0) && (vector_sum_row !=
nullptr))
224 Window win_vector_sum_row(collapsed_window);
225 win_vector_sum_row.set(
Window::DimX, Window::Dimension(0, 0, 0));
226 win_vector_sum_row.set(
Window::DimY, Window::Dimension(0, 0, 0));
227 win_vector_sum_row.set(
Window::DimZ, Window::Dimension(0, 0, 0));
229 Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
231 const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
235 const int batch_id =
id.z() / depth_input;
236 auto mm_result_ptr =
reinterpret_cast<int32_t *
>(mm_result_it.ptr());
239 int32_t b_offset_term_s32 = *(
reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
id.y() + (
id.z() % depth_input) * height_input);
240 b_offset_term_s32 *= b_offset;
242 const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
244 int x = window_start_x;
245 for(; x <= (window_end_x - window_step_x); x += window_step_x)
250 vld1q_s32(mm_result_ptr + x + 0),
251 vld1q_s32(mm_result_ptr + x + 4),
252 vld1q_s32(mm_result_ptr + x + 8),
253 vld1q_s32(mm_result_ptr + x + 12)
258 in_s32.val[0] = vaddq_s32(in_s32.val[0], b_offset_term_s32_vec);
259 in_s32.val[1] = vaddq_s32(in_s32.val[1], b_offset_term_s32_vec);
260 in_s32.val[2] = vaddq_s32(in_s32.val[2], b_offset_term_s32_vec);
261 in_s32.val[3] = vaddq_s32(in_s32.val[3], b_offset_term_s32_vec);
264 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
265 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
266 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
267 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
271 for(; x < window_end_x; ++x)
275 mm_result_ptr[x] += b_offset_term_s32;
278 vector_sum_row_it, mm_result_it);
280 else if((a_offset != 0) && (b_offset == 0) && (vector_sum_col !=
nullptr))
283 Window win_vector_sum_col(collapsed_window);
284 win_vector_sum_col.set(
Window::DimY, Window::Dimension(0, 0, 0));
285 win_vector_sum_col.set(
Window::DimZ, Window::Dimension(0, 0, 0));
287 Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
290 const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
294 const int batch_id =
id.z() / depth_input;
295 auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
296 auto mm_result_ptr =
reinterpret_cast<int32_t *
>(mm_result_it.ptr());
298 int x = window_start_x;
299 for(; x <= (window_end_x - window_step_x); x += window_step_x)
302 int32x4x4_t a_offset_term_s32 =
305 vld1q_s32(vector_sum_col_ptr + x + 0),
306 vld1q_s32(vector_sum_col_ptr + x + 4),
307 vld1q_s32(vector_sum_col_ptr + x + 8),
308 vld1q_s32(vector_sum_col_ptr + x + 12)
312 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
313 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
314 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
315 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
320 vld1q_s32(mm_result_ptr + x + 0),
321 vld1q_s32(mm_result_ptr + x + 4),
322 vld1q_s32(mm_result_ptr + x + 8),
323 vld1q_s32(mm_result_ptr + x + 12)
328 in_s32.val[0] = vaddq_s32(in_s32.val[0], a_offset_term_s32.val[0]);
329 in_s32.val[1] = vaddq_s32(in_s32.val[1], a_offset_term_s32.val[1]);
330 in_s32.val[2] = vaddq_s32(in_s32.val[2], a_offset_term_s32.val[2]);
331 in_s32.val[3] = vaddq_s32(in_s32.val[3], a_offset_term_s32.val[3]);
334 vst1q_s32(mm_result_ptr + x + 0, in_s32.val[0]);
335 vst1q_s32(mm_result_ptr + x + 4, in_s32.val[1]);
336 vst1q_s32(mm_result_ptr + x + 8, in_s32.val[2]);
337 vst1q_s32(mm_result_ptr + x + 12, in_s32.val[3]);
341 for(; x < window_end_x; ++x)
344 const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
348 mm_result_ptr[x] += a_offset_term_s32 * a_offset;
351 vector_sum_col_it, mm_result_it);
368 _a_offset = a_offset;
369 _b_offset = b_offset;
370 _k_offset = a_offset * b_offset * k;
383 ICpuKernel::configure(win);
387 int32_t a_offset, int32_t b_offset)
404 const bool reinterpret_as_3d = vector_sum_row !=
nullptr 406 && mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
408 run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, reinterpret_as_3d);
413 return "CpuGemmLowpOffsetContributionKernel";
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Store the tensor's metadata.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 S32 per channel
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
void collapse_from(size_t start)
Collapse dimensions starting from a given point.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
Class to describe a number of elements in each dimension.
void configure(ITensorInfo *mm_result, ITensorInfo *vector_sum_col, ITensorInfo *vector_sum_row, int32_t k, int32_t a_offset, int32_t b_offset)
Initialise the kernel's input and output.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
static Status validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, int32_t a_offset, int32_t b_offset)
Static function to check if given info will lead to a valid configuration.
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Information about executing thread and CPU.
static constexpr size_t DimZ
Alias for dimension 2 also known as Z dimension.
unsigned int num_dimensions() const
Returns the effective dimensionality of the tensor.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
const char * name() const override
Name of the kernel.
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Describe a multidimensional execution window.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)