23 unsigned int channelDimension3D = dataLayout.
GetDataLayout() == DataLayout::NCHW ? 1 : 2;
26 return (b * shape[dataLayout.
GetHeightIndex()] + h) * shape[channelDimension3D] + c;
55 if (rank != 3 && rank != 4 )
62 unsigned int channelDimension3D = params.
m_DataLayout == DataLayout::NCHW ? 1 : 2;
67 const unsigned int inputBatchSize = inputShape[0];
68 const unsigned int outputBatchSize = outputShape[0];
70 const unsigned int channels = (rank == 3) ? inputShape[channelDimension3D]
73 const unsigned int inputHeight = inputShape[dataLayout.
GetHeightIndex()];
74 const unsigned int inputWidth = (rank == 3) ? 1 : inputShape[dataLayout.
GetWidthIndex()];
75 const unsigned int outputHeight = outputShape[dataLayout.
GetHeightIndex()];
76 const unsigned int outputWidth = (rank == 3) ? 1 : outputShape[dataLayout.
GetWidthIndex()];
79 const unsigned int blockWidth = (rank == 3) ? 1 : params.
m_BlockShape[1];
81 const unsigned int paddingTop = params.
m_PadList[0].first;
82 const unsigned int paddingLeft = (rank == 3) ? 0 : params.
m_PadList[1].first;
84 for (
unsigned int outB = 0; outB < outputBatchSize; ++outB)
86 unsigned int inB = outB % inputBatchSize;
88 unsigned int shiftW = (outB / inputBatchSize) % blockWidth;
89 unsigned int shiftH = (outB / inputBatchSize) / blockWidth;
91 for (
unsigned int outH = 0; outH < outputHeight; ++outH)
93 for (
unsigned int outW = 0; outW < outputWidth; ++outW)
95 if (outH * blockHeight + shiftH < paddingTop ||
96 outH * blockHeight + shiftH >= paddingTop + inputHeight ||
97 outW * blockWidth + shiftW < paddingLeft ||
98 outW * blockWidth + shiftW >= paddingLeft + inputWidth)
100 for (
unsigned int c = 0; c < channels; c++)
102 unsigned int outOffset =
GetOffset(outputShape,
108 outputData += outOffset;
110 outputData -= outOffset;
115 for (
unsigned int c = 0; c < channels; c++)
117 unsigned int inOffset =
GetOffset(inputShape,
119 (outH * blockHeight + shiftH) - paddingTop,
120 (outW * blockWidth + shiftW) - paddingLeft,
124 unsigned int outOffset =
GetOffset(outputShape,
131 outputData += outOffset;
132 inputData += inOffset;
133 outputData.
Set(inputData.
Get());
134 inputData -= inOffset;
135 outputData -= outOffset;