16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
26 class NeonTensorHandleDecorator;
34 m_IsImportEnabled(false),
37 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
43 : m_ImportFlags(importFlags),
45 m_IsImportEnabled(false),
50 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
53 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
54 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
59 if (!m_IsImportEnabled)
61 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
68 if (!m_IsImportEnabled)
71 m_MemoryGroup->manage(&m_Tensor);
79 return m_Tensor.info()->data_type();
82 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
84 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
87 virtual const void*
Map(
bool )
const override
89 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
92 virtual void Unmap()
const override {}
96 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
101 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
106 m_ImportFlags = importFlags;
111 return m_ImportFlags;
116 m_IsImportEnabled = importEnabledFlag;
140 if (!m_Imported && !m_Tensor.buffer())
145 m_Imported = bool(status);
154 if (!m_Imported && m_Tensor.buffer())
157 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
166 m_Imported = bool(status);
190 void CopyOutTo(
void* memory)
const override
194 case arm_compute::DataType::F32:
195 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
196 static_cast<float*
>(memory));
198 case arm_compute::DataType::U8:
199 case arm_compute::DataType::QASYMM8:
200 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
201 static_cast<uint8_t*
>(memory));
203 case arm_compute::DataType::QSYMM8:
204 case arm_compute::DataType::QASYMM8_SIGNED:
205 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
206 static_cast<int8_t*
>(memory));
208 case arm_compute::DataType::BFLOAT16:
209 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
212 case arm_compute::DataType::F16:
213 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
216 case arm_compute::DataType::S16:
217 case arm_compute::DataType::QSYMM16:
218 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
219 static_cast<int16_t*
>(memory));
221 case arm_compute::DataType::S32:
222 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
223 static_cast<int32_t*
>(memory));
225 case arm_compute::DataType::S64:
226 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
227 static_cast<int64_t*
>(memory));
237 void CopyInFrom(
const void* memory)
override
241 case arm_compute::DataType::F32:
242 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
245 case arm_compute::DataType::U8:
246 case arm_compute::DataType::QASYMM8:
247 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
250 case arm_compute::DataType::QSYMM8:
251 case arm_compute::DataType::QASYMM8_SIGNED:
252 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
253 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
256 case arm_compute::DataType::BFLOAT16:
257 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::BFloat16*
>(memory),
260 case arm_compute::DataType::F16:
261 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
264 case arm_compute::DataType::S16:
265 case arm_compute::DataType::QSYMM16:
266 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
269 case arm_compute::DataType::S32:
270 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
273 case arm_compute::DataType::S64:
274 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int64_t*
>(memory),
284 arm_compute::Tensor m_Tensor;
285 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
288 bool m_IsImportEnabled;
289 const uintptr_t m_TypeAlignment;
290 std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
297 const arm_compute::TensorShape& shape,
299 : m_Tensor(&parent->
GetTensor(), shape, coords, true)
301 parentHandle = parent;
304 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
305 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
314 return m_Tensor.info()->data_type();
317 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
319 virtual const void*
Map(
bool )
const override
321 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
323 virtual void Unmap()
const override {}
327 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
332 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
342 void CopyOutTo(
void* memory)
const override
346 case arm_compute::DataType::F32:
347 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
348 static_cast<float*
>(memory));
350 case arm_compute::DataType::U8:
351 case arm_compute::DataType::QASYMM8:
352 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
353 static_cast<uint8_t*
>(memory));
355 case arm_compute::DataType::QSYMM8:
356 case arm_compute::DataType::QASYMM8_SIGNED:
357 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
358 static_cast<int8_t*
>(memory));
360 case arm_compute::DataType::S16:
361 case arm_compute::DataType::QSYMM16:
362 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
363 static_cast<int16_t*
>(memory));
365 case arm_compute::DataType::S32:
366 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
367 static_cast<int32_t*
>(memory));
377 void CopyInFrom(
const void* memory)
override
381 case arm_compute::DataType::F32:
382 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
385 case arm_compute::DataType::U8:
386 case arm_compute::DataType::QASYMM8:
387 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
390 case arm_compute::DataType::QSYMM8:
391 case arm_compute::DataType::QASYMM8_SIGNED:
392 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
395 case arm_compute::DataType::S16:
396 case arm_compute::DataType::QSYMM16:
397 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
400 case arm_compute::DataType::S32:
401 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
411 arm_compute::SubTensor m_Tensor;
412 ITensorHandle* parentHandle =
nullptr;
434 arm_compute::ITensorInfo*
info()
const override;
436 arm_compute::ITensorInfo*
info()
override;
438 uint8_t*
buffer()
const override;
441 arm_compute::ITensor* m_Original;
442 mutable arm_compute::TensorInfo m_TensorInfo;
451 parentHandle = parent;
454 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
455 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
464 return m_Tensor.
info()->data_type();
467 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
469 virtual const void*
Map(
bool )
const override
471 return static_cast<const void*
>(m_Tensor.
buffer() + m_Tensor.
info()->offset_first_element_in_bytes());
473 virtual void Unmap()
const override {}
477 return armcomputetensorutils::GetStrides(m_Tensor.
info()->strides_in_bytes());
482 return armcomputetensorutils::GetShape(m_Tensor.
info()->tensor_shape());
492 void CopyOutTo(
void* memory)
const override
496 case arm_compute::DataType::F32:
497 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
498 static_cast<float*
>(memory));
500 case arm_compute::DataType::U8:
501 case arm_compute::DataType::QASYMM8:
502 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
503 static_cast<uint8_t*
>(memory));
505 case arm_compute::DataType::QSYMM8:
506 case arm_compute::DataType::QASYMM8_SIGNED:
507 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
508 static_cast<int8_t*
>(memory));
510 case arm_compute::DataType::S16:
511 case arm_compute::DataType::QSYMM16:
512 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
513 static_cast<int16_t*
>(memory));
515 case arm_compute::DataType::S32:
516 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
517 static_cast<int32_t*
>(memory));
527 void CopyInFrom(
const void* memory)
override
531 case arm_compute::DataType::F32:
532 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
535 case arm_compute::DataType::U8:
536 case arm_compute::DataType::QASYMM8:
537 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
540 case arm_compute::DataType::QSYMM8:
541 case arm_compute::DataType::QASYMM8_SIGNED:
542 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
545 case arm_compute::DataType::S16:
546 case arm_compute::DataType::QSYMM16:
547 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
550 case arm_compute::DataType::S32:
551 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
561 NeonTensorDecorator m_Tensor;
562 ITensorHandle* parentHandle =
nullptr;