18 const void* inputData,
20 unsigned int dataTypeSize)
22 const unsigned int blockSize = descriptor.
m_BlockSize;
25 const unsigned int batches = inputShape[0];
28 const unsigned int inDepth = inputShape[dataLayoutIndexed.
GetChannelsIndex()];
29 const unsigned int inHeight = inputShape[dataLayoutIndexed.
GetHeightIndex()];
30 const unsigned int inWidth = inputShape[dataLayoutIndexed.
GetWidthIndex()];
32 const unsigned int outDepth = inDepth / (blockSize * blockSize);
53 permDestShape =
TensorShape({ outDepth, inHeight, blockSize, inWidth, blockSize });
54 permVector = { 2, 4, 0, 1, 3 };
58 permDestShape =
TensorShape({ inHeight, blockSize, inWidth, blockSize, outDepth });
59 permVector = { 0, 2, 1, 3, 4 };
62 const unsigned int numElementsPerBatch = inputShape.
GetNumElements() / batches;
64 for (
unsigned int batchIndex = 0u; batchIndex < batches; ++batchIndex)
66 const uintptr_t batchDataOffset = batchIndex * (numElementsPerBatch * dataTypeSize);
70 static_cast<const void*
>(
reinterpret_cast<const uint8_t*
>(inputData) + batchDataOffset),
71 static_cast<void*
>(
reinterpret_cast<uint8_t*
>(outputData) + batchDataOffset),
const TensorShape & GetShape() const
unsigned int GetNumElements() const
Function that calculates the tensor elements by multiplying all dimension size which are Specified.
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
unsigned int GetWidthIndex() const
unsigned int GetHeightIndex() const
unsigned int GetChannelsIndex() const
Copyright (c) 2021 ARM Limited and Contributors.
void DepthToSpace(const TensorInfo &inputInfo, const DepthToSpaceDescriptor &descriptor, const void *inputData, void *outputData, unsigned int dataTypeSize)
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
A SpaceToDepthDescriptor for the SpaceToDepthLayer.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
unsigned int m_BlockSize
Scalar specifying the input block size. It must be >= 1.