ArmNN
 25.11
Loading...
Searching...
No Matches
ReverseV2Impl.cpp
Go to the documentation of this file.
1//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ReverseV2Impl.hpp"
7
9#include <armnn/Logging.hpp>
11
12namespace armnn
13{
14
15// Get multi-dimensional index for input tensor
16std::vector<unsigned int> ReverseGetMultIdx(const unsigned int idx,
17 unsigned int inputRank,
18 std::vector<unsigned int>& elementNumInner)
19{
20 std::vector<unsigned int> indexList(inputRank);
21
22 unsigned int mIdx = idx;
23
24 for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
25 {
26 indexList[iDim] = static_cast<unsigned int>(mIdx / elementNumInner[iDim]);
27 mIdx %= elementNumInner[iDim];
28 }
29
30 return indexList;
31}
32
33// Get flattened index for output encoder
34unsigned int ReverseGetFlatIdx(const std::vector<unsigned int>& idxList,
35 unsigned int inputRank,
36 std::vector<unsigned int>& elementNumInner)
37{
38 unsigned int idx = 0;
39
40 for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
41 {
42 idx += idxList[iDim] * elementNumInner[iDim];
43 }
44
45 return idx;
46}
47
48// Relocate the coordinate to the reversed tensor
49unsigned int ReverseRelocateIdx(unsigned int idx,
50 unsigned int inputRank,
51 std::vector<bool>& axisFlag,
52 std::vector<unsigned int>& dimSize,
53 std::vector<unsigned int>& elementNumInner)
54{
55 // Get the multidimensional index list for input
56 auto inputIdxList = ReverseGetMultIdx(idx, inputRank, elementNumInner);
57
58 std::vector<unsigned int> outputIdxList(inputRank);
59
60 // Relocate the input index to the output one
61 for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
62 {
63 if (axisFlag[iDim])
64 {
65 outputIdxList[iDim] = dimSize[iDim] - inputIdxList[iDim] - 1;
66 }
67 else
68 {
69 outputIdxList[iDim] = inputIdxList[iDim];
70 }
71 }
72
73 // Get the 1-dimensional flattened index for output
74 unsigned int outputIdx = ReverseGetFlatIdx(outputIdxList, inputRank, elementNumInner);
75 return outputIdx;
76}
77
78void ReverseV2(const TensorInfo& inputInfo,
79 const TensorInfo& axisInfo,
80 Decoder<float>& inputDecoder,
81 Decoder<int>& axisDecoder,
82 Encoder<float>& outputEncoder)
83{
84 unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements());
85
86 // Empty axis and empty tensor case: copy input to output
87 if ((axesRank == 0) || inputInfo.GetNumElements() == 0)
88 {
89 for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
90 {
91 float inputValue = inputDecoder.Get();
92 inputDecoder += 1;
93 outputEncoder.Set(inputValue);
94 outputEncoder += 1;
95 }
96 return;
97 }
98
99 unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());
100
101 std::vector<bool> axisFlag(inputRank, false);
102 std::vector<unsigned int> dimSize(inputRank, 0);
103 std::vector<int32_t> axis(axesRank, 0);
104
105 // Decode the axis information
106 for (unsigned int i=0; i < axesRank; i++)
107 {
108 axis[i] = axisDecoder.Get();
109 axisDecoder += 1;
110 }
111
112 // Make sure the axes are positive
113 for (int32_t axisElement: axis)
114 {
115 axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
116 axisFlag[static_cast<uint32_t>(axisElement)] = true;
117 }
118
119 const TensorShape &inputShape = inputInfo.GetShape();
120
121 unsigned int elementNum = inputInfo.GetNumElements();
122 unsigned int baseDimSize = 1;
123
124 std::vector<unsigned int> elementNumInner;
125
126 // Get the number of element within the specific dimension
127 for (unsigned int iDim = 0; iDim < inputRank; ++iDim) {
128 dimSize[iDim] = inputShape[iDim];
129 baseDimSize *= dimSize[iDim];
130 elementNumInner.push_back(static_cast<unsigned int>(elementNum / baseDimSize));
131 }
132
133 // Iterate through all elements
134 for (unsigned int idx = 0; idx < elementNum; ++idx)
135 {
136 float inputValue = inputDecoder.Get();
137 inputDecoder += 1;
138 auto outputIdx = ReverseRelocateIdx(idx, inputRank, axisFlag, dimSize, elementNumInner);
139 outputEncoder[outputIdx];
140 outputEncoder.Set(inputValue);
141 }
142}
143
144} // namespace armnn
virtual IType Get() const =0
virtual void Set(IType right)=0
const TensorShape & GetShape() const
Definition Tensor.hpp:193
unsigned int GetNumDimensions() const
Definition Tensor.hpp:197
unsigned int GetNumElements() const
Definition Tensor.hpp:198
Copyright (c) 2021 ARM Limited and Contributors.
unsigned int ReverseRelocateIdx(unsigned int idx, unsigned int inputRank, std::vector< bool > &axisFlag, std::vector< unsigned int > &dimSize, std::vector< unsigned int > &elementNumInner)
std::vector< unsigned int > ReverseGetMultIdx(const unsigned int idx, unsigned int inputRank, std::vector< unsigned int > &elementNumInner)
unsigned int ReverseGetFlatIdx(const std::vector< unsigned int > &idxList, unsigned int inputRank, std::vector< unsigned int > &elementNumInner)