ArmNN
 25.11
Loading...
Searching...
No Matches
Workload.hpp
Go to the documentation of this file.
1//
2// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "IWorkload.hpp"
8#include "WorkloadData.hpp"
9#include "WorkloadInfo.hpp"
10
11#include <armnn/Logging.hpp>
12
13#include <Profiling.hpp>
14
15#include <client/include/IProfilingService.hpp>
16
17#include <algorithm>
18
19namespace armnn
20{
21
22// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
23// in the various workload factories.
24// There should never be an instantiation of a NullWorkload.
25class NullWorkload : public IWorkload
26{
27 NullWorkload()=delete;
28};
29
30template <typename QueueDescriptor>
31class BaseWorkload : public IWorkload
32{
33public:
34
35 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
36 : m_Data(descriptor),
37 m_Guid(arm::pipe::IProfilingService::GetNextGuid()),
39 {
40 m_Data.Validate(info);
41 }
42
43 virtual const std::string& GetName() const override
44 {
45 return m_Name;
46 }
47
48 void PostAllocationConfigure() override {}
49
50 const QueueDescriptor& GetData() const { return m_Data; }
51
52 arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
53
54 virtual bool SupportsTensorHandleReplacement() const override
55 {
56 return false;
57 }
58
59 // Replace input tensor handle with the given TensorHandle
60 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
61 {
62 armnn::IgnoreUnused(tensorHandle, slot);
63 throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
64 }
65
66 // Replace output tensor handle with the given TensorHandle
67 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
68 {
69 armnn::IgnoreUnused(tensorHandle, slot);
70 throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
71 }
72
73protected:
75 const arm::pipe::ProfilingGuid m_Guid;
76 const std::string m_Name;
77};
78
79// TypedWorkload used
80template <typename QueueDescriptor, armnn::DataType... DataTypes>
81class TypedWorkload : public BaseWorkload<QueueDescriptor>
82{
83public:
84
86 : BaseWorkload<QueueDescriptor>(descriptor, info)
87 {
88 std::vector<armnn::DataType> dataTypes = {DataTypes...};
89 armnn::DataType expectedInputType;
90
91 if (!info.m_InputTensorInfos.empty())
92 {
93 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
94
95 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
96 {
97 throw armnn::Exception("Trying to create workload with incorrect type");
98 }
99 if (std::all_of(std::next(info.m_InputTensorInfos.begin()),
100 info.m_InputTensorInfos.end(),
101 [&](auto it){
102 return it.GetDataType() == expectedInputType;
103 }) == false)
104 {
105 throw armnn::Exception("Trying to create workload with incorrect type");
106 }
107 }
108 armnn::DataType expectedOutputType;
109
110 if (!info.m_OutputTensorInfos.empty())
111 {
112 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
113
114 if (!info.m_InputTensorInfos.empty())
115 {
116 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
117
118 if (expectedOutputType != expectedInputType)
119 {
120 throw armnn::Exception( "Trying to create workload with incorrect type");
121 }
122 }
123 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
124 {
125 throw armnn::Exception("Trying to create workload with incorrect type");
126 }
127 if (std::all_of(std::next(info.m_OutputTensorInfos.begin()),
128 info.m_OutputTensorInfos.end(),
129 [&](auto it){
130 return it.GetDataType() == expectedOutputType;
131 }) == false)
132 {
133 throw armnn::Exception("Trying to create workload with incorrect type");
134 }
135 }
136 }
137};
138
139template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
140class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
141{
142public:
143
145 : BaseWorkload<QueueDescriptor>(descriptor, info)
146 {
147 if (std::all_of(info.m_InputTensorInfos.begin(),
148 info.m_InputTensorInfos.end(),
149 [&](auto it){
150 return it.GetDataType() == InputDataType;
151 }) == false)
152 {
153 throw armnn::Exception("Trying to create workload with incorrect type");
154 }
155 if (std::all_of(info.m_OutputTensorInfos.begin(),
156 info.m_OutputTensorInfos.end(),
157 [&](auto it){
158 return it.GetDataType() == OutputDataType;
159 }) == false)
160 {
161 throw armnn::Exception("Trying to create workload with incorrect type");
162 }
163 }
164};
165
166// FirstInputTypedWorkload used to check type of the first input
167template <typename QueueDescriptor, armnn::DataType DataType>
168class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
169{
170public:
171
173 : BaseWorkload<QueueDescriptor>(descriptor, info)
174 {
175 if (!info.m_InputTensorInfos.empty())
176 {
177 if (info.m_InputTensorInfos.front().GetDataType() != DataType)
178 {
179 throw armnn::Exception("Trying to create workload with incorrect type");
180 }
181 }
182
183 if (std::all_of(info.m_OutputTensorInfos.begin(),
184 info.m_OutputTensorInfos.end(),
185 [&](auto it){
186 return it.GetDataType() == DataType;
187 }) == false)
188 {
189 throw armnn::Exception("Trying to create workload with incorrect type");
190 }
191 }
192};
193
194template <typename QueueDescriptor>
198
199template <typename QueueDescriptor>
201
202template <typename QueueDescriptor>
204
205template <typename QueueDescriptor>
207
208template <typename QueueDescriptor>
210
211template <typename QueueDescriptor>
215
216template <typename QueueDescriptor>
220
221template <typename QueueDescriptor>
225
226template <typename QueueDescriptor>
230
231template <typename QueueDescriptor>
235
236template <typename QueueDescriptor>
240
241template <typename QueueDescriptor>
245
246} //namespace armnn
virtual bool SupportsTensorHandleReplacement() const override
Definition Workload.hpp:54
const arm::pipe::ProfilingGuid m_Guid
Definition Workload.hpp:75
const QueueDescriptor & GetData() const
Definition Workload.hpp:50
const std::string m_Name
Definition Workload.hpp:76
void PostAllocationConfigure() override
Definition Workload.hpp:48
arm::pipe::ProfilingGuid GetGuid() const final
Definition Workload.hpp:52
void ReplaceInputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition Workload.hpp:60
void ReplaceOutputTensorHandle(ITensorHandle *tensorHandle, unsigned int slot) override
Definition Workload.hpp:67
virtual const std::string & GetName() const override
Definition Workload.hpp:43
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition Workload.hpp:35
QueueDescriptor m_Data
Definition Workload.hpp:74
FirstInputTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition Workload.hpp:172
Workload interface to enqueue a layer computation.
Definition IWorkload.hpp:14
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition Workload.hpp:144
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Definition Workload.hpp:85
Copyright (c) 2021 ARM Limited and Contributors.
TypedWorkload< QueueDescriptor, armnn::DataType::Float32 > Float32Workload
Definition Workload.hpp:200
TypedWorkload< QueueDescriptor, armnn::DataType::Signed32 > Int32Workload
Definition Workload.hpp:206
MultiTypedWorkload< QueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean > BaseFloat32ComparisonWorkload
Definition Workload.hpp:212
MultiTypedWorkload< QueueDescriptor, armnn::DataType::BFloat16, armnn::DataType::Float32 > BFloat16ToFloat32Workload
Definition Workload.hpp:222
MultiTypedWorkload< QueueDescriptor, armnn::DataType::Float16, armnn::DataType::Float32 > Float16ToFloat32Workload
Definition Workload.hpp:232
MultiTypedWorkload< QueueDescriptor, armnn::DataType::QAsymmU8, armnn::DataType::Boolean > BaseUint8ComparisonWorkload
Definition Workload.hpp:217
MultiTypedWorkload< QueueDescriptor, armnn::DataType::QAsymmU8, armnn::DataType::Float32 > Uint8ToFloat32Workload
Definition Workload.hpp:242
MultiTypedWorkload< QueueDescriptor, armnn::DataType::Float32, armnn::DataType::BFloat16 > Float32ToBFloat16Workload
Definition Workload.hpp:227
TypedWorkload< QueueDescriptor, armnn::DataType::Float16, armnn::DataType::Float32 > FloatWorkload
Definition Workload.hpp:195
TypedWorkload< QueueDescriptor, armnn::DataType::Boolean > BooleanWorkload
Definition Workload.hpp:209
MultiTypedWorkload< QueueDescriptor, armnn::DataType::Float32, armnn::DataType::Float16 > Float32ToFloat16Workload
Definition Workload.hpp:237
TypedWorkload< QueueDescriptor, armnn::DataType::QAsymmU8 > Uint8Workload
Definition Workload.hpp:203
DataType
Definition Types.hpp:49
void IgnoreUnused(Ts &&...)
Contains information about TensorInfos of a layer.