ArmNN
 24.08
Permute.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Tensor.hpp>
7 
8 #include <armnnUtils/Permute.hpp>
9 
10 #include "Half.hpp"
11 
12 #include <cstring>
13 
14 namespace
15 {
16 
17 class PermuteLoop
18 {
19 public:
20  using size_type = unsigned int;
21 
22  PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
23  : m_DstShape(dstShape)
24  {
25  if (dstShape.GetNumDimensions() != mappings.GetSize())
26  {
27  std::stringstream msg;
28  msg << "Permute: Number of shape dimensions (" << dstShape.GetNumDimensions() <<
29  ") does not match the size of the mappings (" << mappings.GetSize() << ")";
30  throw armnn::InvalidArgumentException(msg.str());
31  }
32 
33  const size_type numDims = dstShape.GetNumDimensions();
34 
35  size_type srcStride = 1U;
36  size_type dstStride = 1U;
37 
38  for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
39  {
40  m_SrcStrides[mappings[i]] = srcStride;
41  m_DstStrides[i] = dstStride;
42 
43  srcStride *= dstShape[mappings[i]];
44  dstStride *= dstShape[i];
45  }
46  }
47 
48  void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
49  {
50  if (srcData == nullptr)
51  {
52  throw armnn::InvalidArgumentException("Permute: Source Data pointer is null");
53  }
54  if (dstData == nullptr)
55  {
56  throw armnn::InvalidArgumentException("Permute: Destination Data pointer is null");
57  }
58  if (dataTypeSize == 0)
59  {
60  throw armnn::InvalidArgumentException("Permute: dataTypeSize is zero");
61  }
62 
63  const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
64  unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
65 
66  const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
67  unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
68 
69  Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
70  }
71 
72 private:
73  void Unroll(size_type dimension,
74  const unsigned char* srcData, unsigned char* dstData,
75  const unsigned char* srcEnd, unsigned char* dstEnd,
76  size_t dataTypeSize)
77  {
78  if (srcData == nullptr)
79  {
80  throw armnn::InvalidArgumentException("Permute: Source Data pointer is null");
81  }
82  if (dstData == nullptr)
83  {
84  throw armnn::InvalidArgumentException("Permute: Destination Data pointer is null");
85  }
86  if (srcEnd == nullptr)
87  {
88  throw armnn::InvalidArgumentException("Permute: Source End pointer is null");
89  }
90  if (dstEnd == nullptr)
91  {
92  throw armnn::InvalidArgumentException("Permute: Destination End pointer is null");
93  }
94  if (dataTypeSize == 0)
95  {
96  throw armnn::Exception("Permute: dataTypeSize is zero");
97  }
98 
99  if (dimension >= m_DstShape.GetNumDimensions())
100  {
101  ::memcpy(dstData, srcData, dataTypeSize);
102  }
103  else
104  {
105  for (size_type i = 0; i < m_DstShape[dimension]; i++)
106  {
107  Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
108 
109  srcData += m_SrcStrides[dimension] * dataTypeSize;
110  dstData += m_DstStrides[dimension] * dataTypeSize;
111  }
112  }
113  }
114 
115  armnn::TensorShape m_DstShape;
116  std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
117  std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
118 };
119 
120 } // namespace
121 
122 namespace armnnUtils
123 {
124 
126  const armnn::PermutationVector& mappings)
127 {
128  if (srcShape.GetNumDimensions() != mappings.GetSize())
129  {
130  std::stringstream msg;
131  msg << "Permute: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
132  ") does not match the size of the mappings (" << mappings.GetSize() << ")";
133  throw armnn::InvalidArgumentException(msg.str());
134  }
135 
136  const unsigned int numDims = mappings.GetSize();
137  unsigned int outDims[armnn::MaxNumOfTensorDimensions];
138 
139  for (unsigned int i = 0U; i < numDims; ++i)
140  {
141  outDims[mappings[i]] = srcShape[i];
142  }
143 
144  armnn::TensorShape permutedShape(numDims, outDims);
145  return permutedShape;
146 }
147 
149  const armnn::PermutationVector& mappings)
150 {
151  armnn::TensorInfo outInfo(info);
152  outInfo.SetShape(Permuted(info.GetShape(), mappings));
153 
154  // If TensorInfo has Per-Axis Quantization then it also has a QuantizationDim which needs to
155  // be permuted according to the mapping
156  if (info.GetQuantizationDim().has_value())
157  {
158  outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
159  }
160 
161  return outInfo;
162 }
163 
164 void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
165  const void* src, void* dst, size_t dataTypeSize)
166 {
167  PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
168 }
169 
170 } // namespace armnnUtils
armnn::TensorInfo::SetQuantizationDim
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
Definition: Tensor.cpp:503
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::MaxNumOfTensorDimensions
constexpr unsigned int MaxNumOfTensorDimensions
Definition: Types.hpp:31
armnnUtils::Permute
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
Definition: Permute.cpp:164
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:125
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::TensorShape::GetNumDimensions
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
armnnUtils
Definition: CompatibleTypes.hpp:10
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::PermutationVector
Definition: Types.hpp:314
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
Permute.hpp
armnn::BoostLogSeverityMapping::info
@ info
Half.hpp
armnn::PermutationVector::GetSize
SizeType GetSize() const
Definition: Types.hpp:357
Tensor.hpp
armnn::TensorInfo::SetShape
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:195