ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ConvertConstants.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
13 
14 #include <Half.hpp>
15 
16 namespace armnn
17 {
18 namespace optimizations
19 {
20 
22 {
23  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
24  {
25  const TensorInfo& info = handle->GetTensorInfo();
26 
27  if (info.GetDataType() == DataType::Float16)
28  {
29  std::vector<float> newValues(info.GetNumElements());
30 
32  info.GetNumElements(),
33  newValues.data());
34 
35  TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true);
36  ConstTensor newInput(newInfo, newValues);
37  handle.reset(new ScopedTensorHandle(newInput));
38  }
39  }
40 };
41 
43 {
44  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
45  {
46  const TensorInfo& info = handle->GetTensorInfo();
47 
48  if (info.GetDataType() == DataType::Float32)
49  {
50  std::vector<Half> newValues(info.GetNumElements());
51 
52  armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(),
53  info.GetNumElements(),
54  newValues.data());
55 
56  TensorInfo newInfo(info.GetShape(), DataType::Float16, 0.0f, 0, true);
57  ConstTensor newInput(newInfo, newValues);
58  handle.reset(new ScopedTensorHandle(newInput));
59  }
60  }
61 };
62 
63 template<typename Converter, typename Predicate>
65 {
66 public:
67  ConvertConstants() = default;
69  virtual ~ConvertConstants() = default;
70 
71  void Run(Graph& graph, Layer& layer) const override
72  {
73  IgnoreUnused(graph);
74  if (Predicate::Test(layer))
75  {
76  layer.OperateOnConstantTensors(Converter::Func);
77  }
78  }
79 protected:
80 };
81 
83 {
84  static bool Test(const Layer& layer)
85  {
86  return layer.GetDataType() == DataType::Float32;
87  }
88 };
89 
91 {
92  static bool Test(const Layer& layer)
93  {
94  return layer.GetDataType() == DataType::Float16;
95  }
96 };
97 
100 
101 } //namespace optimizations
102 } //namespace armnn
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:330
void OperateOnConstantTensors(Op op)
Definition: Layer.hpp:319
DataType GetDataType() const
Definition: Layer.cpp:345
ConvertConstants(const ConvertConstants &)=default
void Run(Graph &graph, Layer &layer) const override
static void ConvertFloat16To32(const void *srcFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
static void ConvertFloat32To16(const float *srcFloat32Buffer, size_t numElements, void *dstFloat16Buffer)
Converts a buffer of FP32 values to FP16, and stores in the given dstFloat16Buffer.
Copyright (c) 2021 ARM Limited and Contributors.
half_float::half Half
Definition: Half.hpp:22
void IgnoreUnused(Ts &&...)
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static bool Test(const Layer &layer)
static bool Test(const Layer &layer)