ArmNN
 25.11
Loading...
Searching...
No Matches
ClFullyConnectedWorkload Class Reference

#include <ClFullyConnectedWorkload.hpp>

Inheritance diagram for ClFullyConnectedWorkload:
[legend]
Collaboration diagram for ClFullyConnectedWorkload:
[legend]

Public Member Functions

 ClFullyConnectedWorkload (const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info, std::shared_ptr< arm_compute::MemoryManagerOnDemand > &memoryManager, const arm_compute::CLCompileContext &clCompileContext)
void Execute () const override
Public Member Functions inherited from ClBaseWorkload< FullyConnectedQueueDescriptor >
 ClBaseWorkload (const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info)
void ReplaceInputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
void ReplaceOutputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
Public Member Functions inherited from BaseWorkload< FullyConnectedQueueDescriptor >
 BaseWorkload (const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info)
virtual const std::string & GetName () const override
void PostAllocationConfigure () override
const FullyConnectedQueueDescriptorGetData () const
arm::pipe::ProfilingGuid GetGuid () const final
virtual bool SupportsTensorHandleReplacement () const override
Public Member Functions inherited from IWorkload
virtual ~IWorkload ()
virtual void RegisterDebugCallback (const DebugCallbackFunction &)
virtual armnn::Optional< armnn::MemoryRequirementsGetMemoryRequirements ()

Additional Inherited Members

Protected Member Functions inherited from ClBaseWorkload< FullyConnectedQueueDescriptor >
virtual void Reconfigure ()
Protected Attributes inherited from BaseWorkload< FullyConnectedQueueDescriptor >
FullyConnectedQueueDescriptor m_Data
const arm::pipe::ProfilingGuid m_Guid
const std::string m_Name

Detailed Description

Definition at line 25 of file ClFullyConnectedWorkload.hpp.

Constructor & Destructor Documentation

◆ ClFullyConnectedWorkload()

ClFullyConnectedWorkload ( const FullyConnectedQueueDescriptor & descriptor,
const WorkloadInfo & info,
std::shared_ptr< arm_compute::MemoryManagerOnDemand > & memoryManager,
const arm_compute::CLCompileContext & clCompileContext )

Definition at line 53 of file ClFullyConnectedWorkload.cpp.

58 : ClBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info), m_FullyConnectedLayer(memoryManager)
59{
60 m_Data.ValidateInputsOutputs("ClFullyConnectedWorkload", descriptor.m_Parameters.GetNumInputs(), 1);
61
62 arm_compute::ICLTensor& input = PolymorphicDowncast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
63 arm_compute::ICLTensor& output = PolymorphicDowncast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
64 arm_compute::ICLTensor& weights = PolymorphicDowncast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
65
66 weights.info()->set_are_values_constant(info.m_InputTensorInfos[1].IsConstant());
67
68 arm_compute::ICLTensor* bias = nullptr;
69 if (m_Data.m_Parameters.m_BiasEnabled)
70 {
71 bias = &PolymorphicDowncast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
72 bias->info()->set_are_values_constant(info.m_InputTensorInfos[2].IsConstant());
73 }
74
75 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
76
77 arm_compute::FullyConnectedLayerInfo fc_info =
79 activationInfo);
80
81 {
82 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClFullyConnectedWorkload_configure");
83 m_FullyConnectedLayer.configure(clCompileContext,
84 &input,
85 &weights,
86 bias,
87 &output,
88 fc_info);
89 }
90
91 // Add details for profiling output
92 WorkloadInfo detailsInfo;
93
94 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
95 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
96
97 // Report Profiling Details
98 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClFullyConnectedWorkload_Construct",
99 descriptor.m_Parameters,
100 detailsInfo,
101 this->GetGuid());
102}
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
arm_compute::ActivationLayerInfo ConvertAdditionalInfoToAclActivationLayerInfo(const QueueDescriptor &queueDescriptor)
arm_compute::FullyConnectedLayerInfo ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor &fullyConnectedDesc, const ActivationDescriptor *activationDesc)

References ARMNN_REPORT_PROFILING_WORKLOAD_DESC, ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID, ClBaseWorkload< FullyConnectedQueueDescriptor >::ClBaseWorkload(), armnn::ConvertAdditionalInfoToAclActivationLayerInfo(), armnn::ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(), FullyConnectedDescriptor::GetNumInputs(), armnn::info, BaseWorkload< FullyConnectedQueueDescriptor >::m_Data, WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, QueueDescriptorWithParameters< LayerDescriptor >::m_Parameters, and armnn::PolymorphicDowncast().

Member Function Documentation

◆ Execute()

void Execute ( ) const
overridevirtual

Implements IWorkload.

Definition at line 104 of file ClFullyConnectedWorkload.cpp.

105{
106 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClFullyConnectedWorkload_Execute");
107 RunClFunction(m_FullyConnectedLayer, CHECK_LOCATION());
108}
#define CHECK_LOCATION()
void RunClFunction(arm_compute::IFunction &function, const CheckLocation &location)

References ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID, CHECK_LOCATION, and armnn::RunClFunction().


The documentation for this class was generated from the following files: