24.08
|
Go to the documentation of this file.
13 #include <arm_compute/function_info/ScatterInfo.h>
18 using namespace armcomputetensorutils;
26 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo);
27 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo);
28 const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo);
29 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
31 arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor);
33 return arm_compute::CLScatter::validate(descriptor.
m_InputEnabled ? &aclInputInfo :
nullptr,
42 const arm_compute::CLCompileContext& clCompileContext)
58 arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.
m_Parameters);
62 m_ScatterNdLayer.configure(clCompileContext,
void ValidateInputsOutputs(const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
bool m_InputEnabled
Flag to show if input tensor is accepted.
void Execute() const override
LayerDescriptor m_Parameters
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
Contains information about TensorInfos of a layer.
arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo &inputInfo, const TensorInfo &indicesInfo, const TensorInfo &updatesInfo, const TensorInfo &outputInfo, const ScatterNdDescriptor &descriptor)
std::vector< ITensorHandle * > m_Outputs
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
ScatterNdQueueDescriptor m_Data
void RunClFunction(arm_compute::IFunction &function, const CheckLocation &location)
Copyright (c) 2021 ARM Limited and Contributors.
A ScatterNdDescriptor for the ScatterNdLayer.
ClScatterNdWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
std::vector< ITensorHandle * > m_Inputs