13 #include <arm_compute/runtime/CL/CLTensor.h>
14 #include <arm_compute/runtime/CL/CLSubTensor.h>
15 #include <arm_compute/runtime/IMemoryGroup.h>
16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/core/TensorShape.h>
18 #include <arm_compute/core/Coordinates.h>
31 m_IsImportEnabled(false)
33 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
39 : m_ImportFlags(importFlags),
41 m_IsImportEnabled(false)
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46 arm_compute::CLTensor&
GetTensor()
override {
return m_Tensor; }
47 arm_compute::CLTensor
const&
GetTensor()
const override {
return m_Tensor; }
51 if (m_IsImportEnabled)
57 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
65 if (m_IsImportEnabled)
71 assert(m_MemoryGroup !=
nullptr);
72 m_MemoryGroup->manage(&m_Tensor);
76 virtual const void*
Map(
bool blocking =
true)
const override
78 const_cast<arm_compute::CLTensor*
>(&m_Tensor)->map(blocking);
79 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
82 virtual void Unmap()
const override {
const_cast<arm_compute::CLTensor*
>(&m_Tensor)->unmap(); }
88 return m_Tensor.info()->data_type();
91 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
93 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
98 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
103 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
108 m_ImportFlags = importFlags;
113 return m_ImportFlags;
118 m_IsImportEnabled = importEnabledFlag;
139 void CopyOutTo(
void* memory)
const override
144 case arm_compute::DataType::F32:
145 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
146 static_cast<float*
>(memory));
148 case arm_compute::DataType::U8:
149 case arm_compute::DataType::QASYMM8:
150 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
151 static_cast<uint8_t*
>(memory));
153 case arm_compute::DataType::QSYMM8:
154 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
155 case arm_compute::DataType::QASYMM8_SIGNED:
156 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
157 static_cast<int8_t*
>(memory));
159 case arm_compute::DataType::F16:
160 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
163 case arm_compute::DataType::S16:
164 case arm_compute::DataType::QSYMM16:
165 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
166 static_cast<int16_t*
>(memory));
168 case arm_compute::DataType::S32:
169 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
170 static_cast<int32_t*
>(memory));
181 void CopyInFrom(
const void* memory)
override
186 case arm_compute::DataType::F32:
187 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
190 case arm_compute::DataType::U8:
191 case arm_compute::DataType::QASYMM8:
192 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
195 case arm_compute::DataType::F16:
196 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
199 case arm_compute::DataType::S16:
200 case arm_compute::DataType::QSYMM8:
201 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
202 case arm_compute::DataType::QASYMM8_SIGNED:
203 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
206 case arm_compute::DataType::QSYMM16:
207 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
210 case arm_compute::DataType::S32:
211 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
222 arm_compute::CLTensor m_Tensor;
223 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
226 bool m_IsImportEnabled;
233 const arm_compute::TensorShape& shape,
235 : m_Tensor(&parent->
GetTensor(), shape, coords)
237 parentHandle = parent;
240 arm_compute::CLSubTensor&
GetTensor()
override {
return m_Tensor; }
241 arm_compute::CLSubTensor
const&
GetTensor()
const override {
return m_Tensor; }
246 virtual const void*
Map(
bool blocking =
true)
const override
248 const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->map(blocking);
249 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
251 virtual void Unmap()
const override {
const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->unmap(); }
257 return m_Tensor.info()->data_type();
260 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
264 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
269 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
274 void CopyOutTo(
void* memory)
const override
279 case arm_compute::DataType::F32:
280 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
281 static_cast<float*
>(memory));
283 case arm_compute::DataType::U8:
284 case arm_compute::DataType::QASYMM8:
285 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
286 static_cast<uint8_t*
>(memory));
288 case arm_compute::DataType::F16:
289 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
292 case arm_compute::DataType::QSYMM8:
293 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
294 case arm_compute::DataType::QASYMM8_SIGNED:
295 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
296 static_cast<int8_t*
>(memory));
298 case arm_compute::DataType::S16:
299 case arm_compute::DataType::QSYMM16:
300 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
301 static_cast<int16_t*
>(memory));
303 case arm_compute::DataType::S32:
304 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
305 static_cast<int32_t*
>(memory));
316 void CopyInFrom(
const void* memory)
override
321 case arm_compute::DataType::F32:
322 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
325 case arm_compute::DataType::U8:
326 case arm_compute::DataType::QASYMM8:
327 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
330 case arm_compute::DataType::F16:
331 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
334 case arm_compute::DataType::QSYMM8:
335 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
336 case arm_compute::DataType::QASYMM8_SIGNED:
337 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
340 case arm_compute::DataType::S16:
341 case arm_compute::DataType::QSYMM16:
342 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
345 case arm_compute::DataType::S32:
346 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
357 mutable arm_compute::CLSubTensor m_Tensor;
358 ITensorHandle* parentHandle =
nullptr;