ArmNN
 24.08
ConvertConstants.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
13 
14 #include <Half.hpp>
15 
16 namespace armnn
17 {
18 namespace optimizations
19 {
20 
22 {
23  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
24  {
25  const TensorInfo& info = handle->GetTensorInfo();
26 
27  if (info.GetDataType() == DataType::Float16)
28  {
29  std::vector<float> newValues(info.GetNumElements());
30 
32  info.GetNumElements(),
33  newValues.data());
34 
35  TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true);
36  ConstTensor newInput(newInfo, newValues);
37  handle.reset(new ScopedTensorHandle(newInput));
38  }
39  }
40 };
41 
43 {
44  static void Func(std::shared_ptr<ConstTensorHandle>& handle)
45  {
46  const TensorInfo& info = handle->GetTensorInfo();
47 
48  if (info.GetDataType() == DataType::Float32)
49  {
50  std::vector<Half> newValues(info.GetNumElements());
51 
52  armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(),
53  info.GetNumElements(),
54  newValues.data());
55 
56  TensorInfo newInfo(info.GetShape(), DataType::Float16, 0.0f, 0, true);
57  ConstTensor newInput(newInfo, newValues);
58  handle.reset(new ScopedTensorHandle(newInput));
59  }
60  }
61 };
62 
63 template<typename Converter, typename Predicate>
65 {
66 public:
67  ConvertConstants() = default;
68  ConvertConstants(const ConvertConstants&) = default;
69  virtual ~ConvertConstants() = default;
70 
71  void Run(Graph& graph, Layer& layer) const override
72  {
73  IgnoreUnused(graph);
74  if (Predicate::Test(layer))
75  {
76  layer.OperateOnConstantTensors(Converter::Func);
77  }
78  }
79 protected:
80 };
81 
83 {
84  static bool Test(const Layer& layer)
85  {
86  return layer.GetDataType() == DataType::Float32;
87  }
88 };
89 
91 {
92  static bool Test(const Layer& layer)
93  {
94  return layer.GetDataType() == DataType::Float16;
95  }
96 };
97 
100 
101 } //namespace optimizations
102 } //namespace armnn
armnn::Layer::OperateOnConstantTensors
void OperateOnConstantTensors(Op op)
Definition: Layer.hpp:319
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::optimizations::ConvertConstants
Definition: ConvertConstants.hpp:64
armnn::DataType::Float32
@ Float32
armnn::optimizations::ConvertConstants::Run
void Run(Graph &graph, Layer &layer) const override
Definition: ConvertConstants.hpp:71
armnn::optimizations::IsFloat32Layer::Test
static bool Test(const Layer &layer)
Definition: ConvertConstants.hpp:84
armnn::optimizations::ConvertConstants::ConvertConstants
ConvertConstants()=default
armnn::Half
half_float::half Half
Definition: Half.hpp:22
IgnoreUnused.hpp
armnn::optimizations::IsFloat32Layer
Definition: ConvertConstants.hpp:82
Optimization.hpp
armnn::Layer
Definition: Layer.hpp:230
armnn::optimizations::ConvertConstants::~ConvertConstants
virtual ~ConvertConstants()=default
armnn::DataType::Float16
@ Float16
armnn::optimizations::Float16ToFloat32
Definition: ConvertConstants.hpp:21
armnnUtils::FloatingPointConverter::ConvertFloat16To32
static void ConvertFloat16To32(const void *srcFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
Definition: FloatingPointConverter.cpp:43
armnn::optimizations::IsFloat16Layer
Definition: ConvertConstants.hpp:90
armnn::optimizations::IsFloat16Layer::Test
static bool Test(const Layer &layer)
Definition: ConvertConstants.hpp:92
armnn::BoostLogSeverityMapping::info
@ info
Half.hpp
armnn::Layer::GetDataType
DataType GetDataType() const
Definition: Layer.cpp:345
TensorHandle.hpp
armnnUtils::FloatingPointConverter::ConvertFloat32To16
static void ConvertFloat32To16(const float *srcFloat32Buffer, size_t numElements, void *dstFloat16Buffer)
Converts a buffer of FP32 values to FP16, and stores in the given dstFloat16Buffer.
Definition: FloatingPointConverter.cpp:17
armnn::optimizations::Float32ToFloat16
Definition: ConvertConstants.hpp:42
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn::Optimization
Definition: Optimization.hpp:15
armnn::optimizations::Float32ToFloat16::Func
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
Definition: ConvertConstants.hpp:44
FloatingPointConverter.hpp
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::ConstTensor
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:329
armnn::ScopedTensorHandle
Definition: TensorHandle.hpp:115
armnn::optimizations::Float16ToFloat32::Func
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
Definition: ConvertConstants.hpp:23
armnn::Graph
Definition: Graph.hpp:30