ArmNN
 25.11
Loading...
Searching...
No Matches
TosaRefPreCompiledWorkload.hpp
Go to the documentation of this file.
1//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
9
10#include <graph_status.h>
11#include <model_runner.h>
12
13#include <memory>
14#include <string>
15#include <vector>
16
17namespace armnn
18{
19
20bool TosaRefPreCompiledWorkloadValidate(std::string* reasonIfUnsupported);
21
22class TosaRefPreCompiledWorkload : public BaseWorkload<PreCompiledQueueDescriptor>
23{
24public:
26 const WorkloadInfo& info);
27 void Execute() const override;
28
29private:
30 bool SupportsTensorHandleReplacement() const override
31 {
32 return true;
33 }
34
35 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
36 {
37 this->m_Data.m_Inputs[slot] = tensorHandle;
38 }
39
40 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
41 {
42 this->m_Data.m_Outputs[slot] = tensorHandle;
43 }
44
45 template <typename T, typename Trunner>
46 void SetInput(TosaReference::IModelRunner& runner, std::string inputName, uint32_t inputIndex) const;
47
48 template <typename T>
49 void SetInput(TosaReference::IModelRunner& runner, std::string inputName, uint32_t inputIndex) const;
50
51 template <typename T, typename Trunner>
52 void GetOutput(TosaReference::IModelRunner& runner, std::string outputName, uint32_t outputIndex) const;
53
54 template <typename T>
55 void GetOutput(TosaReference::IModelRunner& runner, std::string outputName, uint32_t outputIndex) const;
56
57 WorkloadInfo m_workloadInfo;
58};
59
60} //namespace armnn
virtual bool SupportsTensorHandleReplacement() const override
Definition Workload.hpp:54
BaseWorkload(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
Definition Workload.hpp:35
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
Copyright (c) 2021 ARM Limited and Contributors.
bool TosaRefPreCompiledWorkloadValidate(std::string *)
Contains information about TensorInfos of a layer.