ArmNN
 25.11
Loading...
Searching...
No Matches
TensorHandle.hpp
Go to the documentation of this file.
1//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ITensorHandle.hpp"
9
10#include <armnn/TypesUtils.hpp>
13
14#include <algorithm>
15
16namespace armnn
17{
18
19// Get a TensorShape representing the strides (in bytes) for each dimension
20// of a tensor, assuming fully packed data with no padding
22
23// Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
25{
26public:
27 template <typename T>
28 const T* GetConstTensor() const
29 {
31 {
32 return reinterpret_cast<const T*>(m_Memory);
33 }
34 else
35 {
36 throw armnn::Exception("Attempting to get not compatible type tensor!");
37 }
38 }
39
41 {
42 return m_TensorInfo;
43 }
44
45 virtual void Manage() override {}
46
47 virtual ITensorHandle* GetParent() const override { return nullptr; }
48
49 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
50 virtual void Unmap() const override {}
51
52 TensorShape GetStrides() const override
53 {
54 return GetUnpaddedTensorStrides(m_TensorInfo);
55 }
56 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
57
58protected:
59 ConstTensorHandle(const TensorInfo& tensorInfo);
60
61 void SetConstMemory(const void* mem) { m_Memory = mem; }
62
63private:
64 // Only used for testing
65 void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
66 void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
67
68 ConstTensorHandle(const ConstTensorHandle& other) = delete;
69 ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
70
71 TensorInfo m_TensorInfo;
72 const void* m_Memory;
73};
74
75template<>
77
78// Abstract specialization of ConstTensorHandle that allows write access to the same data.
80{
81public:
82 template <typename T>
83 T* GetTensor() const
84 {
86 {
87 return reinterpret_cast<T*>(m_MutableMemory);
88 }
89 else
90 {
91 throw armnn::Exception("Attempting to get not compatible type tensor!");
92 }
93 }
94
95protected:
96 TensorHandle(const TensorInfo& tensorInfo);
97
98 void SetMemory(void* mem)
99 {
100 m_MutableMemory = mem;
101 SetConstMemory(m_MutableMemory);
102 }
103
104private:
105
106 TensorHandle(const TensorHandle& other) = delete;
107 TensorHandle& operator=(const TensorHandle& other) = delete;
108 void* m_MutableMemory;
109};
110
111template <>
113
114// A TensorHandle that owns the wrapped memory region.
116{
117public:
118 explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
119
120 // Copies contents from Tensor.
121 explicit ScopedTensorHandle(const ConstTensor& tensor);
122
123 // Copies contents from ConstTensorHandle
124 explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
125
129
130 virtual void Allocate() override;
131
132private:
133 // Only used for testing
134 void CopyOutTo(void* memory) const override;
135 void CopyInFrom(const void* memory) override;
136
137 void CopyFrom(const ScopedTensorHandle& other);
138 void CopyFrom(const void* srcMemory, unsigned int numBytes);
139};
140
141// A TensorHandle that wraps an already allocated memory region.
142//
143// Clients must make sure the passed in memory region stays alive for the lifetime of
144// the PassthroughTensorHandle instance.
145//
146// Note there is no polymorphism to/from ConstPassthroughTensorHandle.
148{
149public:
150 PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
151 : TensorHandle(tensorInfo)
152 {
153 SetMemory(mem);
154 }
155
156 virtual void Allocate() override;
157};
158
159// A ConstTensorHandle that wraps an already allocated memory region.
160//
161// This allows users to pass in const memory to a network.
162// Clients must make sure the passed in memory region stays alive for the lifetime of
163// the PassthroughTensorHandle instance.
164//
165// Note there is no polymorphism to/from PassthroughTensorHandle.
167{
168public:
169 ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
170 : ConstTensorHandle(tensorInfo)
171 {
172 SetConstMemory(mem);
173 }
174
175 virtual void Allocate() override;
176};
177
178
179// Template specializations.
180
181template <>
183
184template <>
186
188{
189
190public:
191 explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
192 : m_Mapped(false)
193 , m_TensorHandle(std::move(ptr)) {};
194
195 /// RAII Managed resource Unmaps MemoryArea once out of scope
196 const void* Map(bool blocking = true)
197 {
198 if (m_TensorHandle)
199 {
200 auto pRet = m_TensorHandle->Map(blocking);
201 m_Mapped = true;
202 return pRet;
203 }
204 else
205 {
206 throw armnn::Exception("Attempting to Map null TensorHandle");
207 }
208
209 }
210
211 // Delete copy constructor as it's unnecessary
213
214 // Delete copy assignment as it's unnecessary
216
217 // Delete move assignment as it's unnecessary
219
221 {
222 // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
223 if (m_TensorHandle)
224 {
225 Unmap();
226 }
227 }
228
229 void Unmap()
230 {
231 // Only unmap if mapped and TensorHandle exists.
232 if (m_Mapped && m_TensorHandle)
233 {
234 m_TensorHandle->Unmap();
235 m_Mapped = false;
236 }
237 }
238
240 {
241 return m_TensorHandle->GetTensorInfo();
242 }
243
244 bool IsMapped() const
245 {
246 return m_Mapped;
247 }
248
249private:
250 bool m_Mapped;
251 std::shared_ptr<ConstTensorHandle> m_TensorHandle;
252};
253
254} // namespace armnn
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition Assert.hpp:15
ConstPassthroughTensorHandle(const TensorInfo &tensorInfo, const void *mem)
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
ConstTensorHandle(const TensorInfo &tensorInfo)
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual const void * Map(bool) const override
Map the tensor data for access.
const TensorInfo & GetTensorInfo() const
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
const T * GetConstTensor() const
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
void SetConstMemory(const void *mem)
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Base class for all ArmNN exceptions so that users can filter to just those.
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const TensorInfo & GetTensorInfo() const
ManagedConstTensorHandle(const ConstTensorHandle &other)=delete
ManagedConstTensorHandle & operator=(ManagedConstTensorHandle &&other) noexcept=delete
ManagedConstTensorHandle(std::shared_ptr< ConstTensorHandle > ptr)
ManagedConstTensorHandle & operator=(const ManagedConstTensorHandle &other)=delete
PassthroughTensorHandle(const TensorInfo &tensorInfo, void *mem)
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
ScopedTensorHandle(const TensorInfo &tensorInfo)
ScopedTensorHandle & operator=(const ScopedTensorHandle &other)
void SetMemory(void *mem)
TensorHandle(const TensorInfo &tensorInfo)
Copyright (c) 2021 ARM Limited and Contributors.
TensorShape GetUnpaddedTensorStrides(const TensorInfo &tensorInfo)
bool CompatibleTypes(armnn::DataType)