ArmNN
 25.11
Loading...
Searching...
No Matches
ClConstantWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017-2018,2020-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7
8#include <Half.hpp>
10#include <cl/ClTensorHandle.hpp>
12
13#include "ClWorkloadUtils.hpp"
14
15namespace armnn
16{
17
18arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output)
19{
20 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
21
22 std::array<arm_compute::DataType,9> supportedTypes = {
23 arm_compute::DataType::F16,
24 arm_compute::DataType::F32,
25 arm_compute::DataType::QASYMM8,
26 arm_compute::DataType::QASYMM8_SIGNED,
27 arm_compute::DataType::QSYMM16,
28 arm_compute::DataType::QSYMM8,
29 arm_compute::DataType::QSYMM8_PER_CHANNEL,
30 arm_compute::DataType::S32,
31 arm_compute::DataType::S64
32 };
33 auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
34
35 if (it != end(supportedTypes))
36 {
37 return arm_compute::Status{};
38 }
39 else
40 {
41 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
42 }
43}
44
46 const WorkloadInfo& info,
47 const arm_compute::CLCompileContext&)
49 , m_RanOnce(false)
50{
51}
52
54{
55 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClConstantWorkload_Execute");
56
57 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
58 // on the first inference, then reused for subsequent inferences.
59 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
60 // have been configured at the time.
61 if (!m_RanOnce)
62 {
63 const ConstantQueueDescriptor& data = this->m_Data;
64
65 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(data.m_LayerOutput, "Output tensor handle is null.");
66 arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
67 arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
68
69 switch (computeDataType)
70 {
71 case arm_compute::DataType::F16:
72 {
74 break;
75 }
76 case arm_compute::DataType::F32:
77 {
79 break;
80 }
81 case arm_compute::DataType::QASYMM8:
82 {
84 break;
85 }
86 case arm_compute::DataType::QASYMM8_SIGNED:
87 {
89 break;
90 }
91 case arm_compute::DataType::QSYMM16:
92 {
94 break;
95 }
96 case arm_compute::DataType::QSYMM8:
97 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
98 {
100 break;
101 }
102 case arm_compute::DataType::S32:
103 {
105 break;
106 }
107 case arm_compute::DataType::S64:
108 {
110 break;
111 }
112 default:
113 {
114 throw InvalidArgumentException("Unknown data type.");
115 }
116 }
117
118 m_RanOnce = true;
119 }
120}
121
122} //namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
ClBaseWorkload(const ConstantQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
ClConstantWorkload(const ConstantQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
const T * GetConstTensor() const
Copyright (c) 2021 ARM Limited and Contributors.
half_float::half Half
Definition Half.hpp:22
arm_compute::Status ClConstantWorkloadValidate(const TensorInfo &output)
void CopyArmComputeClTensorData(arm_compute::CLTensor &dstTensor, const T *srcData)
const ConstTensorHandle * m_LayerOutput
std::vector< ITensorHandle * > m_Outputs
Contains information about TensorInfos of a layer.