20 using size_type =
unsigned int;
23 : m_DstShape(dstShape)
27 std::stringstream msg;
28 msg <<
"Permute: Number of shape dimensions (" << dstShape.
GetNumDimensions() <<
29 ") does not match the size of the mappings (" << mappings.
GetSize() <<
")";
35 size_type srcStride = 1U;
36 size_type dstStride = 1U;
38 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
40 m_SrcStrides[mappings[i]] = srcStride;
41 m_DstStrides[i] = dstStride;
43 srcStride *= dstShape[mappings[i]];
44 dstStride *= dstShape[i];
48 void Unroll(
const void* srcData,
void* dstData,
size_t dataTypeSize)
50 if (srcData ==
nullptr)
54 if (dstData ==
nullptr)
58 if (dataTypeSize == 0)
63 const unsigned char* srcDataPtr =
reinterpret_cast<const unsigned char*
>(srcData);
64 unsigned char* dstDataPtr =
reinterpret_cast<unsigned char*
>(dstData);
66 const unsigned char*
const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
67 unsigned char*
const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
69 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
73 void Unroll(size_type dimension,
74 const unsigned char* srcData,
unsigned char* dstData,
75 const unsigned char* srcEnd,
unsigned char* dstEnd,
78 if (srcData ==
nullptr)
82 if (dstData ==
nullptr)
86 if (srcEnd ==
nullptr)
90 if (dstEnd ==
nullptr)
94 if (dataTypeSize == 0)
99 if (dimension >= m_DstShape.GetNumDimensions())
101 ::memcpy(dstData, srcData, dataTypeSize);
105 for (size_type i = 0; i < m_DstShape[dimension]; i++)
107 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
109 srcData += m_SrcStrides[dimension] * dataTypeSize;
110 dstData += m_DstStrides[dimension] * dataTypeSize;
116 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
117 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
130 std::stringstream msg;
131 msg <<
"Permute: Number of shape dimensions (" << srcShape.
GetNumDimensions() <<
132 ") does not match the size of the mappings (" << mappings.
GetSize() <<
")";
136 const unsigned int numDims = mappings.
GetSize();
139 for (
unsigned int i = 0U; i < numDims; ++i)
141 outDims[mappings[i]] = srcShape[i];
145 return permutedShape;
156 if (
info.GetQuantizationDim().has_value())
165 const void* src,
void* dst,
size_t dataTypeSize)
167 PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);