ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RefChannelShuffleWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
9 #include "RefWorkloadUtils.hpp"
10 #include "Profiling.hpp"
11 #include "Decoders.hpp"
12 #include "Encoders.hpp"
13 
14 namespace armnn
15 {
17 {
19 }
20 
21 // Reference implementation for channel shuffle taken from
22 // https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/master/nn/common/operations/ChannelShuffle.cpp
23 void RefChannelShuffleWorkload::Execute(std::vector<ITensorHandle*> inputs,
24  std::vector<ITensorHandle*> outputs) const
25 {
26  ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefChannelShuffleWorkload_Execute");
27 
28  const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
29  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
30  std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo, inputs[0]->Map());
31  Decoder<float>& decoder = *decoderPtr;
32 
33  std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
34  Encoder<float>& encoder = *encoderPtr;
35 
36  auto getNumberOfElements = [](const TensorShape& tensorShape,uint32_t startAxis, uint32_t lastAxis)
37  {
38  uint32_t count = 1;
39  for (uint32_t i = startAxis; i < lastAxis; i++)
40  {
41  count *= tensorShape[i];
42  }
43  return count;
44  };
45  const TensorShape tensorShape = GetTensorInfo(inputs[0]).GetShape();
46  uint32_t channelsAxis = m_Data.m_Parameters.m_Axis; // channelsAxis to perform channel shuffle on
47 
48  const uint32_t numGroups = m_Data.m_Parameters.m_NumGroups;
49  const uint32_t groupSize = tensorShape[channelsAxis] / numGroups;
50 
51  uint32_t outerSize = getNumberOfElements(tensorShape, 0, channelsAxis);
52  uint32_t innerSize = getNumberOfElements(tensorShape, channelsAxis + 1, tensorShape.GetNumDimensions());
53 
54  for (uint32_t outer = 0; outer < outerSize; ++outer)
55  {
56  for (uint32_t inner = 0; inner < innerSize; ++inner)
57  {
58  uint32_t decoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
59  decoder += decoderStep1;
60  uint32_t encoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
61  encoder += encoderStep1;
62  for (uint32_t i = 0; i < groupSize; i++)
63  {
64  for (uint32_t j = 0; j < numGroups; j++, encoder += innerSize, encoderStep1 += innerSize)
65  {
66  decoder += innerSize * (i + j * groupSize);
67  float decoded = decoder.Get();
68  encoder.Set(decoded);
69  decoder -= innerSize * (i + j * groupSize);
70  }
71  }
72  decoder -= decoderStep1;
73  encoder -= encoderStep1;
74  }
75  }
76 }
77 }
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
QueueDescriptor m_Data
Definition: Workload.hpp:74
virtual IType Get() const =0
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs