17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/runtime/IMemoryGroup.h>
19 #include <arm_compute/runtime/Tensor.h>
20 #include <arm_compute/runtime/SubTensor.h>
21 #include <arm_compute/core/TensorShape.h>
22 #include <arm_compute/core/Coordinates.h>
27 class NeonTensorHandleDecorator;
35 m_IsImportEnabled(false),
38 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 : m_ImportFlags(importFlags),
46 m_IsImportEnabled(false),
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
54 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
55 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
60 if (!m_IsImportEnabled)
62 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
69 if (!m_IsImportEnabled)
72 m_MemoryGroup->manage(&m_Tensor);
80 return m_Tensor.info()->data_type();
83 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
85 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
88 virtual const void*
Map(
bool )
const override
90 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
93 virtual void Unmap()
const override {}
97 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
102 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
107 m_ImportFlags = importFlags;
112 return m_ImportFlags;
117 m_IsImportEnabled = importEnabledFlag;
141 if (!m_Imported && !m_Tensor.buffer())
146 m_Imported = bool(status);
155 if (!m_Imported && m_Tensor.buffer())
158 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
167 m_Imported = bool(status);
191 void CopyOutTo(
void* memory)
const override
195 case arm_compute::DataType::F32:
196 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
197 static_cast<float*
>(memory));
199 case arm_compute::DataType::U8:
200 case arm_compute::DataType::QASYMM8:
201 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
202 static_cast<uint8_t*
>(memory));
204 case arm_compute::DataType::QSYMM8:
205 case arm_compute::DataType::QASYMM8_SIGNED:
206 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
207 static_cast<int8_t*
>(memory));
209 case arm_compute::DataType::BFLOAT16:
210 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
213 case arm_compute::DataType::F16:
214 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
217 case arm_compute::DataType::S16:
218 case arm_compute::DataType::QSYMM16:
219 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
220 static_cast<int16_t*
>(memory));
222 case arm_compute::DataType::S32:
223 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
224 static_cast<int32_t*
>(memory));
234 void CopyInFrom(
const void* memory)
override
238 case arm_compute::DataType::F32:
239 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
242 case arm_compute::DataType::U8:
243 case arm_compute::DataType::QASYMM8:
244 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
247 case arm_compute::DataType::QSYMM8:
248 case arm_compute::DataType::QASYMM8_SIGNED:
249 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
250 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
253 case arm_compute::DataType::BFLOAT16:
254 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::BFloat16*
>(memory),
257 case arm_compute::DataType::F16:
258 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
261 case arm_compute::DataType::S16:
262 case arm_compute::DataType::QSYMM16:
263 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
266 case arm_compute::DataType::S32:
267 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
277 arm_compute::Tensor m_Tensor;
278 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
281 bool m_IsImportEnabled;
282 const uintptr_t m_TypeAlignment;
283 std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
290 const arm_compute::TensorShape& shape,
292 : m_Tensor(&parent->
GetTensor(), shape, coords, true)
294 parentHandle = parent;
297 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
298 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
307 return m_Tensor.info()->data_type();
310 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
312 virtual const void*
Map(
bool )
const override
314 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
316 virtual void Unmap()
const override {}
320 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
325 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
335 void CopyOutTo(
void* memory)
const override
339 case arm_compute::DataType::F32:
340 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
341 static_cast<float*
>(memory));
343 case arm_compute::DataType::U8:
344 case arm_compute::DataType::QASYMM8:
345 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
346 static_cast<uint8_t*
>(memory));
348 case arm_compute::DataType::QSYMM8:
349 case arm_compute::DataType::QASYMM8_SIGNED:
350 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
351 static_cast<int8_t*
>(memory));
353 case arm_compute::DataType::S16:
354 case arm_compute::DataType::QSYMM16:
355 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
356 static_cast<int16_t*
>(memory));
358 case arm_compute::DataType::S32:
359 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
360 static_cast<int32_t*
>(memory));
370 void CopyInFrom(
const void* memory)
override
374 case arm_compute::DataType::F32:
375 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
378 case arm_compute::DataType::U8:
379 case arm_compute::DataType::QASYMM8:
380 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
383 case arm_compute::DataType::QSYMM8:
384 case arm_compute::DataType::QASYMM8_SIGNED:
385 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
388 case arm_compute::DataType::S16:
389 case arm_compute::DataType::QSYMM16:
390 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
393 case arm_compute::DataType::S32:
394 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
404 arm_compute::SubTensor m_Tensor;
405 ITensorHandle* parentHandle =
nullptr;
427 arm_compute::ITensorInfo*
info()
const override;
429 arm_compute::ITensorInfo*
info()
override;
431 uint8_t*
buffer()
const override;
434 arm_compute::ITensor* m_Original;
435 mutable arm_compute::TensorInfo m_TensorInfo;
444 parentHandle = parent;
447 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
448 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
457 return m_Tensor.
info()->data_type();
460 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
462 virtual const void*
Map(
bool )
const override
464 return static_cast<const void*
>(m_Tensor.
buffer() + m_Tensor.
info()->offset_first_element_in_bytes());
466 virtual void Unmap()
const override {}
470 return armcomputetensorutils::GetStrides(m_Tensor.
info()->strides_in_bytes());
475 return armcomputetensorutils::GetShape(m_Tensor.
info()->tensor_shape());
485 void CopyOutTo(
void* memory)
const override
489 case arm_compute::DataType::F32:
490 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
491 static_cast<float*
>(memory));
493 case arm_compute::DataType::U8:
494 case arm_compute::DataType::QASYMM8:
495 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
496 static_cast<uint8_t*
>(memory));
498 case arm_compute::DataType::QSYMM8:
499 case arm_compute::DataType::QASYMM8_SIGNED:
500 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
501 static_cast<int8_t*
>(memory));
503 case arm_compute::DataType::S16:
504 case arm_compute::DataType::QSYMM16:
505 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
506 static_cast<int16_t*
>(memory));
508 case arm_compute::DataType::S32:
509 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
510 static_cast<int32_t*
>(memory));
520 void CopyInFrom(
const void* memory)
override
524 case arm_compute::DataType::F32:
525 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
528 case arm_compute::DataType::U8:
529 case arm_compute::DataType::QASYMM8:
530 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
533 case arm_compute::DataType::QSYMM8:
534 case arm_compute::DataType::QASYMM8_SIGNED:
535 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
538 case arm_compute::DataType::S16:
539 case arm_compute::DataType::QSYMM16:
540 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
543 case arm_compute::DataType::S32:
544 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
554 NeonTensorDecorator m_Tensor;
555 ITensorHandle* parentHandle =
nullptr;