ArmNN
 24.02
RefGatherNdWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <fmt/format.h>
8 
9 #include "Gather.hpp"
10 #include "Profiling.hpp"
11 #include "RefWorkloadUtils.hpp"
13 
14 namespace armnn
15 {
16 
18 {
20 }
21 
23 {
24  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
25  Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
26 }
27 
28 void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
29 {
30  ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefGatherNdWorkload_Execute");
31 
32  const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
33  const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
34  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
35 
36  std::unique_ptr<Decoder<float>> params_decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->Map());
37 
38  const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map());
39  std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements());
40  // Check for negative indices, it could not be checked in validate as we do not have access to the values there
41  for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
42  {
43  if (indices[i] < 0)
44  {
45  throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i)));
46  }
47  }
48 
49  std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
50 
51  std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(inputInfo0, inputInfo1);
52 
53  /// Calculate flattened indices: flattenedIndices = indices * flattenedCoefficients
54  // Calculate the flattened coefficients to use in the multiplication
55  // to calculate the flattened indices needed by gather
56  TensorShape paramsShape = inputInfo0.GetShape();
57  std::vector<unsigned int> flattenedCoeff(keyIndices["ND"], 1);
58  for (unsigned int i = 1; i < keyIndices["ND"]; ++i)
59  {
60  flattenedCoeff[i-1] = paramsShape[i];
61  }
62  for (unsigned int i = keyIndices["ND"]-1; i > 0; --i)
63  {
64  flattenedCoeff[i-1] *= flattenedCoeff[i];
65  }
66 
67  // Prepare the vector to store the output of the matrix multiplication,
68  // which will represent the flattened indices needed by gather
69  armnn::TensorInfo flattenedIndices_Info = inputInfo1;
70  flattenedIndices_Info.SetShape({ keyIndices["W"] });
71  std::vector<int32_t> flattenedIndices(flattenedIndices_Info.GetNumElements(), 0);
72 
73  // Multiplication to calculate the flattened indices, which are the indices needed by gather.
74  for (unsigned int i = 0; i < keyIndices["W"]; ++i)
75  {
76  for (unsigned int j = 0; j < keyIndices["ND"]; ++j)
77  {
78  flattenedIndices[i] += indices[i * keyIndices["ND"] + j] * static_cast<int32_t>(flattenedCoeff[j]);
79  }
80  }
81 
82  /// Call Gather with adequate shapes
83  // Reshape params into {K, C}
84  armnn::TensorInfo params_K_C_Info = inputInfo0;
85  params_K_C_Info.SetShape({ keyIndices["K"], keyIndices["C"] });
86 
87  // Reshape indices into {N, W}
88  armnn::TensorInfo indices_N_W_Info = inputInfo1;
89  indices_N_W_Info.SetShape({ keyIndices["N"], keyIndices["W"] });
90 
91  // Reshape output to have the shape given by gather {N, W, C}
92  // (the original outputInfo has the shape given by gatherNd)
93  armnn::TensorInfo outputGather_Info = outputInfo;
94  outputGather_Info.SetShape({ keyIndices["N"], keyIndices["W"], keyIndices["C"] });
95 
96  // output_gather = gather(params_K_C, indices_N_W)
97  Gather(params_K_C_Info, indices_N_W_Info, outputGather_Info,
98  *params_decoderPtr, flattenedIndices.data(), *output_encoderPtr, 0);
99 }
100 
101 } //namespace armnn
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:198
armnn::RefGatherNdWorkload::ExecuteAsync
void ExecuteAsync(ExecutionData &executionData) override
Definition: RefGatherNdWorkload.cpp:22
WorkloadUtils.hpp
armnn::experimental::ExecutionData::m_Data
void * m_Data
Definition: ExecutionData.hpp:16
armnn::TensorInfo
Definition: Tensor.hpp:152
Profiling.hpp
ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
Definition: RefWorkloadUtils.hpp:22
armnn::CalculateGatherNdKeyIndices
std::map< std::string, unsigned int > CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1)
Calculates the key index values needed for GatherNd: N, ND, K, W, C (N is always 1)
Definition: WorkloadUtils.cpp:312
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::RefGatherNdWorkload::Execute
void Execute() const override
Definition: RefGatherNdWorkload.cpp:17
armnn::GetTensorInfo
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
Definition: RefWorkloadUtils.hpp:33
Gather.hpp
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
RefWorkloadUtils.hpp
armnn::BaseWorkload< GatherNdQueueDescriptor >::m_Data
GatherNdQueueDescriptor m_Data
Definition: Workload.hpp:89
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::LayerType::Map
@ Map
armnn::experimental::WorkingMemDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkingMemDescriptor.hpp:20
armnn::TensorInfo::SetShape
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:195
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::experimental::WorkingMemDescriptor
Definition: WorkingMemDescriptor.hpp:18
RefGatherNdWorkload.hpp
armnn::experimental::WorkingMemDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkingMemDescriptor.hpp:21
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26
armnn::Gather
void Gather(const TensorInfo &paramsInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > &params, const int32_t *indices, Encoder< float > &output, const int32_t axis_int)
Definition: Gather.cpp:15
armnn::experimental::ExecutionData
Definition: ExecutionData.hpp:14