ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
NeonTensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <BFloat16.hpp>
9 #include <Half.hpp>
10 
13 #include <armnn/Exceptions.hpp>
15 
16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
22 #include "armnn/TypesUtils.hpp"
23 
24 namespace armnn
25 {
26 class NeonTensorHandleDecorator;
27 
29 {
30 public:
31  NeonTensorHandle(const TensorInfo& tensorInfo)
32  : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
33  m_Imported(false),
34  m_IsImportEnabled(false),
35  m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
36  {
37  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
38  }
39 
40  NeonTensorHandle(const TensorInfo& tensorInfo,
41  DataLayout dataLayout,
42  MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
43  : m_ImportFlags(importFlags),
44  m_Imported(false),
45  m_IsImportEnabled(false),
46  m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
47 
48 
49  {
50  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
51  }
52 
53  arm_compute::ITensor& GetTensor() override { return m_Tensor; }
54  arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
55 
56  virtual void Allocate() override
57  {
58  // If we have enabled Importing, don't Allocate the tensor
59  if (!m_IsImportEnabled)
60  {
61  armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
62  }
63  };
64 
65  virtual void Manage() override
66  {
67  // If we have enabled Importing, don't manage the tensor
68  if (!m_IsImportEnabled)
69  {
70  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(m_MemoryGroup, "arm_compute::MemoryGroup is null.");
71  m_MemoryGroup->manage(&m_Tensor);
72  }
73  }
74 
75  virtual ITensorHandle* GetParent() const override { return nullptr; }
76 
77  virtual arm_compute::DataType GetDataType() const override
78  {
79  return m_Tensor.info()->data_type();
80  }
81 
82  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
83  {
84  m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
85  }
86 
87  virtual const void* Map(bool /* blocking = true */) const override
88  {
89  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
90  }
91 
92  virtual void Unmap() const override {}
93 
94  TensorShape GetStrides() const override
95  {
96  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
97  }
98 
99  TensorShape GetShape() const override
100  {
101  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
102  }
103 
105  {
106  m_ImportFlags = importFlags;
107  }
108 
110  {
111  return m_ImportFlags;
112  }
113 
114  void SetImportEnabledFlag(bool importEnabledFlag)
115  {
116  m_IsImportEnabled = importEnabledFlag;
117  }
118 
119  bool CanBeImported(void* memory, MemorySource source) override
120  {
121  if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
122  {
123  return false;
124  }
125  return true;
126  }
127 
128  virtual bool Import(void* memory, MemorySource source) override
129  {
130  if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
131  {
132  if (source == MemorySource::Malloc && m_IsImportEnabled)
133  {
134  if (!CanBeImported(memory, source))
135  {
136  throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
137  }
138 
139  // m_Tensor not yet Allocated
140  if (!m_Imported && !m_Tensor.buffer())
141  {
142  arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
143  // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
144  // with the Status error message
145  m_Imported = bool(status);
146  if (!m_Imported)
147  {
148  throw MemoryImportException(status.error_description());
149  }
150  return m_Imported;
151  }
152 
153  // m_Tensor.buffer() initially allocated with Allocate().
154  if (!m_Imported && m_Tensor.buffer())
155  {
156  throw MemoryImportException(
157  "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
158  }
159 
160  // m_Tensor.buffer() previously imported.
161  if (m_Imported)
162  {
163  arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
164  // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
165  // with the Status error message
166  m_Imported = bool(status);
167  if (!m_Imported)
168  {
169  throw MemoryImportException(status.error_description());
170  }
171  return m_Imported;
172  }
173  }
174  else
175  {
176  throw MemoryImportException("NeonTensorHandle::Import is disabled");
177  }
178  }
179  else
180  {
181  throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
182  }
183  return false;
184  }
185 
186  virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
187 
188 private:
189  // Only used for testing
190  void CopyOutTo(void* memory) const override
191  {
192  switch (this->GetDataType())
193  {
194  case arm_compute::DataType::F32:
195  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
196  static_cast<float*>(memory));
197  break;
198  case arm_compute::DataType::U8:
199  case arm_compute::DataType::QASYMM8:
200  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
201  static_cast<uint8_t*>(memory));
202  break;
203  case arm_compute::DataType::QSYMM8:
204  case arm_compute::DataType::QASYMM8_SIGNED:
205  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206  static_cast<int8_t*>(memory));
207  break;
208  case arm_compute::DataType::BFLOAT16:
209  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
210  static_cast<armnn::BFloat16*>(memory));
211  break;
212  case arm_compute::DataType::F16:
213  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
214  static_cast<armnn::Half*>(memory));
215  break;
216  case arm_compute::DataType::S16:
217  case arm_compute::DataType::QSYMM16:
218  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
219  static_cast<int16_t*>(memory));
220  break;
221  case arm_compute::DataType::S32:
222  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
223  static_cast<int32_t*>(memory));
224  break;
225  case arm_compute::DataType::S64:
226  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
227  static_cast<int64_t*>(memory));
228  break;
229  default:
230  {
232  }
233  }
234  }
235 
236  // Only used for testing
237  void CopyInFrom(const void* memory) override
238  {
239  switch (this->GetDataType())
240  {
241  case arm_compute::DataType::F32:
242  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
243  this->GetTensor());
244  break;
245  case arm_compute::DataType::U8:
246  case arm_compute::DataType::QASYMM8:
247  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
248  this->GetTensor());
249  break;
250  case arm_compute::DataType::QSYMM8:
251  case arm_compute::DataType::QASYMM8_SIGNED:
252  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
253  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
254  this->GetTensor());
255  break;
256  case arm_compute::DataType::BFLOAT16:
257  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
258  this->GetTensor());
259  break;
260  case arm_compute::DataType::F16:
261  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
262  this->GetTensor());
263  break;
264  case arm_compute::DataType::S16:
265  case arm_compute::DataType::QSYMM16:
266  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
267  this->GetTensor());
268  break;
269  case arm_compute::DataType::S32:
270  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
271  this->GetTensor());
272  break;
273  case arm_compute::DataType::S64:
274  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int64_t*>(memory),
275  this->GetTensor());
276  break;
277  default:
278  {
280  }
281  }
282  }
283 
284  arm_compute::Tensor m_Tensor;
285  std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
286  MemorySourceFlags m_ImportFlags;
287  bool m_Imported;
288  bool m_IsImportEnabled;
289  const uintptr_t m_TypeAlignment;
290  std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
291 };
292 
294 {
295 public:
297  const arm_compute::TensorShape& shape,
298  const arm_compute::Coordinates& coords)
299  : m_Tensor(&parent->GetTensor(), shape, coords, true)
300  {
301  parentHandle = parent;
302  }
303 
304  arm_compute::ITensor& GetTensor() override { return m_Tensor; }
305  arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
306 
307  virtual void Allocate() override {}
308  virtual void Manage() override {}
309 
310  virtual ITensorHandle* GetParent() const override { return parentHandle; }
311 
312  virtual arm_compute::DataType GetDataType() const override
313  {
314  return m_Tensor.info()->data_type();
315  }
316 
317  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
318 
319  virtual const void* Map(bool /* blocking = true */) const override
320  {
321  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
322  }
323  virtual void Unmap() const override {}
324 
325  TensorShape GetStrides() const override
326  {
327  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
328  }
329 
330  TensorShape GetShape() const override
331  {
332  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
333  }
334 
335  virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
336  {
337  return nullptr;
338  };
339 
340 private:
341  // Only used for testing
342  void CopyOutTo(void* memory) const override
343  {
344  switch (this->GetDataType())
345  {
346  case arm_compute::DataType::F32:
347  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
348  static_cast<float*>(memory));
349  break;
350  case arm_compute::DataType::U8:
351  case arm_compute::DataType::QASYMM8:
352  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
353  static_cast<uint8_t*>(memory));
354  break;
355  case arm_compute::DataType::QSYMM8:
356  case arm_compute::DataType::QASYMM8_SIGNED:
357  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
358  static_cast<int8_t*>(memory));
359  break;
360  case arm_compute::DataType::S16:
361  case arm_compute::DataType::QSYMM16:
362  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
363  static_cast<int16_t*>(memory));
364  break;
365  case arm_compute::DataType::S32:
366  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
367  static_cast<int32_t*>(memory));
368  break;
369  default:
370  {
372  }
373  }
374  }
375 
376  // Only used for testing
377  void CopyInFrom(const void* memory) override
378  {
379  switch (this->GetDataType())
380  {
381  case arm_compute::DataType::F32:
382  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
383  this->GetTensor());
384  break;
385  case arm_compute::DataType::U8:
386  case arm_compute::DataType::QASYMM8:
387  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
388  this->GetTensor());
389  break;
390  case arm_compute::DataType::QSYMM8:
391  case arm_compute::DataType::QASYMM8_SIGNED:
392  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
393  this->GetTensor());
394  break;
395  case arm_compute::DataType::S16:
396  case arm_compute::DataType::QSYMM16:
397  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
398  this->GetTensor());
399  break;
400  case arm_compute::DataType::S32:
401  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
402  this->GetTensor());
403  break;
404  default:
405  {
407  }
408  }
409  }
410 
411  arm_compute::SubTensor m_Tensor;
412  ITensorHandle* parentHandle = nullptr;
413 };
414 
415 /// NeonTensorDecorator wraps an existing Neon tensor allowing us to override the TensorInfo for it
416 class NeonTensorDecorator : public arm_compute::ITensor
417 {
418 public:
420 
421  NeonTensorDecorator(arm_compute::ITensor* original, const TensorInfo& info);
422 
423  ~NeonTensorDecorator() = default;
424 
426 
428 
430 
432 
433  // Inherited methods overridden:
434  arm_compute::ITensorInfo* info() const override;
435 
436  arm_compute::ITensorInfo* info() override;
437 
438  uint8_t* buffer() const override;
439 
440 private:
441  arm_compute::ITensor* m_Original;
442  mutable arm_compute::TensorInfo m_TensorInfo;
443 };
444 
446 {
447 public:
449  : m_Tensor(&parent->GetTensor(), info)
450  {
451  parentHandle = parent;
452  }
453 
454  arm_compute::ITensor& GetTensor() override { return m_Tensor; }
455  arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
456 
457  virtual void Allocate() override {}
458  virtual void Manage() override {}
459 
460  virtual ITensorHandle* GetParent() const override { return nullptr; }
461 
462  virtual arm_compute::DataType GetDataType() const override
463  {
464  return m_Tensor.info()->data_type();
465  }
466 
467  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
468 
469  virtual const void* Map(bool /* blocking = true */) const override
470  {
471  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
472  }
473  virtual void Unmap() const override {}
474 
475  TensorShape GetStrides() const override
476  {
477  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
478  }
479 
480  TensorShape GetShape() const override
481  {
482  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
483  }
484 
485  virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
486  {
487  return nullptr;
488  };
489 
490 private:
491  // Only used for testing
492  void CopyOutTo(void* memory) const override
493  {
494  switch (this->GetDataType())
495  {
496  case arm_compute::DataType::F32:
497  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
498  static_cast<float*>(memory));
499  break;
500  case arm_compute::DataType::U8:
501  case arm_compute::DataType::QASYMM8:
502  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
503  static_cast<uint8_t*>(memory));
504  break;
505  case arm_compute::DataType::QSYMM8:
506  case arm_compute::DataType::QASYMM8_SIGNED:
507  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
508  static_cast<int8_t*>(memory));
509  break;
510  case arm_compute::DataType::S16:
511  case arm_compute::DataType::QSYMM16:
512  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
513  static_cast<int16_t*>(memory));
514  break;
515  case arm_compute::DataType::S32:
516  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
517  static_cast<int32_t*>(memory));
518  break;
519  default:
520  {
522  }
523  }
524  }
525 
526  // Only used for testing
527  void CopyInFrom(const void* memory) override
528  {
529  switch (this->GetDataType())
530  {
531  case arm_compute::DataType::F32:
532  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
533  this->GetTensor());
534  break;
535  case arm_compute::DataType::U8:
536  case arm_compute::DataType::QASYMM8:
537  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
538  this->GetTensor());
539  break;
540  case arm_compute::DataType::QSYMM8:
541  case arm_compute::DataType::QASYMM8_SIGNED:
542  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
543  this->GetTensor());
544  break;
545  case arm_compute::DataType::S16:
546  case arm_compute::DataType::QSYMM16:
547  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
548  this->GetTensor());
549  break;
550  case arm_compute::DataType::S32:
551  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
552  this->GetTensor());
553  break;
554  default:
555  {
557  }
558  }
559  }
560 
561  NeonTensorDecorator m_Tensor;
562  ITensorHandle* parentHandle = nullptr;
563 };
564 
565 
566 } // namespace armnn
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
Definition: Exceptions.hpp:210
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual void Manage() override
Indicate to the memory manager that this resource is active.
arm_compute::ITensor const & GetTensor() const override
virtual std::shared_ptr< ITensorHandle > DecorateTensorHandle(const TensorInfo &) override
Returns a decorated version of this TensorHandle allowing us to override the TensorInfo for it.
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
NeonSubTensorHandle(IAclTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
arm_compute::ITensor & GetTensor() override
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual const void * Map(bool) const override
Map the tensor data for access.
virtual arm_compute::DataType GetDataType() const override
NeonTensorDecorator wraps an existing Neon tensor allowing us to override the TensorInfo for it.
uint8_t * buffer() const override
NeonTensorDecorator & operator=(const NeonTensorDecorator &)=delete
NeonTensorDecorator & operator=(NeonTensorDecorator &&)=default
NeonTensorDecorator(const NeonTensorDecorator &)=delete
NeonTensorDecorator(NeonTensorDecorator &&)=default
arm_compute::ITensorInfo * info() const override
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual void Manage() override
Indicate to the memory manager that this resource is active.
arm_compute::ITensor const & GetTensor() const override
virtual std::shared_ptr< ITensorHandle > DecorateTensorHandle(const TensorInfo &) override
Returns a decorated version of this TensorHandle allowing us to override the TensorInfo for it.
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
arm_compute::ITensor & GetTensor() override
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual const void * Map(bool) const override
Map the tensor data for access.
virtual arm_compute::DataType GetDataType() const override
NeonTensorHandleDecorator(IAclTensorHandle *parent, const TensorInfo &info)
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual bool Import(void *memory, MemorySource source) override
Import externally allocated memory.
arm_compute::ITensor const & GetTensor() const override
virtual void Unmap() const override
Unmap the tensor data.
NeonTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags=static_cast< MemorySourceFlags >(MemorySource::Malloc))
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
virtual std::shared_ptr< ITensorHandle > DecorateTensorHandle(const TensorInfo &tensorInfo) override
Returns a decorated version of this TensorHandle allowing us to override the TensorInfo for it.
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
void SetImportEnabledFlag(bool importEnabledFlag)
MemorySourceFlags GetImportFlags() const override
Get flags describing supported import sources.
arm_compute::ITensor & GetTensor() override
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual const void * Map(bool) const override
Map the tensor data for access.
void SetImportFlags(MemorySourceFlags importFlags)
virtual arm_compute::DataType GetDataType() const override
bool CanBeImported(void *memory, MemorySource source) override
Implementations must determine if this memory block can be imported.
NeonTensorHandle(const TensorInfo &tensorInfo)
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
Copyright (c) 2021 ARM Limited and Contributors.
half_float::half Half
Definition: Half.hpp:22
MemorySource
Define the Memory Source to reduce copies.
Definition: Types.hpp:246
unsigned int MemorySourceFlags
Status
enumeration
Definition: Types.hpp:43
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:183
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
DataLayout
Definition: Types.hpp:63
DataType
Definition: Types.hpp:49