24.02
|
#include <ClGatherNdWorkload.hpp>
Public Member Functions | |
ClGatherNdWorkload (const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext) | |
virtual void | Execute () const override |
Public Member Functions inherited from ClBaseWorkload< GatherNdQueueDescriptor > | |
ClBaseWorkload (const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info) | |
void | ReplaceInputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override |
void | ReplaceOutputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override |
Public Member Functions inherited from BaseWorkload< GatherNdQueueDescriptor > | |
BaseWorkload (const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info) | |
virtual const std::string & | GetName () const override |
void | ExecuteAsync (ExecutionData &executionData) override |
void | PostAllocationConfigure () override |
const GatherNdQueueDescriptor & | GetData () const |
arm::pipe::ProfilingGuid | GetGuid () const final |
virtual bool | SupportsTensorHandleReplacement () const override |
Public Member Functions inherited from IWorkload | |
virtual | ~IWorkload () |
virtual arm::pipe::ProfilingGuid | GetGuid () const =0 |
virtual bool | SupportsTensorHandleReplacement () const =0 |
virtual const std::string & | GetName () const =0 |
virtual void | RegisterDebugCallback (const DebugCallbackFunction &) |
virtual armnn::Optional< armnn::MemoryRequirements > | GetMemoryRequirements () |
Additional Inherited Members | |
Protected Member Functions inherited from ClBaseWorkload< GatherNdQueueDescriptor > | |
virtual void | Reconfigure () |
Protected Attributes inherited from BaseWorkload< GatherNdQueueDescriptor > | |
GatherNdQueueDescriptor | m_Data |
const arm::pipe::ProfilingGuid | m_Guid |
const std::string | m_Name |
Definition at line 22 of file ClGatherNdWorkload.hpp.
ClGatherNdWorkload | ( | const GatherNdQueueDescriptor & | descriptor, |
const WorkloadInfo & | info, | ||
const arm_compute::CLCompileContext & | clCompileContext | ||
) |
Calculate flattened indices: m_FlattenedIndices = indices * m_FlattenedCoeff. This could be done using MatMul instead of multiplication followed by reduce sum operation, but GeMM does not support s32 at the moment.
Call Gather with adequate shapes
Definition at line 99 of file ClGatherNdWorkload.cpp.
References armnn::CalculateGatherNdKeyIndices(), armnn::info, BaseWorkload< GatherNdQueueDescriptor >::m_Data, QueueDescriptor::m_Inputs, QueueDescriptor::m_Outputs, TensorInfo::SetShape(), and QueueDescriptor::ValidateInputsOutputs().
|
overridevirtual |
Implements IWorkload.
Definition at line 198 of file ClGatherNdWorkload.cpp.
References ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID, CHECK_LOCATION, and armnn::RunClFunction().