15 #include <arm_compute/runtime/CL/CLTensor.h>
16 #include <arm_compute/runtime/CL/CLSubTensor.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/MemoryGroup.h>
19 #include <arm_compute/core/TensorShape.h>
20 #include <arm_compute/core/Coordinates.h>
26 class ClTensorHandleDecorator;
34 m_IsImportEnabled(false)
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
42 : m_ImportFlags(importFlags),
44 m_IsImportEnabled(false)
46 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49 arm_compute::CLTensor&
GetTensor()
override {
return m_Tensor; }
50 arm_compute::CLTensor
const&
GetTensor()
const override {
return m_Tensor; }
54 if (m_IsImportEnabled)
60 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
68 if (m_IsImportEnabled)
74 assert(m_MemoryGroup !=
nullptr);
75 m_MemoryGroup->manage(&m_Tensor);
79 virtual const void*
Map(
bool blocking =
true)
const override
81 const_cast<arm_compute::CLTensor*
>(&m_Tensor)->map(blocking);
82 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85 virtual void Unmap()
const override {
const_cast<arm_compute::CLTensor*
>(&m_Tensor)->unmap(); }
91 return m_Tensor.info()->data_type();
94 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
96 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
101 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
106 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
111 m_ImportFlags = importFlags;
116 return m_ImportFlags;
121 m_IsImportEnabled = importEnabledFlag;
146 void CopyOutTo(
void* memory)
const override
151 case arm_compute::DataType::F32:
152 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
153 static_cast<float*
>(memory));
155 case arm_compute::DataType::U8:
156 case arm_compute::DataType::QASYMM8:
157 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
158 static_cast<uint8_t*
>(memory));
160 case arm_compute::DataType::QSYMM8:
161 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
162 case arm_compute::DataType::QASYMM8_SIGNED:
163 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
164 static_cast<int8_t*
>(memory));
166 case arm_compute::DataType::F16:
167 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
170 case arm_compute::DataType::S16:
171 case arm_compute::DataType::QSYMM16:
172 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
173 static_cast<int16_t*
>(memory));
175 case arm_compute::DataType::S32:
176 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
177 static_cast<int32_t*
>(memory));
179 case arm_compute::DataType::S64:
180 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
181 static_cast<int64_t*
>(memory));
193 void CopyInFrom(
const void* memory)
override
198 case arm_compute::DataType::F32:
199 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
202 case arm_compute::DataType::U8:
203 case arm_compute::DataType::QASYMM8:
204 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
207 case arm_compute::DataType::F16:
208 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
211 case arm_compute::DataType::S16:
212 case arm_compute::DataType::QSYMM8:
213 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
214 case arm_compute::DataType::QASYMM8_SIGNED:
215 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
218 case arm_compute::DataType::QSYMM16:
219 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
222 case arm_compute::DataType::S32:
223 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
226 case arm_compute::DataType::S64:
227 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int64_t*
>(memory),
238 arm_compute::CLTensor m_Tensor;
239 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
242 bool m_IsImportEnabled;
243 std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated;
250 const arm_compute::TensorShape& shape,
252 : m_Tensor(&parent->
GetTensor(), shape, coords)
254 parentHandle = parent;
257 arm_compute::CLSubTensor&
GetTensor()
override {
return m_Tensor; }
258 arm_compute::CLSubTensor
const&
GetTensor()
const override {
return m_Tensor; }
263 virtual const void*
Map(
bool blocking =
true)
const override
265 const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->map(blocking);
266 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
268 virtual void Unmap()
const override {
const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->unmap(); }
274 return m_Tensor.info()->data_type();
277 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
281 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
286 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
291 void CopyOutTo(
void* memory)
const override
296 case arm_compute::DataType::F32:
297 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
298 static_cast<float*
>(memory));
300 case arm_compute::DataType::U8:
301 case arm_compute::DataType::QASYMM8:
302 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
303 static_cast<uint8_t*
>(memory));
305 case arm_compute::DataType::F16:
306 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
309 case arm_compute::DataType::QSYMM8:
310 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
311 case arm_compute::DataType::QASYMM8_SIGNED:
312 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
313 static_cast<int8_t*
>(memory));
315 case arm_compute::DataType::S16:
316 case arm_compute::DataType::QSYMM16:
317 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
318 static_cast<int16_t*
>(memory));
320 case arm_compute::DataType::S32:
321 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
322 static_cast<int32_t*
>(memory));
333 void CopyInFrom(
const void* memory)
override
338 case arm_compute::DataType::F32:
339 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
342 case arm_compute::DataType::U8:
343 case arm_compute::DataType::QASYMM8:
344 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
347 case arm_compute::DataType::F16:
348 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
351 case arm_compute::DataType::QSYMM8:
352 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
353 case arm_compute::DataType::QASYMM8_SIGNED:
354 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
357 case arm_compute::DataType::S16:
358 case arm_compute::DataType::QSYMM16:
359 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
362 case arm_compute::DataType::S32:
363 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
374 mutable arm_compute::CLSubTensor m_Tensor;
375 ITensorHandle* parentHandle =
nullptr;
396 arm_compute::ICLTensor*
parent();
398 void map(
bool blocking =
true);
399 using arm_compute::ICLTensor::map;
402 using arm_compute::ICLTensor::unmap;
404 virtual arm_compute::ITensorInfo*
info()
const override;
405 virtual arm_compute::ITensorInfo*
info()
override;
406 const cl::Buffer&
cl_buffer()
const override;
407 arm_compute::CLQuantization
quantization()
const override;
411 uint8_t*
do_map(cl::CommandQueue& q,
bool blocking)
override;
412 void do_unmap(cl::CommandQueue& q)
override;
415 arm_compute::ICLTensor* m_Original;
416 mutable arm_compute::TensorInfo m_TensorInfo;
425 m_OriginalHandle = parent;
428 arm_compute::ICLTensor&
GetTensor()
override {
return m_Tensor; }
429 arm_compute::ICLTensor
const&
GetTensor()
const override {
return m_Tensor; }
434 virtual const void*
Map(
bool blocking =
true)
const override
436 m_Tensor.
map(blocking);
437 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.
info()->offset_first_element_in_bytes());
449 return m_Tensor.
info()->data_type();
452 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
456 return armcomputetensorutils::GetStrides(m_Tensor.
info()->strides_in_bytes());
461 return armcomputetensorutils::GetShape(m_Tensor.
info()->tensor_shape());
466 void CopyOutTo(
void* memory)
const override
471 case arm_compute::DataType::F32:
472 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
473 static_cast<float*
>(memory));
475 case arm_compute::DataType::U8:
476 case arm_compute::DataType::QASYMM8:
477 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
478 static_cast<uint8_t*
>(memory));
480 case arm_compute::DataType::F16:
481 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
484 case arm_compute::DataType::QSYMM8:
485 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
486 case arm_compute::DataType::QASYMM8_SIGNED:
487 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
488 static_cast<int8_t*
>(memory));
490 case arm_compute::DataType::S16:
491 case arm_compute::DataType::QSYMM16:
492 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
493 static_cast<int16_t*
>(memory));
495 case arm_compute::DataType::S32:
496 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
497 static_cast<int32_t*
>(memory));
508 void CopyInFrom(
const void* memory)
override
513 case arm_compute::DataType::F32:
514 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
517 case arm_compute::DataType::U8:
518 case arm_compute::DataType::QASYMM8:
519 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
522 case arm_compute::DataType::F16:
523 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
526 case arm_compute::DataType::QSYMM8:
527 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
528 case arm_compute::DataType::QASYMM8_SIGNED:
529 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
532 case arm_compute::DataType::S16:
533 case arm_compute::DataType::QSYMM16:
534 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
537 case arm_compute::DataType::S32:
538 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
549 mutable ClTensorDecorator m_Tensor;
550 IClTensorHandle* m_OriginalHandle =
nullptr;