ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
DataLayoutIndexed.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2018-2021,2023 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
18 {
19 public:
21 
22  armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23  unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24  unsigned int GetHeightIndex() const { return m_HeightIndex; }
25  unsigned int GetWidthIndex() const { return m_WidthIndex; }
26  unsigned int GetDepthIndex() const { return m_DepthIndex; }
27 
28  inline unsigned int GetIndex(const armnn::TensorShape& shape,
29  unsigned int batchIndex, unsigned int channelIndex,
30  unsigned int heightIndex, unsigned int widthIndex) const
31  {
32  if (batchIndex >= shape[0] && !( shape[0] == 0 && batchIndex == 0))
33  {
34  throw armnn::Exception("Unable to get batch index", CHECK_LOCATION());
35  }
36  if (channelIndex >= shape[m_ChannelsIndex] &&
37  !(shape[m_ChannelsIndex] == 0 && channelIndex == 0))
38  {
39  throw armnn::Exception("Unable to get channel index", CHECK_LOCATION());
40 
41  }
42  if (heightIndex >= shape[m_HeightIndex] &&
43  !( shape[m_HeightIndex] == 0 && heightIndex == 0))
44  {
45  throw armnn::Exception("Unable to get height index", CHECK_LOCATION());
46  }
47  if (widthIndex >= shape[m_WidthIndex] &&
48  ( shape[m_WidthIndex] == 0 && widthIndex == 0))
49  {
50  throw armnn::Exception("Unable to get width index", CHECK_LOCATION());
51  }
52 
53  /// Offset the given indices appropriately depending on the data layout
54  switch (m_DataLayout)
55  {
57  batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
58  heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
59  widthIndex *= shape[m_ChannelsIndex];
60  /// channelIndex stays unchanged
61  break;
63  default:
64  batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
65  channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
66  heightIndex *= shape[m_WidthIndex];
67  /// widthIndex stays unchanged
68  break;
69  }
70 
71  /// Get the value using the correct offset
72  return batchIndex + channelIndex + heightIndex + widthIndex;
73  }
74 
75 private:
76  armnn::DataLayout m_DataLayout;
77  unsigned int m_ChannelsIndex;
78  unsigned int m_HeightIndex;
79  unsigned int m_WidthIndex;
80  unsigned int m_DepthIndex;
81 };
82 
83 /// Equality methods
84 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
85 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
86 
87 } // namespace armnnUtils
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:47
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
unsigned int GetWidthIndex() const
unsigned int GetDepthIndex() const
unsigned int GetHeightIndex() const
armnn::DataLayout GetDataLayout() const
unsigned int GetChannelsIndex() const
DataLayoutIndexed(armnn::DataLayout dataLayout)
DataLayout
Definition: Types.hpp:63
bool operator==(const armnn::DataLayout &dataLayout, const DataLayoutIndexed &indexed)
Equality methods.