ArmNN
 25.11
Loading...
Searching...
No Matches
ClTensorHandleFactory.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7#include "ClTensorHandle.hpp"
8
12
13#include <arm_compute/core/Coordinates.h>
14#include <arm_compute/runtime/CL/CLSubTensor.h>
15#include <arm_compute/runtime/CL/CLTensor.h>
16
17namespace armnn
18{
19
21
22std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
23 const TensorShape& subTensorShape,
24 const unsigned int* subTensorOrigin) const
25{
26 arm_compute::Coordinates coords;
27 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
28
29 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
30 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
31 {
32 // Arm compute indexes tensor coords in reverse order.
33 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
34 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
35 }
36
37 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
38
39 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
40 // must match the parent shapes
41 if (coords.x() != 0 || coords.y() != 0)
42 {
43 return nullptr;
44 }
45 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
46 {
47 return nullptr;
48 }
49
50 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
51 {
52 return nullptr;
53 }
54
55 return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
56}
57
58std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
59{
60 return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
61}
62
63std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
64 DataLayout dataLayout) const
65{
66 return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
67}
68
69std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
70 const bool IsMemoryManaged) const
71{
72 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
73 if (!IsMemoryManaged)
74 {
75 ARMNN_LOG(warning) << "ClTensorHandleFactory only has support for memory managed.";
76 }
77 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
78 return tensorHandle;
79}
80
81std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
82 DataLayout dataLayout,
83 const bool IsMemoryManaged) const
84{
85 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
86 if (!IsMemoryManaged)
87 {
88 ARMNN_LOG(warning) << "ClTensorHandleFactory only has support for memory managed.";
89 }
90 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
91 return tensorHandle;
92}
93
95{
96 static const FactoryId s_Id(ClTensorHandleFactoryId());
97 return s_Id;
98}
99
101{
102 return GetIdStatic();
103}
104
106{
107 return false;
108}
109
114
119
120} // namespace armnn
#define ARMNN_LOG(severity)
Definition Logging.hpp:212
const FactoryId & GetId() const override
std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo) const override
MemorySourceFlags GetExportFlags() const override
std::unique_ptr< ITensorHandle > CreateSubTensorHandle(ITensorHandle &parent, const TensorShape &subTensorShape, const unsigned int *subTensorOrigin) const override
MemorySourceFlags GetImportFlags() const override
static const FactoryId & GetIdStatic()
virtual TensorShape GetShape() const =0
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition Tensor.cpp:174
Copyright (c) 2021 ARM Limited and Contributors.
unsigned int MemorySourceFlags
ITensorHandleFactory::FactoryId FactoryId
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
constexpr const char * ClTensorHandleFactoryId()
DataLayout
Definition Types.hpp:63