ArmNN
 25.11
Loading...
Searching...
No Matches
MakeWorkloadHelper.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7namespace armnn
8{
9namespace
10{
11
12// Make a workload of the specified WorkloadType.
13template<typename WorkloadType>
14struct MakeWorkloadForType
15{
16 template<typename QueueDescriptorType, typename... Args>
17 static std::unique_ptr<WorkloadType> Func(const QueueDescriptorType& descriptor,
18 const WorkloadInfo& info,
19 Args&&... args)
20 {
21 return std::make_unique<WorkloadType>(descriptor, info, std::forward<Args>(args)...);
22 }
23};
24
25// Specialization for void workload type used for unsupported workloads.
26template<>
27struct MakeWorkloadForType<NullWorkload>
28{
29 template<typename QueueDescriptorType, typename... Args>
30 static std::unique_ptr<NullWorkload> Func(const QueueDescriptorType& descriptor,
31 const WorkloadInfo& info,
32 Args&&... args)
33 {
34 IgnoreUnused(descriptor);
36 IgnoreUnused(args...);
37 return nullptr;
38 }
39};
40
41// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
42// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
43template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename Int32Workload,
44 typename BooleanWorkload, typename Int8Workload, typename QueueDescriptorType, typename... Args>
45std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
46 const WorkloadInfo& info,
47 Args&&... args)
48{
49 const DataType dataType = !info.m_InputTensorInfos.empty() ?
50 info.m_InputTensorInfos[0].GetDataType()
51 : info.m_OutputTensorInfos[0].GetDataType();
52
53 switch (dataType)
54 {
55
57 return MakeWorkloadForType<Float16Workload>::Func(descriptor, info, std::forward<Args>(args)...);
59 return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
61 return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
64 return MakeWorkloadForType<Int8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
66 return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
68 return MakeWorkloadForType<BooleanWorkload>::Func(descriptor, info, std::forward<Args>(args)...);
71 return nullptr;
72 default:
73 throw InvalidArgumentException("Unknown data type passed to MakeWorkloadHelper");
74 }
75}
76
77// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
78// Calling this method is the equivalent of calling the five typed MakeWorkload method with <FloatWorkload,
79// FloatWorkload, Uint8Workload, NullWorkload, NullWorkload, NullWorkload>.
80// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
81template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
82std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
83 const WorkloadInfo& info,
84 Args&&... args)
85{
86 return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload, NullWorkload, NullWorkload, NullWorkload>(
87 descriptor,
88 info,
89 std::forward<Args>(args)...);
90}
91
92} //namespace
93} //namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
TypedWorkload< QueueDescriptor, armnn::DataType::Float32 > Float32Workload
Definition Workload.hpp:200
TypedWorkload< QueueDescriptor, armnn::DataType::Signed32 > Int32Workload
Definition Workload.hpp:206
TypedWorkload< QueueDescriptor, armnn::DataType::Float16, armnn::DataType::Float32 > FloatWorkload
Definition Workload.hpp:195
TypedWorkload< QueueDescriptor, armnn::DataType::Boolean > BooleanWorkload
Definition Workload.hpp:209
TypedWorkload< QueueDescriptor, armnn::DataType::QAsymmU8 > Uint8Workload
Definition Workload.hpp:203
DataType
Definition Types.hpp:49
void IgnoreUnused(Ts &&...)
Contains information about TensorInfos of a layer.