ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
DataTypeUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017, 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ResolveType.hpp>
9 
10 
12 
13 #include <vector>
14 
15 // Utility tenmplate to convert a collection of values to the correct type
16 template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
17 std::vector<T> ConvertToDataType(const std::vector<float>& input,
18  const armnn::TensorInfo& inputTensorInfo)
19 {
20  std::vector<T> output(input.size());
21  auto outputTensorInfo = inputTensorInfo;
22  outputTensorInfo.SetDataType(ArmnnType);
23 
24  if(sizeof(T) > 4)
25  {
26  std::unique_ptr<armnn::Encoder<double>> pOutputEncoder = armnn::MakeEncoder<double>(outputTensorInfo,
27  output.data());
28  armnn::Encoder<double>& rOutputEncoder = *pOutputEncoder;
29  for (auto it = input.begin(); it != input.end(); ++it)
30  {
31  rOutputEncoder.Set(*it);
32  ++rOutputEncoder;
33  }
34  }
35  else
36  {
37  std::unique_ptr<armnn::Encoder<float>> pOutputEncoder = armnn::MakeEncoder<float>(outputTensorInfo,
38  output.data());
39  armnn::Encoder<float>& rOutputEncoder = *pOutputEncoder;
40  for (auto it = input.begin(); it != input.end(); ++it)
41  {
42  rOutputEncoder.Set(*it);
43  ++rOutputEncoder;
44  }
45  }
46 
47  return output;
48 }
49 
50 // Utility tenmplate to convert a single value to the correct type
51 template <typename T>
52 T ConvertToDataType(const float& value,
53  const armnn::TensorInfo& tensorInfo)
54 {
55  std::vector<T> output(1);
56 
57  if(sizeof(T) > 4)
58  {
59  std::unique_ptr<armnn::Encoder<double>> pEncoder = armnn::MakeEncoder<double>(tensorInfo, output.data());
60  armnn::Encoder<double>& rEncoder = *pEncoder;
61  rEncoder.Set(value);
62  }
63  else
64  {
65  std::unique_ptr<armnn::Encoder<float>> pEncoder = armnn::MakeEncoder<float>(tensorInfo, output.data());
66  armnn::Encoder<float>& rEncoder = *pEncoder;
67  rEncoder.Set(value);
68  }
69 
70  return output[0];
71 }
std::vector< T > ConvertToDataType(const std::vector< float > &input, const armnn::TensorInfo &inputTensorInfo)
virtual void Set(IType right)=0
void SetDataType(DataType type)
Definition: Tensor.hpp:201