ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RefComparisonWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 #include "Decoders.hpp"
10 #include "Encoders.hpp"
11 #include "RefWorkloadUtils.hpp"
12 
13 #include <Profiling.hpp>
14 
15 #include <armnn/TypesUtils.hpp>
16 
17 #include <functional>
18 
19 namespace armnn
20 {
21 
23  const WorkloadInfo& info)
25 {}
26 
28 {
30 }
31 
32 void RefComparisonWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs,
33  std::vector<ITensorHandle*> outputs)
34 {
35  const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
36  const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
37  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
38 
39  m_Input0 = MakeDecoder<InType>(inputInfo0);
40  m_Input1 = MakeDecoder<InType>(inputInfo1);
41 
42  m_Output = MakeEncoder<OutType>(outputInfo);
43 }
44 
46 {
48 }
49 
50 void RefComparisonWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
51 {
52  ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefComparisonWorkload_Execute");
53 
54  const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
55  const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
56  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
57 
58  const TensorShape& inShape0 = inputInfo0.GetShape();
59  const TensorShape& inShape1 = inputInfo1.GetShape();
60  const TensorShape& outShape = outputInfo.GetShape();
61 
62  m_Input0->Reset(inputs[0]->Map());
63  m_Input1->Reset(inputs[1]->Map());
64  m_Output->Reset(outputs[0]->Map());
65 
67  using GreaterFunction = ElementwiseBinaryFunction<std::greater<InType>>;
68  using GreaterOrEqualFunction = ElementwiseBinaryFunction<std::greater_equal<InType>>;
69  using LessFunction = ElementwiseBinaryFunction<std::less<InType>>;
70  using LessOrEqualFunction = ElementwiseBinaryFunction<std::less_equal<InType>>;
72 
73  switch (m_Data.m_Parameters.m_Operation)
74  {
76  {
77  EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
78  break;
79  }
81  {
82  GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
83  break;
84  }
86  {
87  GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
88  break;
89  }
91  {
92  LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
93  break;
94  }
96  {
97  LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
98  break;
99  }
101  {
102  NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
103  break;
104  }
105  default:
106  {
107  throw InvalidArgumentException(std::string("Unsupported comparison operation ") +
108  GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
109  }
110  }
111 }
112 
113 } // namespace armnn
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
QueueDescriptor m_Data
Definition: Workload.hpp:74
RefComparisonWorkload(const ComparisonQueueDescriptor &descriptor, const WorkloadInfo &info)
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
constexpr char const * GetComparisonOperationAsCString(ComparisonOperation operation)
Definition: TypesUtils.hpp:62
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
Contains information about TensorInfos of a layer.