ArmNN
 25.11
Loading...
Searching...
No Matches
TosaRefPreCompiledWorkload Class Reference

#include <TosaRefPreCompiledWorkload.hpp>

Inheritance diagram for TosaRefPreCompiledWorkload:
[legend]
Collaboration diagram for TosaRefPreCompiledWorkload:
[legend]

Public Member Functions

 TosaRefPreCompiledWorkload (const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute () const override
Public Member Functions inherited from BaseWorkload< PreCompiledQueueDescriptor >
 BaseWorkload (const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
virtual const std::string & GetName () const override
void PostAllocationConfigure () override
const PreCompiledQueueDescriptorGetData () const
arm::pipe::ProfilingGuid GetGuid () const final
Public Member Functions inherited from IWorkload
virtual ~IWorkload ()
virtual void RegisterDebugCallback (const DebugCallbackFunction &)
virtual armnn::Optional< armnn::MemoryRequirementsGetMemoryRequirements ()

Additional Inherited Members

Protected Attributes inherited from BaseWorkload< PreCompiledQueueDescriptor >
PreCompiledQueueDescriptor m_Data
const arm::pipe::ProfilingGuid m_Guid
const std::string m_Name

Detailed Description

Definition at line 22 of file TosaRefPreCompiledWorkload.hpp.

Constructor & Destructor Documentation

◆ TosaRefPreCompiledWorkload()

Definition at line 11 of file TosaRefPreCompiledWorkload.cpp.

13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
14 , m_workloadInfo(info)
15{
16 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
22}

References BaseWorkload< PreCompiledQueueDescriptor >::BaseWorkload(), armnn::info, and BaseWorkload< PreCompiledQueueDescriptor >::m_Data.

Member Function Documentation

◆ Execute()

void Execute ( ) const
overridevirtual

Implements IWorkload.

Definition at line 23 of file TosaRefPreCompiledWorkload.cpp.

24{
25 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
26
27 std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28 std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
29
30 TosaReference::IModelRunner runner;
31 GraphStatus status;
32
33 // Initialise the model runner with the TosaSerializationHandler
34 status = runner.initialize(*handler);
35 if(status != GraphStatus::TOSA_VALID)
36 {
37 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
38 }
39
40 // Set the inputs
41 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
42 {
43 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
44 switch (dataType)
45 {
46 case DataType::Float16:
47 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
48 break;
49 case DataType::Float32:
50 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
51 break;
52 case DataType::QAsymmU8:
53 SetInput<uint8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
54 break;
55 case DataType::QAsymmS8:
56 case DataType::QSymmS8:
57 SetInput<int8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
58 break;
59 case DataType::QSymmS16:
60 SetInput<int16_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
61 break;
62 case DataType::Signed32:
63 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
64 break;
65 case DataType::Signed64:
66 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
67 break;
68 case DataType::Boolean:
69 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
70 break;
71 default:
72 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
73 }
74 }
75
76 // Run the TOSA Reference Model
77 status = runner.run();
78 if(status != GraphStatus::TOSA_VALID)
79 {
80 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
81 }
82
83 // Gets the outputs
84 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
85 {
86 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
87 switch (dataType)
88 {
89 case DataType::Float16:
90 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
91 break;
92 case DataType::Float32:
93 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
94 break;
95 case DataType::QAsymmU8:
96 GetOutput<uint8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
97 break;
98 case DataType::QAsymmS8:
99 case DataType::QSymmS8:
100 GetOutput<int8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
101 break;
102 case DataType::QSymmS16:
103 GetOutput<int16_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
104 break;
105 case DataType::Signed32:
106 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
107 break;
108 case DataType::Signed64:
109 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
110 break;
111 case DataType::Boolean:
112 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
113 break;
114 default:
115 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
116 }
117 }
118}
DataType
Definition Types.hpp:49

References armnn::Boolean, armnn::Float16, armnn::Float32, BaseWorkload< PreCompiledQueueDescriptor >::m_Data, armnn::QAsymmS8, armnn::QAsymmU8, armnn::QSymmS16, armnn::QSymmS8, armnn::Signed32, and armnn::Signed64.


The documentation for this class was generated from the following files: