Compute Library
 23.11
CpuWinogradConv2d.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Utils.h"
33 
34 #include "src/common/utils/Log.h"
35 #include "src/core/CPP/Validate.h"
38 #include "src/core/NEON/kernels/assembly/winograd.hpp"
39 #include "src/core/NEON/kernels/convolution/common/tensor.hpp"
40 #include "src/core/NEON/kernels/convolution/common/utils.hpp"
47 #include "support/Cast.h"
48 
49 namespace arm_compute
50 {
51 namespace cpu
52 {
53 using namespace arm_compute::experimental;
54 using namespace arm_compute::utils::cast;
55 
56 namespace
57 {
58 inline Tensor4DShape internal_get_shape(const ITensorInfo *in)
59 {
60  const DataLayout data_layout = in->data_layout();
61  const int in_width = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH));
62  const int in_height = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT));
63  const int in_channels = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL));
64  const int in_batches = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES));
65 
66  return Tensor4DShape{in_batches, in_height, in_width, in_channels};
67 }
68 
69 Status validate_arguments(const ITensorInfo *src,
70  const ITensorInfo *weights,
71  const ITensorInfo *biases,
72  const ITensorInfo *dst,
73  const PadStrideInfo &conv_info)
74 {
75  ARM_COMPUTE_UNUSED(dst, weights);
77 
78  ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1,
79  "Winograd layer only supports unit strides.");
80  if (biases != nullptr)
81  {
83  ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
84  }
87  return Status{};
88 }
89 
90 bool get_winograd_kernel_implementation(const ITensorInfo *src,
91  const ITensorInfo *weights,
92  const ITensorInfo *dst,
93  const PadStrideInfo &conv_info,
94  const ActivationLayerInfo &act_info,
95  bool enable_fast_math,
96  arm_conv::winograd::WinogradImpl *winograd_impl,
97  std::unique_ptr<arm_conv::ConvolutionArgs> &conv_args)
98 {
99  arm_conv::winograd::WinogradConfig winograd_cfg;
101 
102  const DataType data_type = src->data_type();
103  Tensor4DShape in_shape{internal_get_shape(src)};
104  Tensor4DShape out_shape{internal_get_shape(dst)};
105  Tensor4DShape kernel_shape{internal_get_shape(weights)};
106  uint32_t nthreads = NEScheduler::get().num_threads();
107  // Get configuration arguments for Winograd
108  winograd_cfg.output_rows = 0;
109  winograd_cfg.output_cols = 0;
110  conv_args = std::make_unique<arm_conv::ConvolutionArgs>(
111  in_shape.n_batches,
112  arm_conv::Shape2D{static_cast<uint32_t>(in_shape.n_rows), static_cast<uint32_t>(in_shape.n_cols)},
113  in_shape.n_channels, conv_info.pad_top(), conv_info.pad_left(),
114  arm_conv::Shape2D{static_cast<uint32_t>(out_shape.n_rows), static_cast<uint32_t>(out_shape.n_cols)},
115  out_shape.n_channels,
116  arm_conv::Shape2D{static_cast<uint32_t>(kernel_shape.n_rows), static_cast<uint32_t>(kernel_shape.n_cols)},
118 
119  bool success = false;
120  if (data_type == DataType::F32)
121  {
122  success = arm_conv::winograd::get_implementation<float>(*winograd_impl, &CPUInfo::get(), *conv_args, nthreads,
123  enable_fast_math, &winograd_cfg, nullptr);
124  }
125 #if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
126  else if (data_type == DataType::F16)
127  {
128  success = arm_conv::winograd::get_implementation<__fp16>(*winograd_impl, &CPUInfo::get(), *conv_args, nthreads,
129  enable_fast_math, &winograd_cfg, nullptr);
130  }
131 #endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
132  else
133  {
134  success = false;
135  }
136  return success;
137 }
138 inline bool fuse_function_supported(const ActivationLayerInfo &act_info)
139 {
140  return act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ||
141  act_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU;
142 }
143 } // namespace
144 
146 
147  : _gemm_function(std::make_unique<CpuGemm>()),
148  _activation_func(std::make_unique<CpuActivation>()),
149  _transform_input_kernel(nullptr),
150  _transform_output_kernel(nullptr),
151  _permute_input(std::make_unique<CpuPermute>()),
152  _permute_output(std::make_unique<CpuPermute>()),
153  _permute_weights(std::make_unique<CpuPermute>()),
154  _aux_mem(AuxTensorIdx::Count),
155  _conv_args{nullptr},
156  _winograd_impl{},
157  _data_layout(),
158  _winograd_transformed_input{},
159  _winograd_transformed_output{},
160  _winograd_transformed_weights{},
161  _input_workspace(),
162  _output_workspace(),
163  _weights_hwio(),
164  _input_nhwc(),
165  _output_nhwc(),
166  _is_prepared{false},
167  _run_activation{false}
168 {
169 }
170 
172 
174  const ITensorInfo *weights,
175  const ITensorInfo *biases,
176  ITensorInfo *dst,
177  const PadStrideInfo &conv_info,
179  bool enable_fast_math)
180 {
182  ARM_COMPUTE_ERROR_THROW_ON(validate(src, weights, biases, dst, conv_info, act_info, enable_fast_math));
183  ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info, act_info, enable_fast_math);
184  ARM_COMPUTE_UNUSED(biases);
185  const DataType data_type = src->data_type();
186  uint32_t nthreads = NEScheduler::get().num_threads();
187  _data_layout = src->data_layout();
188  const Tensor4DShape kernel_shape{internal_get_shape(weights)};
189 
190  bool success = get_winograd_kernel_implementation(src, weights, dst, conv_info, act_info, enable_fast_math,
191  &_winograd_impl, _conv_args);
192 
193  ARM_COMPUTE_EXIT_ON_MSG_VAR(!success, "Unsupported kernel size: %d x %d.\n", kernel_shape.n_rows,
194  kernel_shape.n_cols);
196  _winograd_impl.input_transform->get_name().c_str());
198  _winograd_impl.input_transform->get_name().c_str());
200  _winograd_impl.input_transform->get_name().c_str());
201 
202  const bool has_impl = ((_winograd_impl.input_transform != nullptr) &&
203  (_winograd_impl.output_transform != nullptr) && (_winograd_impl.gemm_args != nullptr));
204  if (has_impl)
205  {
206  // Determine how much working space is required, allocate it.
207  const size_t input_workspace_size =
208  _winograd_impl.input_transform->get_working_space_size(*_conv_args, nthreads);
209  const size_t output_workspace_size =
210  _winograd_impl.output_transform->get_working_space_size(*_conv_args, nthreads);
211 
212  TensorInfo input_workspace_info(TensorShape(input_workspace_size), 1, DataType::U8);
213  TensorInfo output_workspace_info(TensorShape(output_workspace_size), 1, DataType::U8);
214  _input_workspace = input_workspace_info;
215  _output_workspace = output_workspace_info;
216 
217  const auto &wds = _winograd_impl.winograd_spec;
218 
219  // Preparing winograd transformed input tensor
220  const size_t data_type_size = src->element_size();
221  const uint32_t m = _winograd_impl.gemm_args->_Msize; // Total number of tiles
222  const uint32_t k = _winograd_impl.gemm_args->_Ksize; // Input channels
223  const uint32_t n = _winograd_impl.gemm_args->_Nsize; // Output channels
224  const uint32_t n_gemms = _winograd_impl.gemm_args->_nmulti;
225  const uint32_t n_batches = _winograd_impl.gemm_args->_nbatches;
226  constexpr size_t storage_alignment = 64;
227 
228  const TensorShape a_shape(k, m, n_batches, n_gemms);
229  Strides a_strides(data_type_size);
230  a_strides.set(1, data_type_size * _winograd_impl.winograd_spec.input_ld_row);
231  a_strides.set(2, data_type_size * _winograd_impl.winograd_spec.input_ld_batch);
232  a_strides.set(3, data_type_size * _winograd_impl.winograd_spec.input_ld_matrix);
233 
234  const TensorShape b_shape(n, k, n_gemms);
235  Strides b_strides(data_type_size);
236  b_strides.set(1, data_type_size * _winograd_impl.winograd_spec.weight_ld_row);
237  b_strides.set(2, data_type_size * _winograd_impl.winograd_spec.weight_ld_matrix);
238 
239  const TensorShape d_shape(n, m, n_batches, n_gemms);
240  Strides d_strides(data_type_size);
241  d_strides.set(1, data_type_size * _winograd_impl.winograd_spec.output_ld_row);
242  d_strides.set(2, data_type_size * _winograd_impl.winograd_spec.output_ld_batch);
243  d_strides.set(3, data_type_size * _winograd_impl.winograd_spec.output_ld_matrix);
244 
245  TensorInfo a_info{};
246  TensorInfo b_info{};
247  TensorInfo d_info{};
248  a_info.init(a_shape, 1, data_type, a_strides, 0, wds.input_matrix_size_bytes);
249  b_info.init(b_shape, 1, data_type, b_strides, 0, wds.weight_matrix_size_bytes);
250  d_info.init(d_shape, 1, data_type, d_strides, 0, wds.output_matrix_size_bytes);
251 
252  _winograd_transformed_input = a_info;
253  _winograd_transformed_weights = b_info;
254  _winograd_transformed_output = d_info;
255 
256  PermutationVector weights_permutation_vector(3U, 0U, 1U, 2U);
257 
258  // Configure the kernel to transform the input tensor from NCHW -> NHWC
259  if (_data_layout == DataLayout::NCHW)
260  {
261  _permute_input->configure(src, &_input_nhwc, PermutationVector(2U, 0U, 1U));
262  weights_permutation_vector = PermutationVector(3U, 2U, 0U, 1U);
263  }
264 
265  // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map]
266  _permute_weights->configure(weights, &_weights_hwio, weights_permutation_vector);
267 
268  // Reorder the convoluted output to ACL's ordering NCHW
269  if (_data_layout == DataLayout::NCHW)
270  {
271  // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output()
272  TensorInfo info(TensorShape(dst->dimension(2), dst->dimension(0), dst->dimension(1), dst->dimension(3)), 1,
273  dst->data_type());
274  _output_nhwc = info;
275  _permute_output->configure(&_output_nhwc, dst, PermutationVector(1U, 2U, 0U));
276  }
277 
278  // Configure input transform kernel
279  _transform_input_kernel =
280  std::make_unique<CpuWinogradConv2dTransformInputKernel>(_winograd_impl, *_conv_args, nthreads);
281 
282  // Configure GEMM function
283  _gemm_function->configure(&_winograd_transformed_input, &_winograd_transformed_weights, nullptr,
284  &_winograd_transformed_output, 1.0f, 0.f);
285 
286  // Configure output transform kernel
287  _transform_output_kernel =
288  std::make_unique<CpuWinogradConv2dTransformOutputKernel>(_winograd_impl, *_conv_args, nthreads);
289 
290  //Configure Activation Layer
291  _run_activation = act_info.enabled() && !fuse_function_supported(act_info);
292  if (_run_activation)
293  {
294  _activation_func->configure(dst, nullptr, act_info);
295  }
296 
297  const auto mm_mem_req = _gemm_function->workspace();
298  for (unsigned int slot = 0; slot < mm_mem_req.size(); ++slot)
299  {
300  _aux_mem[slot] = mm_mem_req[slot];
301  }
302 
303  // Request temporary memory. Overlap memory needed for Input/Output transformations as they run on different non-overlapping time-steps.
304  _aux_mem[TransformedInput] = MemoryInfo(offset_int_vec(TransformedInput), MemoryLifetime::Temporary,
305  wds.input_matrix_size_bytes, storage_alignment);
306  _aux_mem[TransformedOutput] = MemoryInfo(offset_int_vec(TransformedOutput), MemoryLifetime::Temporary,
307  wds.output_matrix_size_bytes, storage_alignment);
308  _aux_mem[WorkspaceIO] = MemoryInfo(offset_int_vec(WorkspaceIO), MemoryLifetime::Temporary,
309  std::max(input_workspace_size, output_workspace_size));
310  _aux_mem[PermutedWeights] =
311  MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, _weights_hwio.total_size());
312  _aux_mem[TransformedWeights] = MemoryInfo(offset_int_vec(TransformedWeights), MemoryLifetime::Persistent,
313  wds.weight_matrix_size_bytes, storage_alignment);
314  if (_data_layout == DataLayout::NCHW)
315  {
316  _aux_mem[PermutedInput].merge(offset_int_vec(PermutedInput), src->total_size());
317  _aux_mem[PermutedOutput].merge(offset_int_vec(PermutedOutput), dst->total_size());
318  }
319  }
320 }
322  const ITensorInfo *weights,
323  const ITensorInfo *biases,
324  const ITensorInfo *dst,
325  const PadStrideInfo &conv_info,
327  bool enable_fast_math)
328 {
331 
332  // Disable winograd for fp16 if fast math is false.
333  if (!enable_fast_math)
334  {
336  }
337 
338  const Tensor4DShape kernel_shape{internal_get_shape(weights)};
339  arm_conv::winograd::WinogradImpl winograd_impl{};
340 
341  std::unique_ptr<arm_conv::ConvolutionArgs> conv_args;
342  const bool success = get_winograd_kernel_implementation(src, weights, dst, conv_info, act_info, enable_fast_math,
343  &winograd_impl, conv_args);
344 
345  ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR(success == false, "Unsupported kernel size: %d x %d.\n", kernel_shape.n_rows,
346  kernel_shape.n_cols);
348  winograd_impl.input_transform->get_name().c_str());
350  winograd_impl.input_transform->get_name().c_str());
352  winograd_impl.input_transform->get_name().c_str());
353  return Status{};
354 }
355 
357 {
358  prepare(tensors);
359  auto src = tensors.get_const_tensor(ACL_SRC_0);
360  auto biases = tensors.get_const_tensor(ACL_SRC_2);
361  auto output = tensors.get_tensor(ACL_DST);
362  Window win;
363 
364  const uint32_t nthreads = NEScheduler::get().num_threads();
365 
366  // The Winograd transform implementation does fine-grain threading inside the transforms. Just pass thread_id and nthreads.
367  win.set(Window::DimX, Window::Dimension(0, nthreads, 1));
368 
369  // Wrap the winograd-domain tensorInfos created in configuration in tensors and allocate the required memory.
370  CpuAuxTensorHandler input_nhwc(offset_int_vec(PermutedInput), _input_nhwc, tensors, true);
371  CpuAuxTensorHandler winograd_input_transformed(offset_int_vec(TransformedInput), _winograd_transformed_input,
372  tensors, true);
373  CpuAuxTensorHandler input_workspace(offset_int_vec(WorkspaceIO), _input_workspace, tensors, true);
374  const bool is_nchw = _data_layout == DataLayout::NCHW;
375  if (is_nchw)
376  {
377  //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC
378  ITensorPack pack{{ACL_SRC, src}, {ACL_DST, input_nhwc.get()}};
379  _permute_input->run(pack);
380  }
381 
382  CpuAuxTensorHandler winograd_output_transformed(offset_int_vec(TransformedOutput), _winograd_transformed_output,
383  tensors, true);
384  CpuAuxTensorHandler output_workspace(offset_int_vec(WorkspaceIO), _output_workspace, tensors, true);
385  CpuAuxTensorHandler output_nhwc(offset_int_vec(PermutedOutput), _output_nhwc, tensors, true);
386 
387  ITensorPack transform_input_pack{{ACL_SRC, is_nchw ? input_nhwc.get() : src},
388  {ACL_DST, winograd_input_transformed.get()},
389  {ACL_INT, input_workspace.get()}};
390  NEScheduler::get().schedule_op(_transform_input_kernel.get(), Window::DimX, win, transform_input_pack);
391 
392  CpuAuxTensorHandler winograd_weights_transformed(offset_int_vec(TransformedWeights), _winograd_transformed_weights,
393  tensors, true);
394 
395  // Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs
396  ITensorPack gemm_pack = tensors;
397  gemm_pack.add_const_tensor(ACL_SRC, winograd_input_transformed.get());
398  gemm_pack.add_const_tensor(ACL_SRC_1, winograd_weights_transformed.get());
399  gemm_pack.add_const_tensor(ACL_BIAS, nullptr);
400  gemm_pack.add_tensor(ACL_DST, winograd_output_transformed.get());
401  _gemm_function->run(gemm_pack);
402 
403  // Output transform
404  ITensorPack transform_output_pack{{ACL_SRC_0, winograd_output_transformed.get()},
405  {ACL_DST, is_nchw ? output_nhwc.get() : output},
406  {ACL_SRC_1, biases},
407  {ACL_INT, output_workspace.get()}};
408  NEScheduler::get().schedule_op(_transform_output_kernel.get(), Window::DimX, win, transform_output_pack);
409  if (is_nchw)
410  {
411  // Reorder the convoluted output to ACL's ordering NCHW
412  ITensorPack pack{{ACL_SRC, output_nhwc.get()}, {ACL_DST, output}};
413  _permute_output->run(pack);
414  }
415  if (_run_activation)
416  {
417  ITensorPack pack{{ACL_SRC, output}, {ACL_DST, output}};
418  _activation_func->run(pack);
419  }
420 }
421 
423 {
424  if (!_is_prepared)
425  {
426  const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1);
427  ITensor *weights_aux =
428  utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
429 
430  CpuAuxTensorHandler permuted_weights(_weights_hwio, *weights_aux);
431  ITensorPack permute_tensors{{ACL_SRC, weights}, {ACL_DST, permuted_weights.get()}};
432  _permute_weights->run(permute_tensors);
433  const int element_size_in_bytes = permuted_weights.get()->info()->element_size();
434  // Weights were in OHWI format, before being permuted "permuted_weights" to be in HWIO format.
435  const unsigned int height_idx = 3; // H in HWIO
436  const unsigned int width_idx = 2; // W in HWIO
437  const unsigned int channel_idx = 1; // I in HWIO
438 
439  const int permuted_weight_row_stride =
440  permuted_weights.get()->info()->strides_in_bytes()[height_idx] / element_size_in_bytes;
441  const int permuted_weight_col_stride =
442  permuted_weights.get()->info()->strides_in_bytes()[width_idx] / element_size_in_bytes;
443  const int permuted_weight_channel_stride =
444  permuted_weights.get()->info()->strides_in_bytes()[channel_idx] / element_size_in_bytes;
445 
446  // Wrap the winograd-domain transformed weight TensorInfo in Auxiliary tensor and allocate the required memory.
447  ITensor *weights_transf =
448  utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(TransformedWeights)));
449  ARM_COMPUTE_ERROR_ON_NULLPTR(weights_transf);
450  CpuAuxTensorHandler winograd_transformed_weights(_winograd_transformed_weights, *weights_transf);
451 
452  const void *permuted_weights_ptr;
453  void *win_wght_transf_ptr;
454 
455  permuted_weights_ptr = reinterpret_cast<const void *>(
456  permuted_weights.get()->buffer() + permuted_weights.get()->info()->offset_first_element_in_bytes());
457  win_wght_transf_ptr =
458  reinterpret_cast<void *>(winograd_transformed_weights.get()->buffer() +
459  winograd_transformed_weights.get()->info()->offset_first_element_in_bytes());
460 
461  // Prepare Weights
462  _winograd_impl.weight_transform->execute(
463  *_conv_args, permuted_weights_ptr, permuted_weight_row_stride, permuted_weight_col_stride,
464  permuted_weight_channel_stride, win_wght_transf_ptr, _winograd_impl.winograd_spec, 0, 1 // Thread 1 of 1
465  );
466  ITensorPack gemm_pack = tensors;
467  gemm_pack.add_const_tensor(ACL_SRC_1, winograd_transformed_weights.get());
468  _gemm_function->prepare(gemm_pack);
469  _is_prepared = 1;
470  }
471 }
473 {
474  return _aux_mem;
475 }
476 
477 } // namespace cpu
478 } // namespace arm_compute
arm_compute::DataLayout::NCHW
@ NCHW
Num samples, channels, height, width.
arm_compute::cpu::CpuWinogradConv2d::prepare
void prepare(ITensorPack &constants) override
Prepare the function for executing.
Definition: CpuWinogradConv2d.cpp:422
arm_compute::cpu::CpuAuxTensorHandler
Definition: CpuAuxTensorHandler.h:39
Cast.h
arm_compute::experimental::MemoryRequirements
std::vector< MemoryInfo > MemoryRequirements
Definition: Types.h:123
arm_compute::cpu::CpuWinogradConv2d::~CpuWinogradConv2d
~CpuWinogradConv2d()
Destructor.
arm_compute::test::validation::src
SimpleTensor< float > src
Definition: DFT.cpp:155
CpuWinogradConv2dKernel.h
arm_compute::Dimensions::set
void set(size_t dimension, T value, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
Definition: Dimensions.h:75
arm_compute::DataLayout
DataLayout
[DataLayout enum definition]
Definition: CoreTypes.h:110
arm_compute::DataLayoutDimension::CHANNEL
@ CHANNEL
channel
FunctionDescriptors.h
ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR(cond, msg,...)
If the condition is true, an error is returned.
Definition: Error.h:228
arm_compute::TensorShape
Shape of a tensor.
Definition: TensorShape.h:39
arm_compute::IScheduler::schedule_op
virtual void schedule_op(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors)=0
Runs the kernel in the same thread as the caller synchronously.
arm_compute::test::validation::dst
auto dst
Definition: DFT.cpp:170
arm_compute::cpu::kernels::validate_arguments
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
Definition: CpuDirectConv2dKernel.cpp:57
arm_compute::ITensorInfo::element_size
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
arm_compute::Window::DimX
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
arm_compute::assembly_utils::map_to_arm_gemm_activation
arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
Performs a mapping between Compute Library ActivationLayerInfo and the assembly Activation structure.
Definition: AssemblyUtils.cpp:32
arm_compute::CPUInfo::get
static CPUInfo & get()
Access the KernelLibrary singleton.
Definition: CPPTypes.cpp:41
ARM_COMPUTE_EXIT_ON_MSG_VAR
#define ARM_COMPUTE_EXIT_ON_MSG_VAR(cond, msg,...)
If the condition is true, the given message is printed and program exits.
Definition: Error.h:398
arm_compute::ITensor
Interface for CPU tensor.
Definition: ITensor.h:36
arm_compute::ITensorPack::add_tensor
void add_tensor(int id, ITensor *tensor)
Add tensor to the pack.
Definition: ITensorPack.cpp:38
arm_compute::ITensorPack::get_tensor
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Definition: ITensorPack.cpp:63
arm_compute::experimental::MemoryLifetime::Prepare
@ Prepare
AssemblyUtils.h
arm_compute::ACL_SRC_0
@ ACL_SRC_0
Definition: Types.h:45
Error.h
arm_compute::cpu::CpuAuxTensorHandler::get
ITensor * get()
Definition: CpuAuxTensorHandler.h:105
arm_compute::ACL_SRC_1
@ ACL_SRC_1
Definition: Types.h:46
arm_compute::ITensorPack::add_const_tensor
void add_const_tensor(int id, const ITensor *tensor)
Add const tensor to the pack.
Definition: ITensorPack.cpp:48
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:677
arm_compute::cpu::CpuGemm
Basic function to execute GEMM.
Definition: CpuGemm.h:64
arm_compute::DataLayoutDimension::WIDTH
@ WIDTH
width
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:952
arm_compute::ACL_SRC_2
@ ACL_SRC_2
Definition: Types.h:47
arm_compute::cpu::data_layout
constexpr auto data_layout
Definition: impl.h:36
ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL
#define ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(log_level, fmt,...)
Definition: Log.h:31
ARM_COMPUTE_RETURN_ON_ERROR
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:205
arm_compute::ActivationLayerInfo
Activation Layer Information class.
Definition: ActivationLayerInfo.h:55
arm_compute::Strides
Strides of an item in bytes.
Definition: Strides.h:38
arm_compute::test::validation::act_info
act_info
Definition: DirectConvolutionLayer.cpp:547
ARM_COMPUTE_ERROR_ON_NULLPTR
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:159
arm_compute::experimental::MemoryInfo
Definition: Types.h:91
arm_compute::utils::cast::U
U
Definition: SaturateCast.h:65
arm_compute::PermutationVector
Strides PermutationVector
Permutation vector.
Definition: CoreTypes.h:38
arm_compute::ITensor::info
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
arm_compute::ITensorPack::get_const_tensor
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
Definition: ITensorPack.cpp:53
ARM_COMPUTE_ERROR_THROW_ON
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
arm_compute::cpu::CpuWinogradConv2d::configure
void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info=ActivationLayerInfo(), bool enable_fast_math=false)
Set the input and output tensors.
Definition: CpuWinogradConv2d.cpp:173
arm_compute::ITensorPack
Tensor packing service.
Definition: ITensorPack.h:39
CpuActivation.h
arm_compute::cpu::CpuWinogradConv2d::run
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
Definition: CpuWinogradConv2d.cpp:356
arm_compute::DataLayoutDimension::HEIGHT
@ HEIGHT
height
ARM_COMPUTE_RETURN_ERROR_ON
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:298
arm_compute::TensorInfo::total_size
size_t total_size() const override
Returns the total size of the tensor in bytes.
Definition: TensorInfo.h:261
arm_compute::ACL_DST
@ ACL_DST
Definition: Types.h:55
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:117
arm_compute::DataType::U8
@ U8
unsigned 8-bit number
arm_compute::Scheduler::get
static IScheduler & get()
Access the scheduler singleton.
Definition: Scheduler.cpp:94
arm_compute::Status
Status class.
Definition: Error.h:52
arm_compute::utils::cast
Definition: Cast.h:33
WindowHelpers.h
CpuAuxTensorHandler.h
CpuPermute.h
CpuWinogradConv2d.h
arm_gemm::GemmConfig
Definition: arm_gemm.hpp:105
arm_gemm.hpp
arm_compute::test::validation::pack
ITensorPack pack
Definition: Im2Col.cpp:188
ARM_COMPUTE_UNUSED
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:151
arm_compute::Window::Dimension
Describe one of the image's dimensions with a start, end and step.
Definition: Window.h:79
arm_compute::TensorInfo::init
void init(Format format)
Initialize the tensor info with just a format.
Definition: TensorInfo.cpp:137
arm_compute::PadStrideInfo
Definition: CoreTypes.h:139
arm_compute::Window::set
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:53
arm_compute::test::validation::data_type
data_type
Definition: Cast.cpp:222
AsymmHelpers.h
MemoryHelpers.h
arm_compute::experimental
Definition: Types.h:83
arm_compute::ACL_INT
@ ACL_INT
Definition: Types.h:62
arm_compute::cpu::CpuWinogradConv2d::workspace
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
Definition: CpuWinogradConv2d.cpp:472
Utils.h
arm_compute::get_data_layout_dimension_index
size_t get_data_layout_dimension_index(const DataLayout &data_layout, const DataLayoutDimension &data_layout_dimension)
Get the index of the given dimension.
Definition: Helpers.inl:201
NEScheduler.h
ShapeCalculator.h
arm_compute::TensorInfo
Store the tensor's metadata.
Definition: TensorInfo.h:41
arm_compute::offset_int_vec
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38
arm_compute::Window
Describe a multidimensional execution window.
Definition: Window.h:39
Validate.h
arm_compute::cpu::CpuPermute
Basic function to run kernels::CpuPermuteKernel.
Definition: CpuPermute.h:34
arm_compute::ACL_BIAS
@ ACL_BIAS
Definition: Types.h:74
ARM_COMPUTE_RETURN_ERROR_ON_MSG
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:245
arm_compute
Copyright (c) 2017-2023 Arm Limited.
Definition: introduction.dox:24
arm_compute::test::validation::conv_info
conv_info
Definition: DirectConvolutionLayer.cpp:547
arm_compute::DataType::F16
@ F16
16-bit floating-point number
arm_compute::ITensorInfo::offset_first_element_in_bytes
virtual size_t offset_first_element_in_bytes() const =0
The offset from the beginning of the memory allocation to the first element of the tensor.
arm_compute::ITensorInfo::strides_in_bytes
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
Log.h
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
arm_compute::IScheduler::num_threads
virtual unsigned int num_threads() const =0
Returns the number of threads that the SingleThreadScheduler has in its pool.
arm_compute::cpu::channel_idx
const size_t channel_idx
Definition: impl.h:39
arm_compute::ACL_SRC
@ ACL_SRC
Definition: Types.h:44
arm_compute::ITensorInfo
Store the tensor's metadata.
Definition: ITensorInfo.h:44
arm_compute::DataLayoutDimension::BATCHES
@ BATCHES
batches
arm_compute::cpu::CpuWinogradConv2d::CpuWinogradConv2d
CpuWinogradConv2d()
Constructor.
Definition: CpuWinogradConv2d.cpp:145
arm_compute::DataType::F32
@ F32
32-bit floating-point number
arm_compute::cpu::width_idx
const size_t width_idx
Definition: impl.h:37
arm_compute::test::validation::info
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
arm_compute::cpu::CpuActivation
Basic function to run kernels::CpuActivationKernel.
Definition: CpuActivation.h:36
ARM_COMPUTE_LOG_PARAMS
#define ARM_COMPUTE_LOG_PARAMS(...)
Definition: Log.h:35
Validate.h
arm_compute::cpu::CpuWinogradConv2d::validate
static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info=ActivationLayerInfo(), bool enable_fast_math=false)
Static function to check if given info will lead to a valid configuration of CpuWinogradConv2d.
Definition: CpuWinogradConv2d.cpp:321
arm_compute::DataType
DataType
Available data types.
Definition: CoreTypes.h:83
arm_compute::cpu::height_idx
const size_t height_idx
Definition: impl.h:38
arm_compute::ITensor::buffer
virtual uint8_t * buffer() const =0
Interface to be implemented by the child class to return a pointer to CPU memory.
arm_compute::logging::LogLevel::INFO
@ INFO
Information log level.