15 #include <arm_compute/runtime/CL/CLTensor.h>
16 #include <arm_compute/runtime/CL/CLSubTensor.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/MemoryGroup.h>
19 #include <arm_compute/core/TensorShape.h>
20 #include <arm_compute/core/Coordinates.h>
26 class ClTensorHandleDecorator;
34 m_IsImportEnabled(false)
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
42 : m_ImportFlags(importFlags),
44 m_IsImportEnabled(false)
46 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49 arm_compute::CLTensor&
GetTensor()
override {
return m_Tensor; }
50 arm_compute::CLTensor
const&
GetTensor()
const override {
return m_Tensor; }
54 if (m_IsImportEnabled)
60 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
68 if (m_IsImportEnabled)
74 assert(m_MemoryGroup !=
nullptr);
75 m_MemoryGroup->manage(&m_Tensor);
79 virtual const void*
Map(
bool blocking =
true)
const override
81 const_cast<arm_compute::CLTensor*
>(&m_Tensor)->map(blocking);
82 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85 virtual void Unmap()
const override {
const_cast<arm_compute::CLTensor*
>(&m_Tensor)->unmap(); }
91 return m_Tensor.info()->data_type();
94 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
96 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
101 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
106 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
111 m_ImportFlags = importFlags;
116 return m_ImportFlags;
121 m_IsImportEnabled = importEnabledFlag;
146 void CopyOutTo(
void* memory)
const override
151 case arm_compute::DataType::F32:
152 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
153 static_cast<float*
>(memory));
155 case arm_compute::DataType::U8:
156 case arm_compute::DataType::QASYMM8:
157 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
158 static_cast<uint8_t*
>(memory));
160 case arm_compute::DataType::QSYMM8:
161 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
162 case arm_compute::DataType::QASYMM8_SIGNED:
163 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
164 static_cast<int8_t*
>(memory));
166 case arm_compute::DataType::F16:
167 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
170 case arm_compute::DataType::S16:
171 case arm_compute::DataType::QSYMM16:
172 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
173 static_cast<int16_t*
>(memory));
175 case arm_compute::DataType::S32:
176 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
177 static_cast<int32_t*
>(memory));
188 void CopyInFrom(
const void* memory)
override
193 case arm_compute::DataType::F32:
194 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
197 case arm_compute::DataType::U8:
198 case arm_compute::DataType::QASYMM8:
199 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
202 case arm_compute::DataType::F16:
203 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
206 case arm_compute::DataType::S16:
207 case arm_compute::DataType::QSYMM8:
208 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
209 case arm_compute::DataType::QASYMM8_SIGNED:
210 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
213 case arm_compute::DataType::QSYMM16:
214 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
217 case arm_compute::DataType::S32:
218 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
229 arm_compute::CLTensor m_Tensor;
230 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
233 bool m_IsImportEnabled;
234 std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated;
241 const arm_compute::TensorShape& shape,
243 : m_Tensor(&parent->
GetTensor(), shape, coords)
245 parentHandle = parent;
248 arm_compute::CLSubTensor&
GetTensor()
override {
return m_Tensor; }
249 arm_compute::CLSubTensor
const&
GetTensor()
const override {
return m_Tensor; }
254 virtual const void*
Map(
bool blocking =
true)
const override
256 const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->map(blocking);
257 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
259 virtual void Unmap()
const override {
const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->unmap(); }
265 return m_Tensor.info()->data_type();
268 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
272 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
277 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
282 void CopyOutTo(
void* memory)
const override
287 case arm_compute::DataType::F32:
288 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
289 static_cast<float*
>(memory));
291 case arm_compute::DataType::U8:
292 case arm_compute::DataType::QASYMM8:
293 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
294 static_cast<uint8_t*
>(memory));
296 case arm_compute::DataType::F16:
297 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
300 case arm_compute::DataType::QSYMM8:
301 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
302 case arm_compute::DataType::QASYMM8_SIGNED:
303 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
304 static_cast<int8_t*
>(memory));
306 case arm_compute::DataType::S16:
307 case arm_compute::DataType::QSYMM16:
308 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
309 static_cast<int16_t*
>(memory));
311 case arm_compute::DataType::S32:
312 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
313 static_cast<int32_t*
>(memory));
324 void CopyInFrom(
const void* memory)
override
329 case arm_compute::DataType::F32:
330 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
333 case arm_compute::DataType::U8:
334 case arm_compute::DataType::QASYMM8:
335 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
338 case arm_compute::DataType::F16:
339 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
342 case arm_compute::DataType::QSYMM8:
343 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
344 case arm_compute::DataType::QASYMM8_SIGNED:
345 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
348 case arm_compute::DataType::S16:
349 case arm_compute::DataType::QSYMM16:
350 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
353 case arm_compute::DataType::S32:
354 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
365 mutable arm_compute::CLSubTensor m_Tensor;
366 ITensorHandle* parentHandle =
nullptr;
387 arm_compute::ICLTensor*
parent();
389 void map(
bool blocking =
true);
390 using arm_compute::ICLTensor::map;
393 using arm_compute::ICLTensor::unmap;
395 virtual arm_compute::ITensorInfo*
info()
const override;
396 virtual arm_compute::ITensorInfo*
info()
override;
397 const cl::Buffer&
cl_buffer()
const override;
398 arm_compute::CLQuantization
quantization()
const override;
402 uint8_t*
do_map(cl::CommandQueue& q,
bool blocking)
override;
403 void do_unmap(cl::CommandQueue& q)
override;
406 arm_compute::ICLTensor* m_Original;
407 mutable arm_compute::TensorInfo m_TensorInfo;
416 m_OriginalHandle = parent;
419 arm_compute::ICLTensor&
GetTensor()
override {
return m_Tensor; }
420 arm_compute::ICLTensor
const&
GetTensor()
const override {
return m_Tensor; }
425 virtual const void*
Map(
bool blocking =
true)
const override
427 m_Tensor.
map(blocking);
428 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.
info()->offset_first_element_in_bytes());
440 return m_Tensor.
info()->data_type();
443 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
447 return armcomputetensorutils::GetStrides(m_Tensor.
info()->strides_in_bytes());
452 return armcomputetensorutils::GetShape(m_Tensor.
info()->tensor_shape());
457 void CopyOutTo(
void* memory)
const override
462 case arm_compute::DataType::F32:
463 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
464 static_cast<float*
>(memory));
466 case arm_compute::DataType::U8:
467 case arm_compute::DataType::QASYMM8:
468 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
469 static_cast<uint8_t*
>(memory));
471 case arm_compute::DataType::F16:
472 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
475 case arm_compute::DataType::QSYMM8:
476 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
477 case arm_compute::DataType::QASYMM8_SIGNED:
478 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
479 static_cast<int8_t*
>(memory));
481 case arm_compute::DataType::S16:
482 case arm_compute::DataType::QSYMM16:
483 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
484 static_cast<int16_t*
>(memory));
486 case arm_compute::DataType::S32:
487 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
488 static_cast<int32_t*
>(memory));
499 void CopyInFrom(
const void* memory)
override
504 case arm_compute::DataType::F32:
505 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
508 case arm_compute::DataType::U8:
509 case arm_compute::DataType::QASYMM8:
510 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
513 case arm_compute::DataType::F16:
514 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
517 case arm_compute::DataType::QSYMM8:
518 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
519 case arm_compute::DataType::QASYMM8_SIGNED:
520 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
523 case arm_compute::DataType::S16:
524 case arm_compute::DataType::QSYMM16:
525 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
528 case arm_compute::DataType::S32:
529 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
540 mutable ClTensorDecorator m_Tensor;
541 IClTensorHandle* m_OriginalHandle =
nullptr;