ArmNN
 24.08
TensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ITensorHandle.hpp"
9 
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/utility/Assert.hpp>
13 
14 #include <algorithm>
15 
16 namespace armnn
17 {
18 
19 // Get a TensorShape representing the strides (in bytes) for each dimension
20 // of a tensor, assuming fully packed data with no padding
21 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
22 
23 // Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
25 {
26 public:
27  template <typename T>
28  const T* GetConstTensor() const
29  {
30  if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
31  {
32  return reinterpret_cast<const T*>(m_Memory);
33  }
34  else
35  {
36  throw armnn::Exception("Attempting to get not compatible type tensor!");
37  }
38  }
39 
40  const TensorInfo& GetTensorInfo() const
41  {
42  return m_TensorInfo;
43  }
44 
45  virtual void Manage() override {}
46 
47  virtual ITensorHandle* GetParent() const override { return nullptr; }
48 
49  virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
50  virtual void Unmap() const override {}
51 
52  TensorShape GetStrides() const override
53  {
54  return GetUnpaddedTensorStrides(m_TensorInfo);
55  }
56  TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
57 
58 protected:
59  ConstTensorHandle(const TensorInfo& tensorInfo);
60 
61  void SetConstMemory(const void* mem) { m_Memory = mem; }
62 
63 private:
64  // Only used for testing
65  void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
66  void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
67 
68  ConstTensorHandle(const ConstTensorHandle& other) = delete;
69  ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
70 
71  TensorInfo m_TensorInfo;
72  const void* m_Memory;
73 };
74 
75 template<>
76 const void* ConstTensorHandle::GetConstTensor<void>() const;
77 
78 // Abstract specialization of ConstTensorHandle that allows write access to the same data.
80 {
81 public:
82  template <typename T>
83  T* GetTensor() const
84  {
85  if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
86  {
87  return reinterpret_cast<T*>(m_MutableMemory);
88  }
89  else
90  {
91  throw armnn::Exception("Attempting to get not compatible type tensor!");
92  }
93  }
94 
95 protected:
96  TensorHandle(const TensorInfo& tensorInfo);
97 
98  void SetMemory(void* mem)
99  {
100  m_MutableMemory = mem;
101  SetConstMemory(m_MutableMemory);
102  }
103 
104 private:
105 
106  TensorHandle(const TensorHandle& other) = delete;
107  TensorHandle& operator=(const TensorHandle& other) = delete;
108  void* m_MutableMemory;
109 };
110 
111 template <>
112 void* TensorHandle::GetTensor<void>() const;
113 
114 // A TensorHandle that owns the wrapped memory region.
116 {
117 public:
118  explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
119 
120  // Copies contents from Tensor.
121  explicit ScopedTensorHandle(const ConstTensor& tensor);
122 
123  // Copies contents from ConstTensorHandle
124  explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
125 
129 
130  virtual void Allocate() override;
131 
132 private:
133  // Only used for testing
134  void CopyOutTo(void* memory) const override;
135  void CopyInFrom(const void* memory) override;
136 
137  void CopyFrom(const ScopedTensorHandle& other);
138  void CopyFrom(const void* srcMemory, unsigned int numBytes);
139 };
140 
141 // A TensorHandle that wraps an already allocated memory region.
142 //
143 // Clients must make sure the passed in memory region stays alive for the lifetime of
144 // the PassthroughTensorHandle instance.
145 //
146 // Note there is no polymorphism to/from ConstPassthroughTensorHandle.
148 {
149 public:
150  PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
151  : TensorHandle(tensorInfo)
152  {
153  SetMemory(mem);
154  }
155 
156  virtual void Allocate() override;
157 };
158 
159 // A ConstTensorHandle that wraps an already allocated memory region.
160 //
161 // This allows users to pass in const memory to a network.
162 // Clients must make sure the passed in memory region stays alive for the lifetime of
163 // the PassthroughTensorHandle instance.
164 //
165 // Note there is no polymorphism to/from PassthroughTensorHandle.
167 {
168 public:
169  ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
170  : ConstTensorHandle(tensorInfo)
171  {
172  SetConstMemory(mem);
173  }
174 
175  virtual void Allocate() override;
176 };
177 
178 
179 // Template specializations.
180 
181 template <>
182 const void* ConstTensorHandle::GetConstTensor() const;
183 
184 template <>
185 void* TensorHandle::GetTensor() const;
186 
188 {
189 
190 public:
191  explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
192  : m_Mapped(false)
193  , m_TensorHandle(std::move(ptr)) {};
194 
195  /// RAII Managed resource Unmaps MemoryArea once out of scope
196  const void* Map(bool blocking = true)
197  {
198  if (m_TensorHandle)
199  {
200  auto pRet = m_TensorHandle->Map(blocking);
201  m_Mapped = true;
202  return pRet;
203  }
204  else
205  {
206  throw armnn::Exception("Attempting to Map null TensorHandle");
207  }
208 
209  }
210 
211  // Delete copy constructor as it's unnecessary
212  ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;
213 
214  // Delete copy assignment as it's unnecessary
216 
217  // Delete move assignment as it's unnecessary
218  ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
219 
221  {
222  // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
223  if (m_TensorHandle)
224  {
225  Unmap();
226  }
227  }
228 
229  void Unmap()
230  {
231  // Only unmap if mapped and TensorHandle exists.
232  if (m_Mapped && m_TensorHandle)
233  {
234  m_TensorHandle->Unmap();
235  m_Mapped = false;
236  }
237  }
238 
239  const TensorInfo& GetTensorInfo() const
240  {
241  return m_TensorHandle->GetTensorInfo();
242  }
243 
244  bool IsMapped() const
245  {
246  return m_Mapped;
247  }
248 
249 private:
250  bool m_Mapped;
251  std::shared_ptr<ConstTensorHandle> m_TensorHandle;
252 };
253 
254 } // namespace armnn
armnn::ManagedConstTensorHandle::operator=
ManagedConstTensorHandle & operator=(const ManagedConstTensorHandle &other)=delete
armnn::ScopedTensorHandle::~ScopedTensorHandle
~ScopedTensorHandle()
Definition: TensorHandle.cpp:86
armnn::ConstTensorHandle::GetShape
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
Definition: TensorHandle.hpp:56
armnn::ConstTensorHandle
Definition: TensorHandle.hpp:24
armnn::ManagedConstTensorHandle::Unmap
void Unmap()
Definition: TensorHandle.hpp:229
armnn::ConstTensorHandle::Map
virtual const void * Map(bool) const override
Map the tensor data for access.
Definition: TensorHandle.hpp:49
armnn::ConstPassthroughTensorHandle::Allocate
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
Definition: TensorHandle.cpp:163
TypesUtils.hpp
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::ScopedTensorHandle::ScopedTensorHandle
ScopedTensorHandle(const TensorInfo &tensorInfo)
Definition: TensorHandle.cpp:55
armnn::TensorHandle::GetTensor
T * GetTensor() const
Definition: TensorHandle.hpp:83
armnn::GetUnpaddedTensorStrides
TensorShape GetUnpaddedTensorStrides(const TensorInfo &tensorInfo)
Definition: TensorHandle.cpp:15
armnn::ITensorHandle
Definition: ITensorHandle.hpp:16
armnn::ConstPassthroughTensorHandle::ConstPassthroughTensorHandle
ConstPassthroughTensorHandle(const TensorInfo &tensorInfo, const void *mem)
Definition: TensorHandle.hpp:169
armnn::ManagedConstTensorHandle::IsMapped
bool IsMapped() const
Definition: TensorHandle.hpp:244
armnn::ManagedConstTensorHandle::ManagedConstTensorHandle
ManagedConstTensorHandle(std::shared_ptr< ConstTensorHandle > ptr)
Definition: TensorHandle.hpp:191
armnn::ConstTensorHandle::GetTensorInfo
const TensorInfo & GetTensorInfo() const
Definition: TensorHandle.hpp:40
armnn::TensorHandle
Definition: TensorHandle.hpp:79
ARMNN_ASSERT_MSG
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
armnn::ManagedConstTensorHandle
Definition: TensorHandle.hpp:187
armnn::ConstTensorHandle::Manage
virtual void Manage() override
Indicate to the memory manager that this resource is active.
Definition: TensorHandle.hpp:45
Assert.hpp
armnn::ConstTensorHandle::GetConstTensor
const T * GetConstTensor() const
Definition: TensorHandle.hpp:28
armnn::TensorShape
Definition: Tensor.hpp:20
ITensorHandle.hpp
armnn::TensorHandle::SetMemory
void SetMemory(void *mem)
Definition: TensorHandle.hpp:98
CompatibleTypes.hpp
armnn::ManagedConstTensorHandle::~ManagedConstTensorHandle
~ManagedConstTensorHandle()
Definition: TensorHandle.hpp:220
armnn::ManagedConstTensorHandle::Map
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
Definition: TensorHandle.hpp:196
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnn::TensorHandle::TensorHandle
TensorHandle(const TensorInfo &tensorInfo)
Definition: TensorHandle.cpp:43
armnn::ConstTensorHandle::GetParent
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
Definition: TensorHandle.hpp:47
armnn::ScopedTensorHandle::Allocate
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
Definition: TensorHandle.cpp:91
armnn::ConstTensorHandle::Unmap
virtual void Unmap() const override
Unmap the tensor data.
Definition: TensorHandle.hpp:50
armnn::ConstTensorHandle::ConstTensorHandle
ConstTensorHandle(const TensorInfo &tensorInfo)
Definition: TensorHandle.cpp:31
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
std
Definition: BackendId.hpp:149
armnn::ConstPassthroughTensorHandle
Definition: TensorHandle.hpp:166
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::ConstTensorHandle::SetConstMemory
void SetConstMemory(const void *mem)
Definition: TensorHandle.hpp:61
armnn::ConstTensor
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:329
armnn::PassthroughTensorHandle::PassthroughTensorHandle
PassthroughTensorHandle(const TensorInfo &tensorInfo, void *mem)
Definition: TensorHandle.hpp:150
armnn::PassthroughTensorHandle::Allocate
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
Definition: TensorHandle.cpp:158
armnn::ScopedTensorHandle
Definition: TensorHandle.hpp:115
armnn::ConstTensorHandle::GetStrides
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
Definition: TensorHandle.hpp:52
armnn::ScopedTensorHandle::operator=
ScopedTensorHandle & operator=(const ScopedTensorHandle &other)
Definition: TensorHandle.cpp:78
armnn::PassthroughTensorHandle
Definition: TensorHandle.hpp:147
armnn::ManagedConstTensorHandle::GetTensorInfo
const TensorInfo & GetTensorInfo() const
Definition: TensorHandle.hpp:239