ArmNN
 25.11
Loading...
Searching...
No Matches
Permute.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Tensor.hpp>
7
9
10#include "Half.hpp"
11
12#include <cstring>
13
14namespace
15{
16
17class PermuteLoop
18{
19public:
20 using size_type = unsigned int;
21
22 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
23 : m_DstShape(dstShape)
24 {
25 if (dstShape.GetNumDimensions() != mappings.GetSize())
26 {
27 std::stringstream msg;
28 msg << "Permute: Number of shape dimensions (" << dstShape.GetNumDimensions() <<
29 ") does not match the size of the mappings (" << mappings.GetSize() << ")";
30 throw armnn::InvalidArgumentException(msg.str());
31 }
32
33 const size_type numDims = dstShape.GetNumDimensions();
34
35 size_type srcStride = 1U;
36 size_type dstStride = 1U;
37
38 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
39 {
40 m_SrcStrides[mappings[i]] = srcStride;
41 m_DstStrides[i] = dstStride;
42
43 srcStride *= dstShape[mappings[i]];
44 dstStride *= dstShape[i];
45 }
46 }
47
48 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
49 {
50 if (srcData == nullptr)
51 {
52 throw armnn::InvalidArgumentException("Permute: Source Data pointer is null");
53 }
54 if (dstData == nullptr)
55 {
56 throw armnn::InvalidArgumentException("Permute: Destination Data pointer is null");
57 }
58 if (dataTypeSize == 0)
59 {
60 throw armnn::InvalidArgumentException("Permute: dataTypeSize is zero");
61 }
62
63 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
64 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
65
66 const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
67 unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
68
69 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
70 }
71
72private:
73 void Unroll(size_type dimension,
74 const unsigned char* srcData, unsigned char* dstData,
75 const unsigned char* srcEnd, unsigned char* dstEnd,
76 size_t dataTypeSize)
77 {
78 if (srcData == nullptr)
79 {
80 throw armnn::InvalidArgumentException("Permute: Source Data pointer is null");
81 }
82 if (dstData == nullptr)
83 {
84 throw armnn::InvalidArgumentException("Permute: Destination Data pointer is null");
85 }
86 if (srcEnd == nullptr)
87 {
88 throw armnn::InvalidArgumentException("Permute: Source End pointer is null");
89 }
90 if (dstEnd == nullptr)
91 {
92 throw armnn::InvalidArgumentException("Permute: Destination End pointer is null");
93 }
94 if (dataTypeSize == 0)
95 {
96 throw armnn::Exception("Permute: dataTypeSize is zero");
97 }
98
99 if (dimension >= m_DstShape.GetNumDimensions())
100 {
101 ::memcpy(dstData, srcData, dataTypeSize);
102 }
103 else
104 {
105 for (size_type i = 0; i < m_DstShape[dimension]; i++)
106 {
107 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
108
109 srcData += m_SrcStrides[dimension] * dataTypeSize;
110 dstData += m_DstStrides[dimension] * dataTypeSize;
111 }
112 }
113 }
114
115 armnn::TensorShape m_DstShape;
116 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
117 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
118};
119
120} // namespace
121
122namespace armnnUtils
123{
124
126 const armnn::PermutationVector& mappings)
127{
128 if (srcShape.GetNumDimensions() != mappings.GetSize())
129 {
130 std::stringstream msg;
131 msg << "Permute: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
132 ") does not match the size of the mappings (" << mappings.GetSize() << ")";
133 throw armnn::InvalidArgumentException(msg.str());
134 }
135
136 const unsigned int numDims = mappings.GetSize();
137 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
138
139 for (unsigned int i = 0U; i < numDims; ++i)
140 {
141 outDims[mappings[i]] = srcShape[i];
142 }
143
144 armnn::TensorShape permutedShape(numDims, outDims);
145 return permutedShape;
146}
147
149 const armnn::PermutationVector& mappings)
150{
151 armnn::TensorInfo outInfo(info);
152 outInfo.SetShape(Permuted(info.GetShape(), mappings));
153
154 // If TensorInfo has Per-Axis Quantization then it also has a QuantizationDim which needs to
155 // be permuted according to the mapping
156 if (info.GetQuantizationDim().has_value())
157 {
158 outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
159 }
160
161 return outInfo;
162}
163
164void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
165 const void* src, void* dst, size_t dataTypeSize)
166{
167 PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
168}
169
170} // namespace armnnUtils
SizeType GetSize() const
Definition Types.hpp:359
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
Definition Tensor.cpp:503
void SetShape(const TensorShape &newShape)
Definition Tensor.hpp:195
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition Tensor.cpp:174
constexpr unsigned int MaxNumOfTensorDimensions
Definition Types.hpp:31
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition Permute.cpp:125