14 , m_workloadInfo(
info)
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
27 std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28 std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
30 TosaReference::IModelRunner runner;
34 status = runner.initialize(*handler);
35 if(status != GraphStatus::TOSA_VALID)
37 throw armnn::Exception(
"An error has occurred while initialising the TOSA Reference Model.");
41 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
47 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
50 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
53 SetInput<uint8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
57 SetInput<int8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
60 SetInput<int16_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
63 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
66 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
69 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
72 throw armnn::Exception(
"Input data type is unsupported in TOSA Reference Backend.");
77 status = runner.run();
78 if(status != GraphStatus::TOSA_VALID)
80 throw armnn::Exception(
"An error has occurred while running the TOSA Reference Model.");
84 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
90 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
93 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
96 GetOutput<uint8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
100 GetOutput<int8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
103 GetOutput<int16_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
106 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
109 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
112 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
115 throw armnn::Exception(
"Output data type is unsupported in TOSA Reference Backend.");
120 template <
typename T>
121 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
122 std::string inputName,
123 uint32_t inputIndex)
const
125 SetInput<T, T>(runner, inputName, inputIndex);
128 template <
typename T,
typename Trunner>
129 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
130 std::string inputName,
131 uint32_t inputIndex)
const
133 std::vector<T> inputData(
m_Data.
m_Inputs[inputIndex]->GetShape().GetNumElements());
134 std::vector<Trunner> inputDataRunner(
m_Data.
m_Inputs[inputIndex]->GetShape().GetNumElements());
138 std::transform(inputData.begin(), inputData.end(),
139 inputDataRunner.begin(), [](T x) { return static_cast<Trunner>(x);});
141 runner.setInput<Trunner>(inputName, inputDataRunner);
144 template <
typename T>
145 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
146 std::string outputName,
147 uint32_t outputIndex)
const
149 GetOutput<T, T>(runner, outputName, outputIndex);
152 template <
typename T,
typename Trunner>
153 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
154 std::string outputName,
155 uint32_t outputIndex)
const
157 std::vector<Trunner> actualOutputsRunner = runner.getOutput<Trunner>(outputName);
158 std::vector<T> actualOutputs (actualOutputsRunner.size());
160 std::transform(actualOutputsRunner.begin(), actualOutputsRunner.end(),
161 actualOutputs.begin(), [](Trunner x) { return static_cast<T>(x);});