ArmNN
 25.11
Loading...
Searching...
No Matches
DataLayoutIndexed.hpp
Go to the documentation of this file.
1//
2// Copyright © 2018-2021,2023 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Types.hpp>
9#include <armnn/Tensor.hpp>
10
12
13namespace armnnUtils
14{
15
16/// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
18{
19public:
21
22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24 unsigned int GetHeightIndex() const { return m_HeightIndex; }
25 unsigned int GetWidthIndex() const { return m_WidthIndex; }
26 unsigned int GetDepthIndex() const { return m_DepthIndex; }
27
28 inline unsigned int GetIndex(const armnn::TensorShape& shape,
29 unsigned int batchIndex, unsigned int channelIndex,
30 unsigned int heightIndex, unsigned int widthIndex) const
31 {
32 if (batchIndex >= shape[0] && !( shape[0] == 0 && batchIndex == 0))
33 {
34 throw armnn::Exception("Unable to get batch index", CHECK_LOCATION());
35 }
36 if (channelIndex >= shape[m_ChannelsIndex] &&
37 !(shape[m_ChannelsIndex] == 0 && channelIndex == 0))
38 {
39 throw armnn::Exception("Unable to get channel index", CHECK_LOCATION());
40
41 }
42 if (heightIndex >= shape[m_HeightIndex] &&
43 !( shape[m_HeightIndex] == 0 && heightIndex == 0))
44 {
45 throw armnn::Exception("Unable to get height index", CHECK_LOCATION());
46 }
47 if (widthIndex >= shape[m_WidthIndex] &&
48 ( shape[m_WidthIndex] == 0 && widthIndex == 0))
49 {
50 throw armnn::Exception("Unable to get width index", CHECK_LOCATION());
51 }
52
53 /// Offset the given indices appropriately depending on the data layout
54 switch (m_DataLayout)
55 {
57 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
58 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
59 widthIndex *= shape[m_ChannelsIndex];
60 /// channelIndex stays unchanged
61 break;
63 default:
64 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
65 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
66 heightIndex *= shape[m_WidthIndex];
67 /// widthIndex stays unchanged
68 break;
69 }
70
71 /// Get the value using the correct offset
72 return batchIndex + channelIndex + heightIndex + widthIndex;
73 }
74
75private:
76 armnn::DataLayout m_DataLayout;
77 unsigned int m_ChannelsIndex;
78 unsigned int m_HeightIndex;
79 unsigned int m_WidthIndex;
80 unsigned int m_DepthIndex;
81};
82
83/// Equality methods
84bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
85bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
86
87} // namespace armnnUtils
#define CHECK_LOCATION()
Base class for all ArmNN exceptions so that users can filter to just those.
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
unsigned int GetHeightIndex() const
armnn::DataLayout GetDataLayout() const
unsigned int GetChannelsIndex() const
DataLayoutIndexed(armnn::DataLayout dataLayout)
DataLayout
Definition Types.hpp:63
bool operator==(const armnn::DataLayout &dataLayout, const DataLayoutIndexed &indexed)
Equality methods.