24.08
RefGatherWorkload.cpp
Go to the documentation of this file.
1
//
2
// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
3
// SPDX-License-Identifier: MIT
4
//
5
6
#include "
RefGatherWorkload.hpp
"
7
8
#include "
Gather.hpp
"
9
#include "
Profiling.hpp
"
10
#include "
RefWorkloadUtils.hpp
"
11
#include <
ResolveType.hpp
>
12
#include <fmt/format.h>
13
14
namespace
armnn
15
{
16
17
void
RefGatherWorkload::Execute
()
const
18
{
19
Execute
(
m_Data
.
m_Inputs
,
m_Data
.
m_Outputs
);
20
}
21
22
void
RefGatherWorkload::ExecuteAsync
(
ExecutionData
& executionData)
23
{
24
WorkingMemDescriptor
* workingMemDescriptor =
static_cast<
WorkingMemDescriptor
*
>
(executionData.
m_Data
);
25
Execute
(workingMemDescriptor->
m_Inputs
, workingMemDescriptor->
m_Outputs
);
26
}
27
28
void
RefGatherWorkload::Execute
(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
const
29
{
30
ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID
(
"RefGatherWorkload_Execute"
);
31
32
const
TensorInfo
& inputInfo0 =
GetTensorInfo
(inputs[0]);
33
const
TensorInfo
& inputInfo1 =
GetTensorInfo
(inputs[1]);
34
const
TensorInfo
& outputInfo =
GetTensorInfo
(outputs[0]);
35
36
std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->
Map
());
37
Decoder<float>
& decoder = *decoderPtr;
38
39
const
int32_t* indicesData =
reinterpret_cast<
int32_t*
>
(inputs[1]->Map());
40
// Check for negative indices, it could not be checked in validate as we do not have access to the values there
41
for
(
unsigned
int
i = 0; i < inputInfo1.
GetNumElements
(); ++i)
42
{
43
if
(indicesData[i] < 0)
44
{
45
throw
InvalidArgumentException
((fmt::format(
"Gather: indices[{}] < 0"
, i)));
46
}
47
}
48
49
std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->
Map
());
50
Encoder<float>& encoder = *encoderPtr;
51
52
Gather
(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder,
m_Data
.
m_Parameters
.
m_Axis
);
53
}
54
55
}
//namespace armnn
armnn::Decoder< float >
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition:
Tensor.hpp:198
armnn::experimental::ExecutionData::m_Data
void * m_Data
Definition:
ExecutionData.hpp:16
armnn::TensorInfo
Definition:
Tensor.hpp:152
Profiling.hpp
ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
Definition:
RefWorkloadUtils.hpp:22
ResolveType.hpp
armnn::QueueDescriptorWithParameters::m_Parameters
LayerDescriptor m_Parameters
Definition:
WorkloadData.hpp:66
armnn::InvalidArgumentException
Definition:
Exceptions.hpp:80
armnn::GetTensorInfo
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
Definition:
RefWorkloadUtils.hpp:33
armnn::GatherDescriptor::m_Axis
int32_t m_Axis
The axis in params to gather indices from.
Definition:
Descriptors.hpp:981
RefGatherWorkload.hpp
Gather.hpp
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition:
WorkloadData.hpp:27
RefWorkloadUtils.hpp
armnn::BaseWorkload< GatherQueueDescriptor >::m_Data
GatherQueueDescriptor m_Data
Definition:
Workload.hpp:89
armnn::RefGatherWorkload::Execute
void Execute() const override
Definition:
RefGatherWorkload.cpp:17
armnn::RefGatherWorkload::ExecuteAsync
void ExecuteAsync(ExecutionData &executionData) override
Definition:
RefGatherWorkload.cpp:22
armnn::LayerType::Map
@ Map
armnn::experimental::WorkingMemDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition:
WorkingMemDescriptor.hpp:20
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition:
01_00_quick_start.dox:6
armnn::experimental::WorkingMemDescriptor
Definition:
WorkingMemDescriptor.hpp:18
armnn::experimental::WorkingMemDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition:
WorkingMemDescriptor.hpp:21
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition:
WorkloadData.hpp:26
armnn::Gather
void Gather(const TensorInfo ¶msInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > ¶ms, const int32_t *indices, Encoder< float > &output, const int32_t axis_int)
Definition:
Gather.cpp:15
armnn::experimental::ExecutionData
Definition:
ExecutionData.hpp:14
src
backends
reference
workloads
RefGatherWorkload.cpp
Generated on Wed Aug 28 2024 14:31:51 for Arm NN by
1.8.17