92 #include "depthwise.hpp"
102 template <
class StratType,
class OutputStage=Nothing>
109 WorkspaceArgs(
const StratType *strat,
const DepthwiseArgs &dwargs,
const OutputStage &os = {})
124 template <
class StratType,
class OutputStage>
125 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &) {
return 0; }
127 template <
class WorkspaceType,
class StratType,
class OutputStage>
128 static void *initialise(WorkspaceType *,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &)
139 template <
typename T,
class OutputStage=Nothing>
140 class ActivationsElement
148 template <
typename StratType>
149 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &)
154 template <
class WorkspaceType,
class StratType>
155 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &
args)
157 ws->activation_min =
static_cast<T
>(-std::numeric_limits<float>::infinity());
158 ws->activation_max =
static_cast<T
>(std::numeric_limits<float>::infinity());
160 switch (
args.depthwise_args.activation.type)
163 ws->activation_max =
static_cast<T
>(
args.depthwise_args.activation.param1);
166 ws->activation_min =
static_cast<T
>(0);
179 template <
typename T>
180 class ActivationsElement<T,
arm_gemm::Requantize32> :
public EmptyElement
188 template <
typename OutputStage>
189 char get_input_buffer_fill_value(
const OutputStage &)
208 template <
typename T>
209 class InputBufferElement
217 template <
typename StratType,
typename OutputStage>
218 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &
args)
220 return sizeof(T) *
args.depthwise_args.input_channels *
args.depthwise_args.channel_multiplier;
223 template <
class WorkspaceType,
typename StratType,
typename OutputStage>
224 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &
args)
226 ws->input_buffer =
reinterpret_cast<T*
>(buffer);
227 memset(ws->input_buffer, get_input_buffer_fill_value(
args.output_stage), get_element_size(
args));
228 return reinterpret_cast<char *
>(buffer) + get_element_size(
args);
236 template <
typename T>
237 class OutputArrayElement
246 template <
typename OutputStage>
247 static size_t get_element_size(
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &
args)
249 return sizeof_outptr_array(
args) + sizeof_output_buffer(
args);
252 template <
class WorkspaceType,
typename OutputStage>
253 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &
args)
255 char *buffer_bytes =
reinterpret_cast<char *
>(buffer);
257 ws->outptr_array =
reinterpret_cast<T **
>(buffer_bytes);
258 buffer_bytes += sizeof_outptr_array(
args);
260 ws->output_buffer =
reinterpret_cast<T *
>(buffer_bytes);
261 buffer_bytes += sizeof_output_buffer(
args);
267 template <
typename OutputStage>
268 static size_t sizeof_outptr_array(
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &
args)
270 return sizeof(T **) *
args.strategy->get_output_rows() *
args.strategy->get_output_cols();
273 template <
typename OutputStage>
274 static size_t sizeof_output_buffer(
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &
args)
276 return sizeof(T) *
args.depthwise_args.input_channels *
args.depthwise_args.channel_multiplier;
284 template <
typename T>
285 class IntermediateBufferElement
293 template <
typename StratType,
typename OutputStage>
294 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &
args)
296 auto cols =
args.depthwise_args.input_cols +
args.depthwise_args.kernel_cols;
297 auto rows =
args.strategy->get_input_rows() +
args.depthwise_args.kernel_rows;
298 auto channels =
args.depthwise_args.input_channels *
args.depthwise_args.channel_multiplier;
299 return sizeof(T) *
cols *
rows * channels;
302 template <
class WorkspaceType,
typename StratType,
typename OutputStage>
303 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &
args)
305 ws->intermediate_buffer =
reinterpret_cast<T*
>(buffer);
306 return reinterpret_cast<char *
>(buffer) + get_element_size(
args);
317 class RequantizationParametersElement
325 template <
typename StratType>
326 static size_t get_element_size(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &
args)
328 return sizeof_bias(
args) + sizeof_requant_muls(
args) + sizeof_requant_shifts(
args);
331 template <
typename WorkspaceType,
typename StratType>
332 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &
args)
334 const auto n_output_channels =
args.depthwise_args.input_channels *
args.depthwise_args.channel_multiplier;
335 char *buffer_bytes =
reinterpret_cast<char *
>(buffer);
337 ws->bias =
args.output_stage.bias;
338 ws->requant_muls =
args.output_stage.per_channel_muls;
339 ws->requant_shifts =
args.output_stage.per_channel_right_shifts;
341 if (ws->bias ==
nullptr)
343 ws->bias =
reinterpret_cast<const int32_t *
>(buffer_bytes);
344 memset(buffer_bytes, 0, sizeof_bias(
args));
345 buffer_bytes += sizeof_bias(
args);
348 if (ws->requant_muls ==
nullptr)
350 ws->requant_muls =
reinterpret_cast<const int32_t *
>(buffer_bytes);
351 auto muls =
reinterpret_cast<int32_t *
>(buffer_bytes);
352 buffer_bytes += sizeof_requant_muls(
args);
354 for (
auto n = 0u; n < n_output_channels; n++)
356 muls[n] =
args.output_stage.per_layer_mul;
360 if (ws->requant_shifts ==
nullptr)
362 ws->requant_shifts =
reinterpret_cast<int32_t *
>(buffer_bytes);
363 auto shifts =
reinterpret_cast<int32_t *
>(buffer_bytes);
364 buffer_bytes += sizeof_requant_shifts(
args);
366 for (
auto n = 0u; n < n_output_channels; n++)
368 shifts[n] =
args.output_stage.per_layer_right_shift;
376 template <
typename StratType>
377 static size_t sizeof_bias(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &
args)
379 return args.output_stage.bias !=
nullptr ?
380 0 :
sizeof(int32_t) *
args.depthwise_args.channel_multiplier *
args.depthwise_args.input_channels;
383 template <
typename StratType>
384 static size_t sizeof_requant_muls(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &
args)
386 return args.output_stage.per_channel_muls !=
nullptr ?
387 0 :
sizeof(int32_t) *
args.depthwise_args.channel_multiplier *
args.depthwise_args.input_channels;
390 template <
typename StratType>
391 static size_t sizeof_requant_shifts(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &
args)
393 return args.output_stage.per_channel_right_shifts !=
nullptr ?
394 0 :
sizeof(int32_t) *
args.depthwise_args.channel_multiplier *
args.depthwise_args.input_channels;
399 template <
typename ...Elements>
402 template <
typename Element,
typename ...Elements>
403 class Workspace<Element, Elements...>
406 struct WorkspaceType : Element::Workspace, Workspace<Elements...>::WorkspaceType
410 template <
class S,
class T>
411 static void initialise(
void *buffer,
const WorkspaceArgs<S, T> &
args)
415 auto ws =
reinterpret_cast<WorkspaceType *
>(buffer);
416 initialise_elements(ws, ws + 1,
args);
419 template <
class S,
class T=Nothing>
420 static size_t get_sizeof_workspace(
const WorkspaceArgs<S, T> &
args)
422 return sizeof(WorkspaceType) + get_element_sizes(
args);
425 template <
class S,
class T>
426 static inline size_t get_element_sizes(
const WorkspaceArgs<S, T> &
args)
428 return Element::get_element_size(
args) + Workspace<Elements...>::get_element_sizes(
args);
431 template <
class WorkspaceType,
class S,
class T>
432 static void initialise_elements(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<S, T> &
args)
434 buffer = Element::initialise(ws, buffer,
args);
435 Workspace<Elements...>::initialise_elements(ws, buffer,
args);
447 template <
class S,
class T>
448 static inline size_t get_element_sizes(
const WorkspaceArgs<S, T> &)
453 template <
class WorkspaceType,
class S,
class T>
454 static void initialise_elements(WorkspaceType *,
void *,
const WorkspaceArgs<S, T> &)