ArmNN
 25.11
Loading...
Searching...
No Matches
ClTensorHandle.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
10
11#include <Half.hpp>
12
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
23
24namespace armnn
25{
27
29{
30public:
31 ClTensorHandle(const TensorInfo& tensorInfo)
32 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
33 m_Imported(false),
34 m_IsImportEnabled(false)
35 {
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37 }
38
39 ClTensorHandle(const TensorInfo& tensorInfo,
40 DataLayout dataLayout,
42 : m_ImportFlags(importFlags),
43 m_Imported(false),
44 m_IsImportEnabled(false)
45 {
46 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
47 }
48
49 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
50 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
51 virtual void Allocate() override
52 {
53 // If we have enabled Importing, don't allocate the tensor
54 if (m_IsImportEnabled)
55 {
56 throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
57 }
58 else
59 {
60 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
61 }
62
63 }
64
65 virtual void Manage() override
66 {
67 // If we have enabled Importing, don't manage the tensor
68 if (m_IsImportEnabled)
69 {
70 throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
71 }
72 else
73 {
74 assert(m_MemoryGroup != nullptr);
75 m_MemoryGroup->manage(&m_Tensor);
76 }
77 }
78
79 virtual const void* Map(bool blocking = true) const override
80 {
81 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
82 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
83 }
84
85 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
86
87 virtual ITensorHandle* GetParent() const override { return nullptr; }
88
89 virtual arm_compute::DataType GetDataType() const override
90 {
91 return m_Tensor.info()->data_type();
92 }
93
94 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
95 {
97 }
98
99 TensorShape GetStrides() const override
100 {
101 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
102 }
103
104 TensorShape GetShape() const override
105 {
106 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
107 }
108
110 {
111 m_ImportFlags = importFlags;
112 }
113
115 {
116 return m_ImportFlags;
117 }
118
119 void SetImportEnabledFlag(bool importEnabledFlag)
120 {
121 m_IsImportEnabled = importEnabledFlag;
122 }
123
124 virtual bool Import(void* memory, MemorySource source) override
125 {
126 armnn::IgnoreUnused(memory);
127 if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
128 {
129 throw MemoryImportException("ClTensorHandle::Incorrect import flag");
130 }
131 m_Imported = false;
132 return false;
133 }
134
135 virtual bool CanBeImported(void* memory, MemorySource source) override
136 {
137 // This TensorHandle can never import.
138 armnn::IgnoreUnused(memory, source);
139 return false;
140 }
141
142 virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
143
144private:
145 // Only used for testing
146 void CopyOutTo(void* memory) const override
147 {
148 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
149 switch(this->GetDataType())
150 {
151 case arm_compute::DataType::F32:
152 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
153 static_cast<float*>(memory));
154 break;
155 case arm_compute::DataType::U8:
156 case arm_compute::DataType::QASYMM8:
157 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
158 static_cast<uint8_t*>(memory));
159 break;
160 case arm_compute::DataType::QSYMM8:
161 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
162 case arm_compute::DataType::QASYMM8_SIGNED:
163 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164 static_cast<int8_t*>(memory));
165 break;
166 case arm_compute::DataType::F16:
167 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
168 static_cast<armnn::Half*>(memory));
169 break;
170 case arm_compute::DataType::S16:
171 case arm_compute::DataType::QSYMM16:
172 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173 static_cast<int16_t*>(memory));
174 break;
175 case arm_compute::DataType::S32:
176 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
177 static_cast<int32_t*>(memory));
178 break;
179 case arm_compute::DataType::S64:
180 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
181 static_cast<int64_t*>(memory));
182 break;
183
184 default:
185 {
187 }
188 }
189 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
190 }
191
192 // Only used for testing
193 void CopyInFrom(const void* memory) override
194 {
195 this->Map(true);
196 switch(this->GetDataType())
197 {
198 case arm_compute::DataType::F32:
199 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
200 this->GetTensor());
201 break;
202 case arm_compute::DataType::U8:
203 case arm_compute::DataType::QASYMM8:
204 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
205 this->GetTensor());
206 break;
207 case arm_compute::DataType::F16:
208 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
209 this->GetTensor());
210 break;
211 case arm_compute::DataType::S16:
212 case arm_compute::DataType::QSYMM8:
213 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
214 case arm_compute::DataType::QASYMM8_SIGNED:
215 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
216 this->GetTensor());
217 break;
218 case arm_compute::DataType::QSYMM16:
219 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
220 this->GetTensor());
221 break;
222 case arm_compute::DataType::S32:
223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
224 this->GetTensor());
225 break;
226 case arm_compute::DataType::S64:
227 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int64_t*>(memory),
228 this->GetTensor());
229 break;
230 default:
231 {
232 throw armnn::UnimplementedException();
233 }
234 }
235 this->Unmap();
236 }
237
238 arm_compute::CLTensor m_Tensor;
239 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
240 MemorySourceFlags m_ImportFlags;
241 bool m_Imported;
242 bool m_IsImportEnabled;
243 std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated;
244};
245
247{
248public:
250 const arm_compute::TensorShape& shape,
251 const arm_compute::Coordinates& coords)
252 : m_Tensor(&parent->GetTensor(), shape, coords)
253 {
254 parentHandle = parent;
255 }
256
257 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
258 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
259
260 virtual void Allocate() override {}
261 virtual void Manage() override {}
262
263 virtual const void* Map(bool blocking = true) const override
264 {
265 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
266 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
267 }
268 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
269
270 virtual ITensorHandle* GetParent() const override { return parentHandle; }
271
272 virtual arm_compute::DataType GetDataType() const override
273 {
274 return m_Tensor.info()->data_type();
275 }
276
277 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
278
279 TensorShape GetStrides() const override
280 {
281 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
282 }
283
284 TensorShape GetShape() const override
285 {
286 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
287 }
288
289private:
290 // Only used for testing
291 void CopyOutTo(void* memory) const override
292 {
293 const_cast<ClSubTensorHandle*>(this)->Map(true);
294 switch(this->GetDataType())
295 {
296 case arm_compute::DataType::F32:
297 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
298 static_cast<float*>(memory));
299 break;
300 case arm_compute::DataType::U8:
301 case arm_compute::DataType::QASYMM8:
302 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
303 static_cast<uint8_t*>(memory));
304 break;
305 case arm_compute::DataType::F16:
306 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
307 static_cast<armnn::Half*>(memory));
308 break;
309 case arm_compute::DataType::QSYMM8:
310 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
311 case arm_compute::DataType::QASYMM8_SIGNED:
312 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
313 static_cast<int8_t*>(memory));
314 break;
315 case arm_compute::DataType::S16:
316 case arm_compute::DataType::QSYMM16:
317 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
318 static_cast<int16_t*>(memory));
319 break;
320 case arm_compute::DataType::S32:
321 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
322 static_cast<int32_t*>(memory));
323 break;
324 default:
325 {
327 }
328 }
329 const_cast<ClSubTensorHandle*>(this)->Unmap();
330 }
331
332 // Only used for testing
333 void CopyInFrom(const void* memory) override
334 {
335 this->Map(true);
336 switch(this->GetDataType())
337 {
338 case arm_compute::DataType::F32:
339 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
340 this->GetTensor());
341 break;
342 case arm_compute::DataType::U8:
343 case arm_compute::DataType::QASYMM8:
344 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
345 this->GetTensor());
346 break;
347 case arm_compute::DataType::F16:
348 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
349 this->GetTensor());
350 break;
351 case arm_compute::DataType::QSYMM8:
352 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
353 case arm_compute::DataType::QASYMM8_SIGNED:
354 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
355 this->GetTensor());
356 break;
357 case arm_compute::DataType::S16:
358 case arm_compute::DataType::QSYMM16:
359 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
360 this->GetTensor());
361 break;
362 case arm_compute::DataType::S32:
363 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
364 this->GetTensor());
365 break;
366 default:
367 {
368 throw armnn::UnimplementedException();
369 }
370 }
371 this->Unmap();
372 }
373
374 mutable arm_compute::CLSubTensor m_Tensor;
375 ITensorHandle* parentHandle = nullptr;
376};
377
378/** ClTensorDecorator wraps an existing CL tensor allowing us to override the TensorInfo for it */
379class ClTensorDecorator : public arm_compute::ICLTensor
380{
381public:
383
384 ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& info);
385
387
389
391
393
395
396 arm_compute::ICLTensor* parent();
397
398 void map(bool blocking = true);
399 using arm_compute::ICLTensor::map;
400
401 void unmap();
402 using arm_compute::ICLTensor::unmap;
403
404 virtual arm_compute::ITensorInfo* info() const override;
405 virtual arm_compute::ITensorInfo* info() override;
406 const cl::Buffer& cl_buffer() const override;
407 arm_compute::CLQuantization quantization() const override;
408
409protected:
410 // Inherited methods overridden:
411 uint8_t* do_map(cl::CommandQueue& q, bool blocking) override;
412 void do_unmap(cl::CommandQueue& q) override;
413
414private:
415 arm_compute::ICLTensor* m_Original;
416 mutable arm_compute::TensorInfo m_TensorInfo;
417};
418
420{
421public:
423 : m_Tensor(&parent->GetTensor(), info)
424 {
425 m_OriginalHandle = parent;
426 }
427
428 arm_compute::ICLTensor& GetTensor() override { return m_Tensor; }
429 arm_compute::ICLTensor const& GetTensor() const override { return m_Tensor; }
430
431 virtual void Allocate() override {}
432 virtual void Manage() override {}
433
434 virtual const void* Map(bool blocking = true) const override
435 {
436 m_Tensor.map(blocking);
437 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
438 }
439
440 virtual void Unmap() const override
441 {
442 m_Tensor.unmap();
443 }
444
445 virtual ITensorHandle* GetParent() const override { return nullptr; }
446
447 virtual arm_compute::DataType GetDataType() const override
448 {
449 return m_Tensor.info()->data_type();
450 }
451
452 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
453
454 TensorShape GetStrides() const override
455 {
456 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
457 }
458
459 TensorShape GetShape() const override
460 {
461 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
462 }
463
464private:
465 // Only used for testing
466 void CopyOutTo(void* memory) const override
467 {
468 const_cast<ClTensorHandleDecorator*>(this)->Map(true);
469 switch(this->GetDataType())
470 {
471 case arm_compute::DataType::F32:
472 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
473 static_cast<float*>(memory));
474 break;
475 case arm_compute::DataType::U8:
476 case arm_compute::DataType::QASYMM8:
477 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
478 static_cast<uint8_t*>(memory));
479 break;
480 case arm_compute::DataType::F16:
481 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
482 static_cast<armnn::Half*>(memory));
483 break;
484 case arm_compute::DataType::QSYMM8:
485 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
486 case arm_compute::DataType::QASYMM8_SIGNED:
487 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
488 static_cast<int8_t*>(memory));
489 break;
490 case arm_compute::DataType::S16:
491 case arm_compute::DataType::QSYMM16:
492 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
493 static_cast<int16_t*>(memory));
494 break;
495 case arm_compute::DataType::S32:
496 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
497 static_cast<int32_t*>(memory));
498 break;
499 default:
500 {
502 }
503 }
504 const_cast<ClTensorHandleDecorator*>(this)->Unmap();
505 }
506
507 // Only used for testing
508 void CopyInFrom(const void* memory) override
509 {
510 this->Map(true);
511 switch(this->GetDataType())
512 {
513 case arm_compute::DataType::F32:
514 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
515 this->GetTensor());
516 break;
517 case arm_compute::DataType::U8:
518 case arm_compute::DataType::QASYMM8:
519 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
520 this->GetTensor());
521 break;
522 case arm_compute::DataType::F16:
523 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
524 this->GetTensor());
525 break;
526 case arm_compute::DataType::QSYMM8:
527 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
528 case arm_compute::DataType::QASYMM8_SIGNED:
529 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
530 this->GetTensor());
531 break;
532 case arm_compute::DataType::S16:
533 case arm_compute::DataType::QSYMM16:
534 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
535 this->GetTensor());
536 break;
537 case arm_compute::DataType::S32:
538 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
539 this->GetTensor());
540 break;
541 default:
542 {
543 throw armnn::UnimplementedException();
544 }
545 }
546 this->Unmap();
547 }
548
549 mutable ClTensorDecorator m_Tensor;
550 IClTensorHandle* m_OriginalHandle = nullptr;
551};
552
553} // namespace armnn
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual void Manage() override
Indicate to the memory manager that this resource is active.
arm_compute::CLSubTensor const & GetTensor() const override
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
ClSubTensorHandle(IClTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
arm_compute::CLSubTensor & GetTensor() override
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual arm_compute::DataType GetDataType() const override
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
void map(bool blocking=true)
ClTensorDecorator & operator=(ClTensorDecorator &&)=default
uint8_t * do_map(cl::CommandQueue &q, bool blocking) override
void do_unmap(cl::CommandQueue &q) override
const cl::Buffer & cl_buffer() const override
arm_compute::CLQuantization quantization() const override
ClTensorDecorator(const ClTensorDecorator &)=delete
ClTensorDecorator & operator=(const ClTensorDecorator &)=delete
ClTensorDecorator(ClTensorDecorator &&)=default
arm_compute::ICLTensor * parent()
virtual arm_compute::ITensorInfo * info() const override
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual void Manage() override
Indicate to the memory manager that this resource is active.
arm_compute::ICLTensor & GetTensor() override
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
ClTensorHandleDecorator(IClTensorHandle *parent, const TensorInfo &info)
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
arm_compute::ICLTensor const & GetTensor() const override
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual arm_compute::DataType GetDataType() const override
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
arm_compute::CLTensor const & GetTensor() const override
virtual void Manage() override
Indicate to the memory manager that this resource is active.
arm_compute::CLTensor & GetTensor() override
virtual bool Import(void *memory, MemorySource source) override
Import externally allocated memory.
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
ClTensorHandle(const TensorInfo &tensorInfo)
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
virtual std::shared_ptr< ITensorHandle > DecorateTensorHandle(const TensorInfo &tensorInfo) override
Returns a decorated version of this TensorHandle allowing us to override the TensorInfo for it.
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
void SetImportEnabledFlag(bool importEnabledFlag)
MemorySourceFlags GetImportFlags() const override
Get flags describing supported import sources.
virtual bool CanBeImported(void *memory, MemorySource source) override
Implementations must determine if this memory block can be imported.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
void SetImportFlags(MemorySourceFlags importFlags)
ClTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags=static_cast< MemorySourceFlags >(MemorySource::Undefined))
virtual arm_compute::DataType GetDataType() const override
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
Copyright (c) 2021 ARM Limited and Contributors.
half_float::half Half
Definition Half.hpp:22
MemorySource
Define the Memory Source to reduce copies.
Definition Types.hpp:246
auto PolymorphicPointerDowncast(const SourceType &value)
Polymorphic downcast for shared pointers and build in pointers.
unsigned int MemorySourceFlags
DataLayout
Definition Types.hpp:63
void IgnoreUnused(Ts &&...)