ArmNN
 25.11
Loading...
Searching...
No Matches
RefScatterNdWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <fmt/format.h>
9#include "ScatterNd.hpp"
10#include "Profiling.hpp"
11
12namespace armnn
13{
14
18
20 {
21 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
22 }
23
24 void RefScatterNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
25 {
26 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefScatterNdWorkload_Execute");
27
29 {
30 // Getting TensorInfos for three inputs slots
31 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
32 const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]);
33 const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]);
34
35 // Getting Decoder for input
36 std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
37 inputs[0]->Map());
38
39 // Getting Decoder for indices
40 std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]),
41 inputs[1]->Map());
42
43 // Getting Decoder for updates
44 std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]),
45 inputs[2]->Map());
46
47 // Getting Encoder for output
48 std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
49 outputs[0]->Map());
50
51 ScatterNd(inputInfo,
52 indicesInfo,
53 updatesInfo,
54 *inputDecoder,
55 *indicesDecoder,
56 *updatesDecoder,
57 *outputEncoder,
59 }
60 else
61 {
62 // Getting TensorInfos for three inputs slots
63 const TensorInfo& shapeInfo = GetTensorInfo(inputs[0]);
64 const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]);
65 const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]);
66
67 // Getting Decoder for shape
68 std::unique_ptr<Decoder<int>> shapeDecoder = MakeDecoder<int>(GetTensorInfo(inputs[0]),
69 inputs[0]->Map());
70
71 // Getting Decoder for indices
72 std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]),
73 inputs[1]->Map());
74
75 // Getting Decoder for updates
76 std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]),
77 inputs[2]->Map());
78
79 // Getting Encoder for output
80 std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
81 outputs[0]->Map());
82
83 ScatterNd(indicesInfo,
84 updatesInfo,
85 shapeInfo,
86 *indicesDecoder,
87 *updatesDecoder,
88 *shapeDecoder,
89 *outputEncoder,
91 }
92 }
93
94} // namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
RefBaseWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info)
RefScatterNdWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info)
Copyright (c) 2021 ARM Limited and Contributors.
std::unique_ptr< Decoder< T > > MakeDecoder(const TensorInfo &info, const void *data=nullptr)
std::unique_ptr< Encoder< T > > MakeEncoder(const TensorInfo &info, void *data=nullptr)
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
bool m_InputEnabled
Flag to show if input tensor is accepted.
Contains information about TensorInfos of a layer.