13#include <arm_compute/function_info/ScatterInfo.h>
18using namespace armcomputetensorutils;
26 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo);
27 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo);
28 const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo);
29 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
31 arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor);
33 return arm_compute::CLScatter::validate(descriptor.
m_InputEnabled ? &aclInputInfo :
nullptr,
42 const arm_compute::CLCompileContext& clCompileContext)
51 m_Data.ValidateInputsOutputs(
"ClScatterNdWorkload", 3, 1);
58 arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.
m_Parameters);
62 m_ScatterNdLayer.configure(clCompileContext,
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
ScatterNdQueueDescriptor m_Data
ClBaseWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info)
ClScatterNdWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
void Execute() const override
Copyright (c) 2021 ARM Limited and Contributors.
arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo &inputInfo, const TensorInfo &indicesInfo, const TensorInfo &updatesInfo, const TensorInfo &outputInfo, const ScatterNdDescriptor &descriptor)
void RunClFunction(arm_compute::IFunction &function, const CheckLocation &location)
LayerDescriptor m_Parameters
A ScatterNdDescriptor for the ScatterNdLayer.
bool m_InputEnabled
Flag to show if input tensor is accepted.
Contains information about TensorInfos of a layer.