ArmNN
 26.01
Loading...
Searching...
No Matches
DataLayoutIndexed.hpp
Go to the documentation of this file.
1//
2// Copyright © 2018-2021,2023 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Types.hpp>
9#include <armnn/Tensor.hpp>
10
12
13namespace armnnUtils
14{
15
16/// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
18{
19public:
21
22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24 unsigned int GetHeightIndex() const { return m_HeightIndex; }
25 unsigned int GetWidthIndex() const { return m_WidthIndex; }
26 unsigned int GetDepthIndex() const { return m_DepthIndex; }
27
28 inline unsigned int GetIndex(const armnn::TensorShape& shape,
29 unsigned int batchIndex, unsigned int channelIndex,
30 unsigned int heightIndex, unsigned int widthIndex) const
31 {
32 if (batchIndex >= shape[0] && !( shape[0] == 0 && batchIndex == 0))
33 {
34 throw armnn::Exception("Unable to get batch index", CHECK_LOCATION());
35 }
36 if (channelIndex >= shape[m_ChannelsIndex] &&
37 !(shape[m_ChannelsIndex] == 0 && channelIndex == 0))
38 {
39 throw armnn::Exception("Unable to get channel index", CHECK_LOCATION());
40
41 }
42 if (heightIndex >= shape[m_HeightIndex] &&
43 !( shape[m_HeightIndex] == 0 && heightIndex == 0))
44 {
45 throw armnn::Exception("Unable to get height index", CHECK_LOCATION());
46 }
47 if (widthIndex >= shape[m_WidthIndex] &&
48 ( shape[m_WidthIndex] == 0 && widthIndex == 0))
49 {
50 throw armnn::Exception("Unable to get width index", CHECK_LOCATION());
51 }
52
53 /// Offset the given indices appropriately depending on the data layout
54 switch (m_DataLayout)
55 {
57 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
58 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
59 widthIndex *= shape[m_ChannelsIndex];
60 /// channelIndex stays unchanged
61 break;
63 default:
64 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
65 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
66 heightIndex *= shape[m_WidthIndex];
67 /// widthIndex stays unchanged
68 break;
69 }
70
71 /// Get the value using the correct offset
72 return batchIndex + channelIndex + heightIndex + widthIndex;
73 }
74
75private:
76 armnn::DataLayout m_DataLayout;
77 unsigned int m_ChannelsIndex;
78 unsigned int m_HeightIndex;
79 unsigned int m_WidthIndex;
80 unsigned int m_DepthIndex;
81};
82
83/// Equality methods
84bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
85bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
86
87} // namespace armnnUtils
#define CHECK_LOCATION()
Base class for all ArmNN exceptions so that users can filter to just those.
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
unsigned int GetHeightIndex() const
armnn::DataLayout GetDataLayout() const
unsigned int GetChannelsIndex() const
DataLayoutIndexed(armnn::DataLayout dataLayout)
DataLayout
Definition Types.hpp:63
bool operator==(const armnn::DataLayout &dataLayout, const DataLayoutIndexed &indexed)
Equality methods.