ArmNN
 25.11
Loading...
Searching...
No Matches
RefGatherWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7
8#include "Gather.hpp"
9#include "Profiling.hpp"
10#include "RefWorkloadUtils.hpp"
11#include <ResolveType.hpp>
12#include <fmt/format.h>
13
14namespace armnn
15{
16
18{
19 auto inputDataType = GetTensorInfo(m_Data.m_Inputs[0]).GetDataType();
20 if(inputDataType == DataType::Signed64)
21 {
22 Execute<double_t>(m_Data.m_Inputs, m_Data.m_Outputs);
23 }
24 else
25 {
26 Execute<float>(m_Data.m_Inputs, m_Data.m_Outputs);
27 }
28}
29
30template <typename T>
31void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
32{
33 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefGatherWorkload_Execute");
34
35 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
36 const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
37 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
38
39 const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map());
40 // Check for negative indices, it could not be checked in validate as we do not have access to the values there
41 for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
42 {
43 if (indicesData[i] < 0)
44 {
45 throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i)));
46 }
47 }
48
49 std::unique_ptr<Decoder<T>> decoderPtr = MakeDecoder<T>(inputInfo0, inputs[0]->Map());
50 Decoder<T>& decoder = *decoderPtr;
51
52 std::unique_ptr<Encoder<T>> encoderPtr = MakeEncoder<T>(outputInfo, outputs[0]->Map());
53 Encoder<T>& encoder = *encoderPtr;
54
55 Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder, m_Data.m_Parameters.m_Axis);
56}
57} //namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
void Execute() const override
unsigned int GetNumElements() const
Definition Tensor.hpp:198
DataType GetDataType() const
Definition Tensor.hpp:200
Copyright (c) 2021 ARM Limited and Contributors.
std::unique_ptr< Decoder< T > > MakeDecoder(const TensorInfo &info, const void *data=nullptr)
std::unique_ptr< Encoder< T > > MakeEncoder(const TensorInfo &info, void *data=nullptr)
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)