ArmNN
 26.01
Loading...
Searching...
No Matches
ClImportTensorHandleFactory.cpp
Go to the documentation of this file.
1//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
8
11
12#include <arm_compute/core/Coordinates.h>
13#include <arm_compute/runtime/CL/CLTensor.h>
14
15namespace armnn
16{
17
19
21 ITensorHandle& parent, const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const
22{
23 arm_compute::Coordinates coords;
24 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
25
26 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
27 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
28 {
29 // Arm compute indexes tensor coords in reverse order.
30 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
31 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
32 }
33
34 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
35
36 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
37 // must match the parent shapes
38 if (coords.x() != 0 || coords.y() != 0)
39 {
40 return nullptr;
41 }
42 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
43 {
44 return nullptr;
45 }
46
47 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
48 {
49 return nullptr;
50 }
51
52 return std::make_unique<ClImportSubTensorHandle>(
53 PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
54}
55
56std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
57{
58 std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
60 return tensorHandle;
61}
62
63std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
64 DataLayout dataLayout) const
65{
66 std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
67 dataLayout,
69 return tensorHandle;
70}
71
72std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
73 const bool IsMemoryManaged) const
74{
75 if (IsMemoryManaged)
76 {
77 throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
78 }
79 return CreateTensorHandle(tensorInfo);
80}
81
82std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
83 DataLayout dataLayout,
84 const bool IsMemoryManaged) const
85{
86 if (IsMemoryManaged)
87 {
88 throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
89 }
90 return CreateTensorHandle(tensorInfo, dataLayout);
91}
92
94{
95 static const FactoryId s_Id(ClImportTensorHandleFactoryId());
96 return s_Id;
97}
98
100{
101 return GetIdStatic();
102}
103
105{
106 return true;
107}
108
110{
111 return false;
112}
113
115{
116 return m_ExportFlags;
117}
118
120{
121 return m_ImportFlags;
122}
123
125 const IConnectableLayer* connectedLayer,
126 CapabilityClass capabilityClass)
127{
128 IgnoreUnused(layer);
129 IgnoreUnused(connectedLayer);
130 std::vector<Capability> capabilities;
131 if (capabilityClass == CapabilityClass::FallbackImportDisabled)
132 {
134 capabilities.push_back(paddingCapability);
135 }
136 return capabilities;
137}
138
139} // namespace armnn
const FactoryId & GetId() const override
std::vector< Capability > GetCapabilities(const IConnectableLayer *layer, const IConnectableLayer *connectedLayer, CapabilityClass capabilityClass) override
std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo) const override
MemorySourceFlags GetExportFlags() const override
std::unique_ptr< ITensorHandle > CreateSubTensorHandle(ITensorHandle &parent, const TensorShape &subTensorShape, const unsigned int *subTensorOrigin) const override
MemorySourceFlags GetImportFlags() const override
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition INetwork.hpp:81
virtual TensorShape GetShape() const =0
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition Tensor.cpp:174
Copyright (c) 2021 ARM Limited and Contributors.
CapabilityClass
Capability class to calculate in the GetCapabilities function so that only the capability in the scop...
unsigned int MemorySourceFlags
constexpr const char * ClImportTensorHandleFactoryId()
ITensorHandleFactory::FactoryId FactoryId
DataLayout
Definition Types.hpp:63
void IgnoreUnused(Ts &&...)
Capability of the TensorHandleFactory.