16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
26 class NeonTensorHandleDecorator;
34 m_IsImportEnabled(false),
37 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
43 : m_ImportFlags(importFlags),
45 m_IsImportEnabled(false),
50 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
53 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
54 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
59 if (!m_IsImportEnabled)
61 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
68 if (!m_IsImportEnabled)
71 m_MemoryGroup->manage(&m_Tensor);
79 return m_Tensor.info()->data_type();
82 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
84 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
87 virtual const void*
Map(
bool )
const override
89 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
92 virtual void Unmap()
const override {}
96 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
101 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
106 m_ImportFlags = importFlags;
111 return m_ImportFlags;
116 m_IsImportEnabled = importEnabledFlag;
140 if (!m_Imported && !m_Tensor.buffer())
145 m_Imported = bool(status);
154 if (!m_Imported && m_Tensor.buffer())
157 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
166 m_Imported = bool(status);
190 void CopyOutTo(
void* memory)
const override
194 case arm_compute::DataType::F32:
195 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
196 static_cast<float*
>(memory));
198 case arm_compute::DataType::U8:
199 case arm_compute::DataType::QASYMM8:
200 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
201 static_cast<uint8_t*
>(memory));
203 case arm_compute::DataType::QSYMM8:
204 case arm_compute::DataType::QASYMM8_SIGNED:
205 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
206 static_cast<int8_t*
>(memory));
208 case arm_compute::DataType::BFLOAT16:
209 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
212 case arm_compute::DataType::F16:
213 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
216 case arm_compute::DataType::S16:
217 case arm_compute::DataType::QSYMM16:
218 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
219 static_cast<int16_t*
>(memory));
221 case arm_compute::DataType::S32:
222 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
223 static_cast<int32_t*
>(memory));
233 void CopyInFrom(
const void* memory)
override
237 case arm_compute::DataType::F32:
238 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
241 case arm_compute::DataType::U8:
242 case arm_compute::DataType::QASYMM8:
243 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
246 case arm_compute::DataType::QSYMM8:
247 case arm_compute::DataType::QASYMM8_SIGNED:
248 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
249 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
252 case arm_compute::DataType::BFLOAT16:
253 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::BFloat16*
>(memory),
256 case arm_compute::DataType::F16:
257 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
260 case arm_compute::DataType::S16:
261 case arm_compute::DataType::QSYMM16:
262 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
265 case arm_compute::DataType::S32:
266 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
276 arm_compute::Tensor m_Tensor;
277 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
280 bool m_IsImportEnabled;
281 const uintptr_t m_TypeAlignment;
282 std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
289 const arm_compute::TensorShape& shape,
291 : m_Tensor(&parent->
GetTensor(), shape, coords, true)
293 parentHandle = parent;
296 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
297 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
306 return m_Tensor.info()->data_type();
309 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
311 virtual const void*
Map(
bool )
const override
313 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
315 virtual void Unmap()
const override {}
319 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
324 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
334 void CopyOutTo(
void* memory)
const override
338 case arm_compute::DataType::F32:
339 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
340 static_cast<float*
>(memory));
342 case arm_compute::DataType::U8:
343 case arm_compute::DataType::QASYMM8:
344 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
345 static_cast<uint8_t*
>(memory));
347 case arm_compute::DataType::QSYMM8:
348 case arm_compute::DataType::QASYMM8_SIGNED:
349 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
350 static_cast<int8_t*
>(memory));
352 case arm_compute::DataType::S16:
353 case arm_compute::DataType::QSYMM16:
354 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
355 static_cast<int16_t*
>(memory));
357 case arm_compute::DataType::S32:
358 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
359 static_cast<int32_t*
>(memory));
369 void CopyInFrom(
const void* memory)
override
373 case arm_compute::DataType::F32:
374 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
377 case arm_compute::DataType::U8:
378 case arm_compute::DataType::QASYMM8:
379 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
382 case arm_compute::DataType::QSYMM8:
383 case arm_compute::DataType::QASYMM8_SIGNED:
384 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
387 case arm_compute::DataType::S16:
388 case arm_compute::DataType::QSYMM16:
389 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
392 case arm_compute::DataType::S32:
393 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
403 arm_compute::SubTensor m_Tensor;
404 ITensorHandle* parentHandle =
nullptr;
426 arm_compute::ITensorInfo*
info()
const override;
428 arm_compute::ITensorInfo*
info()
override;
430 uint8_t*
buffer()
const override;
433 arm_compute::ITensor* m_Original;
434 mutable arm_compute::TensorInfo m_TensorInfo;
443 parentHandle = parent;
446 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
447 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
456 return m_Tensor.
info()->data_type();
459 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
461 virtual const void*
Map(
bool )
const override
463 return static_cast<const void*
>(m_Tensor.
buffer() + m_Tensor.
info()->offset_first_element_in_bytes());
465 virtual void Unmap()
const override {}
469 return armcomputetensorutils::GetStrides(m_Tensor.
info()->strides_in_bytes());
474 return armcomputetensorutils::GetShape(m_Tensor.
info()->tensor_shape());
484 void CopyOutTo(
void* memory)
const override
488 case arm_compute::DataType::F32:
489 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
490 static_cast<float*
>(memory));
492 case arm_compute::DataType::U8:
493 case arm_compute::DataType::QASYMM8:
494 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
495 static_cast<uint8_t*
>(memory));
497 case arm_compute::DataType::QSYMM8:
498 case arm_compute::DataType::QASYMM8_SIGNED:
499 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
500 static_cast<int8_t*
>(memory));
502 case arm_compute::DataType::S16:
503 case arm_compute::DataType::QSYMM16:
504 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
505 static_cast<int16_t*
>(memory));
507 case arm_compute::DataType::S32:
508 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
509 static_cast<int32_t*
>(memory));
519 void CopyInFrom(
const void* memory)
override
523 case arm_compute::DataType::F32:
524 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
527 case arm_compute::DataType::U8:
528 case arm_compute::DataType::QASYMM8:
529 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
532 case arm_compute::DataType::QSYMM8:
533 case arm_compute::DataType::QASYMM8_SIGNED:
534 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
537 case arm_compute::DataType::S16:
538 case arm_compute::DataType::QSYMM16:
539 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
542 case arm_compute::DataType::S32:
543 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
553 NeonTensorDecorator m_Tensor;
554 ITensorHandle* parentHandle =
nullptr;