ArmNN
 26.01
Loading...
Searching...
No Matches
SampleTensorHandle.cpp
Go to the documentation of this file.
1//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7
8namespace sdb // sample dynamic backend
9{
10
12 std::shared_ptr<SampleMemoryManager> &memoryManager)
13 : m_TensorInfo(tensorInfo),
14 m_MemoryManager(memoryManager),
15 m_Pool(nullptr),
16 m_UnmanagedMemory(nullptr),
17 m_ImportFlags(static_cast<armnn::MemorySourceFlags>(armnn::MemorySource::Undefined)),
18 m_Imported(false)
19{
20
21}
22
24 armnn::MemorySourceFlags importFlags)
25 : m_TensorInfo(tensorInfo),
26 m_MemoryManager(nullptr),
27 m_Pool(nullptr),
28 m_UnmanagedMemory(nullptr),
29 m_ImportFlags(importFlags),
30 m_Imported(true)
31{
32
33}
34
36{
37 if (!m_Pool)
38 {
39 // unmanaged
40 if (!m_Imported)
41 {
42 ::operator delete(m_UnmanagedMemory);
43 }
44 }
45}
46
48{
49 m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50}
51
53{
54 if (!m_UnmanagedMemory)
55 {
56 if (!m_Pool)
57 {
58 // unmanaged
59 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
60 }
61 else
62 {
63 m_MemoryManager->Allocate(m_Pool);
64 }
65 }
66 else
67 {
68 throw armnn::InvalidArgumentException("SampleTensorHandle::Allocate Trying to allocate a "
69 "SampleTensorHandle that already has allocated "
70 "memory.");
71 }
72}
73
74const void* SampleTensorHandle::Map(bool /*unused*/) const
75{
76 return GetPointer();
77}
78
79void* SampleTensorHandle::GetPointer() const
80{
81 if (m_UnmanagedMemory)
82 {
83 return m_UnmanagedMemory;
84 }
85 else
86 {
87 return m_MemoryManager->GetPointer(m_Pool);
88 }
89}
90
92{
93
94 if (m_ImportFlags & static_cast<armnn::MemorySourceFlags>(source))
95 {
96 if (source == armnn::MemorySource::Malloc)
97 {
98 // Check memory alignment
99 constexpr uintptr_t alignment = sizeof(size_t);
100 if (reinterpret_cast<uintptr_t>(memory) % alignment)
101 {
102 if (m_Imported)
103 {
104 m_Imported = false;
105 m_UnmanagedMemory = nullptr;
106 }
107
108 return false;
109 }
110
111 // m_UnmanagedMemory not yet allocated.
112 if (!m_Imported && !m_UnmanagedMemory)
113 {
114 m_UnmanagedMemory = memory;
115 m_Imported = true;
116 return true;
117 }
118
119 // m_UnmanagedMemory initially allocated with Allocate().
120 if (!m_Imported && m_UnmanagedMemory)
121 {
122 return false;
123 }
124
125 // m_UnmanagedMemory previously imported.
126 if (m_Imported)
127 {
128 m_UnmanagedMemory = memory;
129 return true;
130 }
131 }
132 }
133
134 return false;
135}
136
137void SampleTensorHandle::CopyOutTo(void* dest) const
138{
139 const void* src = GetPointer();
140 if (dest == nullptr)
141 {
142 throw armnn::Exception("SampleTensorHandle:CopyOutTo: Destination Ptr is null");
143 }
144 if (src == nullptr)
145 {
146 throw armnn::Exception("SampleTensorHandle:CopyOutTo: Source Ptr is null");
147 }
148 memcpy(dest, src, m_TensorInfo.GetNumBytes());
149}
150
151void SampleTensorHandle::CopyInFrom(const void* src)
152{
153 void* dest = GetPointer();
154 if (src == nullptr)
155 {
156 throw armnn::Exception("SampleTensorHandle:CopyInFrom: Source Ptr is null");
157 }
158 if (dest == nullptr)
159 {
160 throw armnn::Exception("SampleTensorHandle:CopyInFrom: Destination Ptr is null");
161 }
162 memcpy(dest, src, m_TensorInfo.GetNumBytes());
163}
164
165} // namespace sdb
Base class for all ArmNN exceptions so that users can filter to just those.
unsigned int GetNumBytes() const
Definition Tensor.cpp:427
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual bool Import(void *memory, armnn::MemorySource source) override
Import externally allocated memory.
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual const void * Map(bool) const override
Map the tensor data for access.
SampleTensorHandle(const armnn::TensorInfo &tensorInfo, std::shared_ptr< SampleMemoryManager > &memoryManager)
Copyright (c) 2021 ARM Limited and Contributors.
MemorySource
Define the Memory Source to reduce copies.
Definition Types.hpp:246
unsigned int MemorySourceFlags