34 #include "src/core/NEON/kernels/assembly/depthwise.hpp"
37 #include "depthwise_common.hpp"
52 constexpr
unsigned int idx_channels = 0;
53 constexpr
unsigned int idx_batches = 3;
55 template <
typename TSrc,
typename TWeights,
typename TDst>
56 void create_arm_dwc(
const ITensorInfo *
src,
57 const ITensorInfo *weights,
59 const ConvolutionInfo &
info,
61 std::unique_ptr<arm_conv::depthwise::IDepthwiseCommon> &kernel,
64 unsigned int stride_cols{};
65 unsigned int stride_rows{};
66 std::tie(stride_cols, stride_rows) =
info.pad_stride_info.stride();
68 unsigned int dilation_cols =
info.dilation.x();
69 unsigned int dilation_rows =
info.dilation.y();
73 const unsigned int n_batches =
src->dimension(idx_batches);
76 const unsigned int n_channels =
src->dimension(idx_channels);
80 const unsigned int kernel_cols = weights->dimension(
idx_width);
81 const unsigned int kernel_rows = weights->dimension(
idx_height);
85 arm_conv::depthwise::DepthwiseArgs
args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols,
86 dilation_rows, dilation_cols, n_batches, src_rows, src_cols, n_channels,
87 dst_rows, dst_cols,
info.depth_multiplier, padding, activation,
nullptr);
90 auto dwc_kernel_asm = arm_conv::depthwise::depthwise<TSrc, TWeights, TDst>(
args);
91 if (dwc_kernel_asm ==
nullptr)
97 _name = dwc_kernel_asm->name();
98 kernel = std::move(dwc_kernel_asm);
101 template <
typename TSrc,
typename TWeights,
typename TDst>
102 void create_arm_dwc_quant(
const ITensorInfo *
src,
103 const ITensorInfo *weights,
105 const ConvolutionInfo &
info,
107 std::unique_ptr<arm_conv::depthwise::IDepthwiseCommon> &kernel,
108 std::vector<int32_t> &multipliers,
109 std::vector<int32_t> &right_shifts,
110 std::vector<int32_t> &left_shifts,
113 unsigned int stride_cols{};
114 unsigned int stride_rows{};
115 std::tie(stride_cols, stride_rows) =
info.pad_stride_info.stride();
117 unsigned int dilation_cols =
info.dilation.x();
118 unsigned int dilation_rows =
info.dilation.y();
122 const unsigned int n_batches =
src->dimension(idx_batches);
124 const unsigned int src_cols =
src->dimension(
idx_width);
125 const unsigned int n_channels =
src->dimension(idx_channels);
127 const unsigned int dst_cols =
dst->dimension(
idx_width);
129 const unsigned int kernel_cols = weights->dimension(
idx_width);
130 const unsigned int kernel_rows = weights->dimension(
idx_height);
134 arm_conv::depthwise::DepthwiseArgs
args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols,
135 dilation_rows, dilation_cols, n_batches, src_rows, src_cols, n_channels,
136 dst_rows, dst_cols,
info.depth_multiplier, padding, activation,
nullptr);
138 const auto src_qinfo =
src->quantization_info().uniform();
139 const auto weights_qinfo = weights->quantization_info();
140 const auto dst_qinfo =
dst->quantization_info().uniform();
142 const unsigned int num_filters = weights_qinfo.scale().size();
144 multipliers.resize(num_filters);
145 std::vector<int32_t> dst_shifts(num_filters);
150 int32_t max_activation = std::numeric_limits<TSrc>::max();
151 if (
info.act_info.enabled())
153 std::tie(min_activation, max_activation) =
161 left_shifts.resize(num_filters);
162 right_shifts.resize(num_filters);
163 bool need_left_shift =
false;
164 for (
unsigned int i = 0; i < num_filters; ++i)
166 left_shifts[i] = std::max(-dst_shifts[i],
static_cast<int32_t
>(0));
167 right_shifts[i] = std::min(-dst_shifts[i],
static_cast<int32_t
>(0));
168 if (dst_shifts[i] < 0 && !need_left_shift)
170 need_left_shift =
true;
175 dst_qinfo.offset, (need_left_shift) ? left_shifts.data() :
nullptr,
176 right_shifts.data(), multipliers.data(),
177 static_cast<TSrc
>(min_activation),
static_cast<TSrc
>(max_activation));
182 dst_qinfo.offset, -dst_shifts[0], multipliers[0],
183 static_cast<TSrc
>(min_activation),
static_cast<TSrc
>(max_activation));
187 auto dwc_kernel_asm =
188 arm_conv::depthwise::depthwise<TSrc, TWeights, TDst, arm_gemm::Requantize32>(
args, requant_args);
189 if (dwc_kernel_asm ==
nullptr)
194 _name = dwc_kernel_asm->name();
195 kernel = std::move(dwc_kernel_asm);
200 : _kernel_asm(nullptr), _multipliers(), _left_shifts(), _right_shifts(), _name()
219 _name =
"CpuDepthwiseConv2dAssemblyWrapperKernel";
220 std::string asm_kernel_name(
"");
221 #if defined(__aarch64__)
222 switch (
src->data_type())
227 create_arm_dwc_quant<uint8_t, int8_t, uint8_t>(
src, weights,
dst,
info, cpu_info, _kernel_asm,
228 _multipliers, _right_shifts, _left_shifts,
233 create_arm_dwc_quant<uint8_t, uint8_t, uint8_t>(
src, weights,
dst,
info, cpu_info, _kernel_asm,
234 _multipliers, _right_shifts, _left_shifts,
239 create_arm_dwc_quant<int8_t, int8_t, int8_t>(
src, weights,
dst,
info, cpu_info, _kernel_asm, _multipliers,
240 _right_shifts, _left_shifts, asm_kernel_name);
242 #if defined(ENABLE_FP16_KERNELS)
244 create_arm_dwc<float16_t, float16_t, float16_t>(
src, weights,
dst,
info, cpu_info, _kernel_asm,
247 #endif // defined(ENABLE_FP16_KERNELS)
249 create_arm_dwc<float, float, float>(
src, weights,
dst,
info, cpu_info, _kernel_asm, asm_kernel_name);
254 #endif // defined(__aarch64__)
257 ICpuKernel::configure(win);
258 if (_kernel_asm !=
nullptr)
260 _name +=
"/" + asm_kernel_name;
272 #if !defined(__aarch64__)
274 #endif // !defined(__aarch64__)
279 "Only NHWC is supported by assembly kernels");
306 if (
dst->total_size() > 0)
314 const auto &padding =
info.pad_stride_info;
315 const auto &dilation =
info.dilation;
318 const auto dilated_wei_w = wei_shape[1] + (wei_shape[1] - 1) * (dilation.x() - 1);
319 const auto dilated_wei_h = wei_shape[2] + (wei_shape[2] - 1) * (dilation.y() - 1);
322 padding.pad_top() >= dilated_wei_h || padding.pad_bottom() >= dilated_wei_h);
341 const auto src_ptr =
src->buffer() +
src->info()->offset_first_element_in_bytes();
342 auto dst_ptr =
dst->buffer() +
dst->info()->offset_first_element_in_bytes();
346 const auto src_shape =
src->info()->tensor_shape();
348 const auto src_padding =
src->info()->padding();
349 const auto dst_padding =
dst->info()->padding();
351 const size_t ld_src_col = src_shape[0] + src_padding.left + src_padding.right;
352 const size_t ld_src_row = ld_src_col * (src_shape[1] + src_padding.top + src_padding.bottom);
353 const size_t ld_src_batch = ld_src_row * src_shape[2];
354 const size_t ld_dst_col =
dst_shape[0] + dst_padding.left + dst_padding.right;
355 const size_t ld_dst_row = ld_dst_col * (
dst_shape[1] + dst_padding.top + dst_padding.bottom);
356 const size_t ld_dst_batch = ld_dst_row *
dst_shape[2];
358 _kernel_asm->execute(src_ptr, ld_src_col, ld_src_row, ld_src_batch, parameters_ptr, dst_ptr, ld_dst_col, ld_dst_row,
359 ld_dst_batch, working_space,
info.thread_id,
info.num_threads);
363 void *parameters_ptr,
void *bias_ptr,
void *weights_ptr,
size_t ld_weights_col,
size_t ld_weight_row)
365 _kernel_asm->pack_parameters(parameters_ptr, bias_ptr, weights_ptr, ld_weights_col, ld_weight_row);
370 return _kernel_asm->get_storage_size();
375 return _kernel_asm->get_working_size(num_threads);
380 return _kernel_asm !=
nullptr;
385 return _name.c_str();