ArmNN
 25.11
Loading...
Searching...
No Matches
ClScatterNdWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7
8#include "ClWorkloadUtils.hpp"
9
11#include <cl/ClTensorHandle.hpp>
12
13#include <arm_compute/function_info/ScatterInfo.h>
14
15namespace armnn
16{
17
18using namespace armcomputetensorutils;
19
20arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& inputInfo,
21 const TensorInfo& indicesInfo,
22 const TensorInfo& updatesInfo,
23 const TensorInfo& outputInfo,
24 const ScatterNdDescriptor& descriptor)
25{
26 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo);
27 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo);
28 const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo);
29 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
30
31 arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor);
32
33 return arm_compute::CLScatter::validate(descriptor.m_InputEnabled ? &aclInputInfo : nullptr,
34 &aclUpdatesInfo,
35 &aclIndicesInfo,
36 &aclOutputInfo,
37 scatterInfo);
38}
39
41 const WorkloadInfo& info,
42 const arm_compute::CLCompileContext& clCompileContext)
44{
45 // Report Profiling Details
46 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClScatterNdWorkload_Construct",
47 descriptor.m_Parameters,
48 info,
49 this->GetGuid());
50
51 m_Data.ValidateInputsOutputs("ClScatterNdWorkload", 3, 1);
52
53 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
54 arm_compute::ICLTensor& updates = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
55 arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
56 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
57
58 arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.m_Parameters);
59
60 {
61 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_configure");
62 m_ScatterNdLayer.configure(clCompileContext,
63 descriptor.m_Parameters.m_InputEnabled ? &input : nullptr,
64 &updates,
65 &indices,
66 &output,
67 scatterInfo);
68 }
69}
70
72{
73 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_Execute");
74 RunClFunction(m_ScatterNdLayer, CHECK_LOCATION());
75}
76
77} //namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
#define CHECK_LOCATION()
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
ClBaseWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info)
ClScatterNdWorkload(const ScatterNdQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
Copyright (c) 2021 ARM Limited and Contributors.
arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo &inputInfo, const TensorInfo &indicesInfo, const TensorInfo &updatesInfo, const TensorInfo &outputInfo, const ScatterNdDescriptor &descriptor)
void RunClFunction(arm_compute::IFunction &function, const CheckLocation &location)
A ScatterNdDescriptor for the ScatterNdLayer.
bool m_InputEnabled
Flag to show if input tensor is accepted.
Contains information about TensorInfos of a layer.