14 , m_workloadInfo(
info)
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
27 std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28 std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
30 TosaReference::IModelRunner runner;
34 status = runner.initialize(*handler);
35 if(status != GraphStatus::TOSA_VALID)
37 throw armnn::Exception(
"An error has occurred while initialising the TOSA Reference Model.");
41 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
47 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
50 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
57 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
60 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
63 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
66 throw armnn::Exception(
"Input data type is unsupported in TOSA Reference Backend.");
71 status = runner.run();
72 if(status != GraphStatus::TOSA_VALID)
74 throw armnn::Exception(
"An error has occurred while running the TOSA Reference Model.");
78 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
84 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
87 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
94 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
97 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
100 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
103 throw armnn::Exception(
"Output data type is unsupported in TOSA Reference Backend.");
108 template <
typename T>
109 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
110 std::string inputName,
111 uint32_t inputIndex)
const
113 std::vector<T> inputData(
m_Data.
m_Inputs[inputIndex]->GetShape().GetNumElements());
116 runner.setInput<T>(inputName, inputData);
119 template <
typename T>
120 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
121 std::string outputName,
122 uint32_t outputIndex)
const
124 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);