20 using size_type =
unsigned int;
23 : m_SrcShape(srcShape)
27 std::stringstream msg;
28 msg <<
"Transpose: Number of shape dimensions (" << srcShape.
GetNumDimensions() <<
29 ") does not match the size of the mappings (" << mappings.
GetSize() <<
")";
35 size_type srcStride = 1U;
36 size_type dstStride = 1U;
38 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
40 m_SrcStrides[i] = srcStride;
41 m_DstStrides[mappings[i]] = dstStride;
43 srcStride *= srcShape[i];
44 dstStride *= srcShape[mappings[i]];
48 void Unroll(
const void* srcData,
void* dstData,
size_t dataTypeSize)
50 if (srcData ==
nullptr)
54 if (dstData ==
nullptr)
58 if (dataTypeSize == 0)
63 const unsigned char* srcDataPtr =
reinterpret_cast<const unsigned char*
>(srcData);
64 unsigned char* dstDataPtr =
reinterpret_cast<unsigned char*
>(dstData);
66 const unsigned char*
const srcEndPtr = srcDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
67 unsigned char*
const dstEndPtr = dstDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
69 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
73 void Unroll(size_type dimension,
74 const unsigned char* srcData,
unsigned char* dstData,
75 const unsigned char* srcEnd,
unsigned char* dstEnd,
78 if (srcData ==
nullptr)
82 if (dstData ==
nullptr)
86 if (srcEnd ==
nullptr)
90 if (dstEnd ==
nullptr)
94 if (dataTypeSize == 0)
99 if (dimension >= m_SrcShape.GetNumDimensions())
101 ::memcpy(dstData, srcData, dataTypeSize);
105 for (size_type i = 0; i < m_SrcShape[dimension]; i++)
107 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
109 srcData += m_SrcStrides[dimension] * dataTypeSize;
110 dstData += m_DstStrides[dimension] * dataTypeSize;
116 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
117 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
129 std::stringstream msg;
130 msg <<
"Transpose: Number of shape dimensions (" << srcShape.
GetNumDimensions() <<
131 ") does not match the size of the mappings (" << mappings.
GetSize() <<
")";
135 const unsigned int numDims = mappings.
GetSize();
138 for (
unsigned int i = 0U; i < numDims; ++i)
140 outDims[i] = srcShape[mappings[i]];
143 return permutedShape;
154 const void* src,
void* dst,
size_t dataTypeSize)
156 TransposeLoop(srcShape, mappings).Unroll(src, dst, dataTypeSize);