ArmNN
 25.11
Loading...
Searching...
No Matches
RefCastWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefCastWorkload.hpp"
8#include <ResolveType.hpp>
9#include "Encoders.hpp"
10#include "Decoders.hpp"
11
12namespace
13{
15 const uint32_t numElements, const armnn::DataType OutputDataType)
16 {
17 for (unsigned int i = 0; i < numElements; ++i)
18 {
19 switch (OutputDataType)
20 {
24 out.Set(in.Get());
25 break;
26 default:
27 out.Set(std::floor(in.Get()));
28 break;
29 }
30 ++in;
31 ++out;
32 }
33 }
34
35
36 // Cast Float to Int64
38 const uint32_t numElements, const armnn::DataType)
39 {
40 for (unsigned int i = 0; i < numElements; ++i)
41 {
42 out.Set(in.Get());
43 ++in;
44 ++out;
45 }
46 }
47
48 // Cast Int64 To Float
50 const uint32_t numElements, const armnn::DataType)
51 {
52 for (unsigned int i = 0; i < numElements; ++i)
53 {
54 out.Set(static_cast<float>(in.Get()));
55 ++in;
56 ++out;
57 }
58 }
59}
60
61
62namespace armnn
63{
64
66{
67 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
68}
69
70void RefCastWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
71{
72 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefCastWorkload_Execute");
73
74 TensorInfo inputTensorInfo(GetTensorInfo(inputs[0]));
75 TensorInfo outputTensorInfo(GetTensorInfo(outputs[0]));
76
77 // Quantization info should set to default values.
78 if (inputTensorInfo.IsQuantized())
79 {
80 inputTensorInfo.SetQuantizationScale(1.0f);
81 inputTensorInfo.SetQuantizationOffset(0);
82 }
83 if (outputTensorInfo.IsQuantized())
84 {
85 outputTensorInfo.SetQuantizationScale(1.0f);
86 outputTensorInfo.SetQuantizationOffset(0);
87 }
88
89 if(inputTensorInfo.GetDataType() == DataType::Signed64)
90 {
91 Cast(*MakeDecoder<double_t>(inputTensorInfo, inputs[0]->Map()),
92 *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
93 inputTensorInfo.GetNumElements(),
94 outputTensorInfo.GetDataType());
95 }
96 else if(outputTensorInfo.GetDataType() == DataType::Signed64)
97 {
98 Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
99 *MakeEncoder<double_t>(outputTensorInfo, outputs[0]->Map()),
100 inputTensorInfo.GetNumElements(),
101 outputTensorInfo.GetDataType());
102 }
103 else
104 {
105 Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
106 *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
107 inputTensorInfo.GetNumElements(),
108 outputTensorInfo.GetDataType());
109 }
110}
111
112} //namespace armnn
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
virtual IType Get() const =0
virtual void Set(IType right)=0
void Execute() const override
Copyright (c) 2021 ARM Limited and Contributors.
std::unique_ptr< Decoder< T > > MakeDecoder(const TensorInfo &info, const void *data=nullptr)
std::unique_ptr< Encoder< T > > MakeEncoder(const TensorInfo &info, void *data=nullptr)
DataType
Definition Types.hpp:49
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)