ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Workload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "IWorkload.hpp"
8 #include "WorkloadData.hpp"
9 #include "WorkloadInfo.hpp"
10 
11 #include <armnn/Logging.hpp>
12 
13 #include <Profiling.hpp>
14 
15 #include <client/include/IProfilingService.hpp>
16 
17 #include <algorithm>
18 
19 namespace armnn
20 {
21 
22 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
23 // in the various workload factories.
24 // There should never be an instantiation of a NullWorkload.
25 class NullWorkload : public IWorkload
26 {
27  NullWorkload()=delete;
28 };
29 
30 template <typename QueueDescriptor>
31 class BaseWorkload : public IWorkload
32 {
33 public:
34 
35  BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
36  : m_Data(descriptor),
37  m_Guid(arm::pipe::IProfilingService::GetNextGuid()),
39  {
40  m_Data.Validate(info);
41  }
42 
43  virtual const std::string& GetName() const override
44  {
45  return m_Name;
46  }
47 
48  void PostAllocationConfigure() override {}
49 
50  const QueueDescriptor& GetData() const { return m_Data; }
51 
52  arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
53 
54  virtual bool SupportsTensorHandleReplacement() const override
55  {
56  return false;
57  }
58 
59  // Replace input tensor handle with the given TensorHandle
60  void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
61  {
62  armnn::IgnoreUnused(tensorHandle, slot);
63  throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
64  }
65 
66  // Replace output tensor handle with the given TensorHandle
67  void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
68  {
69  armnn::IgnoreUnused(tensorHandle, slot);
70  throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
71  }
72 
73 protected:
75  const arm::pipe::ProfilingGuid m_Guid;
76  const std::string m_Name;
77 };
78 
79 // TypedWorkload used
80 template <typename QueueDescriptor, armnn::DataType... DataTypes>
81 class TypedWorkload : public BaseWorkload<QueueDescriptor>
82 {
83 public:
84 
85  TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
86  : BaseWorkload<QueueDescriptor>(descriptor, info)
87  {
88  std::vector<armnn::DataType> dataTypes = {DataTypes...};
89  armnn::DataType expectedInputType;
90 
91  if (!info.m_InputTensorInfos.empty())
92  {
93  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
94 
95  if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
96  {
97  throw armnn::Exception("Trying to create workload with incorrect type");
98  }
99  if (std::all_of(std::next(info.m_InputTensorInfos.begin()),
100  info.m_InputTensorInfos.end(),
101  [&](auto it){
102  return it.GetDataType() == expectedInputType;
103  }) == false)
104  {
105  throw armnn::Exception("Trying to create workload with incorrect type");
106  }
107  }
108  armnn::DataType expectedOutputType;
109 
110  if (!info.m_OutputTensorInfos.empty())
111  {
112  expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
113 
114  if (!info.m_InputTensorInfos.empty())
115  {
116  expectedInputType = info.m_InputTensorInfos.front().GetDataType();
117 
118  if (expectedOutputType != expectedInputType)
119  {
120  throw armnn::Exception( "Trying to create workload with incorrect type");
121  }
122  }
123  else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
124  {
125  throw armnn::Exception("Trying to create workload with incorrect type");
126  }
127  if (std::all_of(std::next(info.m_OutputTensorInfos.begin()),
128  info.m_OutputTensorInfos.end(),
129  [&](auto it){
130  return it.GetDataType() == expectedOutputType;
131  }) == false)
132  {
133  throw armnn::Exception("Trying to create workload with incorrect type");
134  }
135  }
136  }
137 };
138 
139 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
140 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
141 {
142 public:
143 
144  MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
145  : BaseWorkload<QueueDescriptor>(descriptor, info)
146  {
147  if (std::all_of(info.m_InputTensorInfos.begin(),
148  info.m_InputTensorInfos.end(),
149  [&](auto it){
150  return it.GetDataType() == InputDataType;
151  }) == false)
152  {
153  throw armnn::Exception("Trying to create workload with incorrect type");
154  }
155  if (std::all_of(info.m_OutputTensorInfos.begin(),
156  info.m_OutputTensorInfos.end(),
157  [&](auto it){
158  return it.GetDataType() == OutputDataType;
159  }) == false)
160  {
161  throw armnn::Exception("Trying to create workload with incorrect type");
162  }
163  }
164 };
165 
166 // FirstInputTypedWorkload used to check type of the first input
167 template <typename QueueDescriptor, armnn::DataType DataType>
168 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
169 {
170 public:
171 
172  FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
173  : BaseWorkload<QueueDescriptor>(descriptor, info)
174  {
175  if (!info.m_InputTensorInfos.empty())
176  {
177  if (info.m_InputTensorInfos.front().GetDataType() != DataType)
178  {
179  throw armnn::Exception("Trying to create workload with incorrect type");
180  }
181  }
182 
183  if (std::all_of(info.m_OutputTensorInfos.begin(),
184  info.m_OutputTensorInfos.end(),
185  [&](auto it){
186  return it.GetDataType() == DataType;
187  }) == false)
188  {
189  throw armnn::Exception("Trying to create workload with incorrect type");
190  }
191  }
192 };
193 
194 template <typename QueueDescriptor>
198 
199 template <typename QueueDescriptor>
201 
202 template <typename QueueDescriptor>
204 
205 template <typename QueueDescriptor>
207 
208 template <typename QueueDescriptor>
210 
211 template <typename QueueDescriptor>
215 
216 template <typename QueueDescriptor>
220 
221 template <typename QueueDescriptor>
225 
226 template <typename QueueDescriptor>
230 
231 template <typename QueueDescriptor>
235 
236 template <typename QueueDescriptor>
240 
241 template <typename QueueDescriptor>
245 
246 } //namespace armnn
virtual bool SupportsTensorHandleReplacement() const override
Definition: Workload.hpp:54
virtual const std::string & GetName() const override
Definition: Workload.hpp:43
const arm::pipe::ProfilingGuid m_Guid
Definition: Workload.hpp:75
const std::string m_Name
Definition: Workload.hpp:76
void PostAllocationConfigure() override
Definition: Workload.hpp:48
const QueueDescriptor & GetData() const
Definition: Workload.hpp:50
arm::pipe::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:52
void ReplaceInputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:60
void ReplaceOutputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition: Workload.hpp:67
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:35
QueueDescriptor m_Data
Definition: Workload.hpp:74
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:47
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:172
Workload interface to enqueue a layer computation.
Definition: IWorkload.hpp:14
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:144
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: Workload.hpp:85
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
DataType
Definition: Types.hpp:49
Contains information about TensorInfos of a layer.