ArmNN
 24.02
TensorBufferArrayView.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Tensor.hpp>
9 
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 namespace armnn
15 {
16 
17 // Utility class providing access to raw tensor memory based on indices along each dimension.
18 template <typename DataType>
20 {
21 public:
24  : m_Shape(shape)
25  , m_Data(data)
26  , m_DataLayout(dataLayout)
27  {
28  ARMNN_ASSERT(m_Shape.GetNumDimensions() == 4);
29  }
30 
31  DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
32  {
33  return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)];
34  }
35 
36 private:
37  const TensorShape m_Shape;
38  DataType* m_Data;
39  armnnUtils::DataLayoutIndexed m_DataLayout;
40 };
41 
42 } //namespace armnn
ARMNN_ASSERT
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
armnn::TensorBufferArrayView::TensorBufferArrayView
TensorBufferArrayView(const TensorShape &shape, DataType *data, armnnUtils::DataLayoutIndexed dataLayout=DataLayout::NCHW)
Definition: TensorBufferArrayView.hpp:22
armnnUtils::DataLayoutIndexed
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
Definition: DataLayoutIndexed.hpp:17
Assert.hpp
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::TensorShape::GetNumDimensions
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
armnn::DataType
DataType
Definition: Types.hpp:48
Tensor.hpp
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::TensorBufferArrayView::Get
DataType & Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
Definition: TensorBufferArrayView.hpp:31
armnn::TensorBufferArrayView
Definition: TensorBufferArrayView.hpp:19
armnnUtils::DataLayoutIndexed::GetIndex
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
Definition: DataLayoutIndexed.hpp:28
DataLayoutIndexed.hpp
armnn::DataLayout::NCHW
@ NCHW