ArmNN
 26.01
Loading...
Searching...
No Matches
ClRankWorkload.hpp
Go to the documentation of this file.
1//
2// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ClBaseWorkload.hpp"
10
11#include "ClWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16struct ClRankWorkload : public ClBaseWorkload<RankQueueDescriptor>
17{
18public:
20 virtual void Execute() const override
21 {
22 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClRankWorkload_Execute");
23
24 const ClTensorHandle* clTensorHandle = PolymorphicDowncast<const ClTensorHandle*>(m_Data.m_Inputs[0]);
25 const int32_t rank = static_cast<int32_t>(clTensorHandle->GetShape().GetNumDimensions());
26
27 std::memcpy(GetOutputTensorData<void>(0, m_Data), &rank, sizeof(int32_t));
28 m_Data.m_Outputs[0]->Unmap();
29 }
30};
31
32} //namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
QueueDescriptor m_Data
Definition Workload.hpp:74
ClBaseWorkload(const RankQueueDescriptor &descriptor, const WorkloadInfo &info)
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition Tensor.cpp:174
Copyright (c) 2021 ARM Limited and Contributors.
virtual void Execute() const override
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs