ArmNN
 25.11
Loading...
Searching...
No Matches
RefDebugWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2018-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7#include "Debug.hpp"
9
10#include <ResolveType.hpp>
11
12#include <cstring>
13
14namespace armnn
15{
16
17template<armnn::DataType DataType>
19{
20 Execute(m_Data.m_Inputs);
21}
22
23template<armnn::DataType DataType>
24void RefDebugWorkload<DataType>::Execute(std::vector<ITensorHandle*> inputs) const
25{
26 using T = ResolveType<DataType>;
27
29
30 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
31
32 const T* inputData = GetInputTensorData<T>(0, m_Data);
33 T* outputData = GetOutputTensorData<T>(0, m_Data);
34
35 if (m_Callback)
36 {
37 m_Callback(m_Data.m_Guid, m_Data.m_SlotIndex, inputs[0]);
38 }
39 else
40 {
41 Debug(inputInfo, inputData, m_Data.m_Guid, m_Data.m_LayerName, m_Data.m_SlotIndex, m_Data.m_LayerOutputToFile);
42 }
43
44 std::memcpy(outputData, inputData, inputInfo.GetNumElements()*sizeof(T));
45}
46
47template<armnn::DataType DataType>
49{
50 m_Callback = func;
51}
52
63
64} // namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
void RegisterDebugCallback(const DebugCallbackFunction &func) override
void Execute() const override
unsigned int GetNumElements() const
Definition Tensor.hpp:198
Copyright (c) 2021 ARM Limited and Contributors.
typename ResolveTypeImpl< DT >::Type ResolveType
std::function< void(LayerGuid guid, unsigned int slotIndex, ITensorHandle *tensorHandle)> DebugCallbackFunction
Define the type of callback for the Debug layer to call.
Definition Types.hpp:400
const DataType * GetInputTensorData(unsigned int idx, const PayloadType &data)
DataType * GetOutputTensorData(unsigned int idx, const PayloadType &data)
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)