ArmNN
 25.11
Loading...
Searching...
No Matches
ClImportTensorHandleFactory.cpp
Go to the documentation of this file.
1//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
8
11
12#include <arm_compute/core/Coordinates.h>
13#include <arm_compute/runtime/CL/CLTensor.h>
14
15namespace armnn
16{
17
19
21 ITensorHandle& parent, const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const
22{
23 arm_compute::Coordinates coords;
24 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
25
26 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
27 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
28 {
29 // Arm compute indexes tensor coords in reverse order.
30 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
31 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
32 }
33
34 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
35
36 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
37 // must match the parent shapes
38 if (coords.x() != 0 || coords.y() != 0)
39 {
40 return nullptr;
41 }
42 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
43 {
44 return nullptr;
45 }
46
47 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
48 {
49 return nullptr;
50 }
51
52 return std::make_unique<ClImportSubTensorHandle>(
53 PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
54}
55
56std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
57{
58 std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
60 return tensorHandle;
61}
62
63std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
64 DataLayout dataLayout) const
65{
66 std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
67 dataLayout,
69 return tensorHandle;
70}
71
72std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
73 const bool IsMemoryManaged) const
74{
75 if (IsMemoryManaged)
76 {
77 throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
78 }
79 return CreateTensorHandle(tensorInfo);
80}
81
82std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
83 DataLayout dataLayout,
84 const bool IsMemoryManaged) const
85{
86 if (IsMemoryManaged)
87 {
88 throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
89 }
90 return CreateTensorHandle(tensorInfo, dataLayout);
91}
92
94{
95 static const FactoryId s_Id(ClImportTensorHandleFactoryId());
96 return s_Id;
97}
98
100{
101 return GetIdStatic();
102}
103
105{
106 return true;
107}
108
110{
111 return false;
112}
113
115{
116 return m_ExportFlags;
117}
118
120{
121 return m_ImportFlags;
122}
123
125 const IConnectableLayer* connectedLayer,
126 CapabilityClass capabilityClass)
127{
128 IgnoreUnused(layer);
129 IgnoreUnused(connectedLayer);
130 std::vector<Capability> capabilities;
131 if (capabilityClass == CapabilityClass::FallbackImportDisabled)
132 {
134 capabilities.push_back(paddingCapability);
135 }
136 return capabilities;
137}
138
139} // namespace armnn
const FactoryId & GetId() const override
std::vector< Capability > GetCapabilities(const IConnectableLayer *layer, const IConnectableLayer *connectedLayer, CapabilityClass capabilityClass) override
std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo) const override
MemorySourceFlags GetExportFlags() const override
std::unique_ptr< ITensorHandle > CreateSubTensorHandle(ITensorHandle &parent, const TensorShape &subTensorShape, const unsigned int *subTensorOrigin) const override
MemorySourceFlags GetImportFlags() const override
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition INetwork.hpp:81
virtual TensorShape GetShape() const =0
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition Tensor.cpp:174
Copyright (c) 2021 ARM Limited and Contributors.
CapabilityClass
Capability class to calculate in the GetCapabilities function so that only the capability in the scop...
unsigned int MemorySourceFlags
ITensorHandleFactory::FactoryId FactoryId
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
constexpr const char * ClImportTensorHandleFactoryId()
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
DataLayout
Definition Types.hpp:63
void IgnoreUnused(Ts &&...)
Capability of the TensorHandleFactory.