92 #include "depthwise.hpp" 102 template <
class StratType,
class OutputStage=Nothing>
109 WorkspaceArgs(
const StratType *strat,
const DepthwiseArgs &dwargs,
const OutputStage &os = {})
124 template <
class StratType,
class OutputStage>
125 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &) {
return 0; }
127 template <
class WorkspaceType,
class StratType,
class OutputStage>
128 static void *initialise(WorkspaceType *,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &)
139 template <
typename T,
class OutputStage=Nothing>
140 class ActivationsElement
148 template <
typename StratType>
149 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &)
154 template <
class WorkspaceType,
class StratType>
155 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &
args)
157 ws->activation_min =
static_cast<T
>(-std::numeric_limits<float>::infinity());
158 ws->activation_max =
static_cast<T
>(std::numeric_limits<float>::infinity());
160 switch (args.depthwise_args.activation.type)
163 ws->activation_max =
static_cast<T
>(args.depthwise_args.activation.param1);
166 ws->activation_min =
static_cast<T
>(0);
179 template <
typename T>
180 class ActivationsElement<T, arm_gemm::Requantize32> :
public EmptyElement
188 template <
typename OutputStage>
189 char get_input_buffer_fill_value(
const OutputStage &)
208 template <
typename T>
209 class InputBufferElement
217 template <
typename StratType,
typename OutputStage>
218 static size_t get_element_size(
const WorkspaceArgs<StratType, OutputStage> &
args)
220 return sizeof(T) * args.depthwise_args.input_channels;
223 template <
class WorkspaceType,
typename StratType,
typename OutputStage>
224 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, OutputStage> &
args)
226 ws->input_buffer =
reinterpret_cast<T*
>(buffer);
227 memset(ws->input_buffer, get_input_buffer_fill_value(args.output_stage), get_element_size(args));
228 return reinterpret_cast<char *
>(buffer) + get_element_size(args);
236 template <
typename T>
237 class OutputArrayElement
246 template <
typename OutputStage>
247 static size_t get_element_size(
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
249 return sizeof_outptr_array(args) + sizeof_output_buffer(args);
252 template <
class WorkspaceType,
typename OutputStage>
253 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
255 char *buffer_bytes =
reinterpret_cast<char *
>(buffer);
257 ws->outptr_array =
reinterpret_cast<T **
>(buffer_bytes);
258 buffer_bytes += sizeof_outptr_array(args);
260 ws->output_buffer =
reinterpret_cast<T *
>(buffer_bytes);
261 buffer_bytes += sizeof_output_buffer(args);
267 template <
typename OutputStage>
268 static size_t sizeof_outptr_array(
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
270 return sizeof(T **) * args.strategy->get_output_rows() * args.strategy->get_output_cols();
273 template <
typename OutputStage>
274 static size_t sizeof_output_buffer(
const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
276 return sizeof(T) * args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
287 class RequantizationParametersElement
295 template <
typename StratType>
296 static size_t get_element_size(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
298 return sizeof_bias(args) + sizeof_requant_muls(args) + sizeof_requant_shifts(args);
301 template <
typename WorkspaceType,
typename StratType>
302 static void *initialise(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
304 const auto n_output_channels = args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
305 char *buffer_bytes =
reinterpret_cast<char *
>(buffer);
307 ws->bias = args.output_stage.bias;
308 ws->requant_muls = args.output_stage.per_channel_muls;
309 ws->requant_shifts = args.output_stage.per_channel_right_shifts;
311 if (ws->bias ==
nullptr)
313 ws->bias =
reinterpret_cast<const int32_t *
>(buffer_bytes);
314 memset(buffer_bytes, 0, sizeof_bias(args));
315 buffer_bytes += sizeof_bias(args);
318 if (ws->requant_muls ==
nullptr)
320 ws->requant_muls =
reinterpret_cast<const int32_t *
>(buffer_bytes);
321 auto muls =
reinterpret_cast<int32_t *
>(buffer_bytes);
322 buffer_bytes += sizeof_requant_muls(args);
324 for (
auto n = 0u;
n < n_output_channels;
n++)
326 muls[
n] = args.output_stage.per_layer_mul;
330 if (ws->requant_shifts ==
nullptr)
332 ws->requant_shifts =
reinterpret_cast<int32_t *
>(buffer_bytes);
333 auto shifts =
reinterpret_cast<int32_t *
>(buffer_bytes);
334 buffer_bytes += sizeof_requant_shifts(args);
336 for (
auto n = 0u;
n < n_output_channels;
n++)
338 shifts[
n] = args.output_stage.per_layer_right_shift;
346 template <
typename StratType>
347 static size_t sizeof_bias(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
349 return args.output_stage.bias !=
nullptr ?
350 0 :
sizeof(int32_t) * args.depthwise_args.channel_multiplier * args.depthwise_args.input_channels;
353 template <
typename StratType>
354 static size_t sizeof_requant_muls(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
356 return args.output_stage.per_channel_muls !=
nullptr ?
357 0 :
sizeof(int32_t) * args.depthwise_args.channel_multiplier * args.depthwise_args.input_channels;
360 template <
typename StratType>
361 static size_t sizeof_requant_shifts(
const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
363 return args.output_stage.per_channel_right_shifts !=
nullptr ?
364 0 :
sizeof(int32_t) * args.depthwise_args.channel_multiplier * args.depthwise_args.input_channels;
369 template <
typename ...Elements>
372 template <
typename Element,
typename ...Elements>
373 class Workspace<Element, Elements...>
376 struct WorkspaceType : Element::Workspace, Workspace<Elements...>::WorkspaceType
380 template <
class S,
class T>
381 static void initialise(
void *buffer,
const WorkspaceArgs<S, T> &args)
385 auto ws =
reinterpret_cast<WorkspaceType *
>(buffer);
386 initialise_elements(ws, ws + 1, args);
389 template <
class S,
class T=Nothing>
390 static size_t get_sizeof_workspace(
const WorkspaceArgs<S, T> &args)
392 return sizeof(WorkspaceType) + get_element_sizes(args);
395 template <
class S,
class T>
396 static inline size_t get_element_sizes(
const WorkspaceArgs<S, T> &args)
398 return Element::get_element_size(args) + Workspace<Elements...>::get_element_sizes(args);
401 template <
class WorkspaceType,
class S,
class T>
402 static void initialise_elements(WorkspaceType *ws,
void *buffer,
const WorkspaceArgs<S, T> &args)
404 buffer = Element::initialise(ws, buffer, args);
405 Workspace<Elements...>::initialise_elements(ws, buffer, args);
417 template <
class S,
class T>
418 static inline size_t get_element_sizes(
const WorkspaceArgs<S, T> &)
423 template <
class WorkspaceType,
class S,
class T>
424 static void initialise_elements(WorkspaceType *,
void *,
const WorkspaceArgs<S, T> &)
const DepthwiseArgs & depthwise_args
const OutputStage & output_stage
const StratType * strategy
const int32_t * requant_muls
template UniqueDepthwiseCommon< float > depthwise(const DepthwiseArgs &, const Nothing &)
const int32_t * requant_shifts